203 lines
5.4 KiB
Go
203 lines
5.4 KiB
Go
package api
|
||
|
||
import (
|
||
"encoding/json"
|
||
"fmt"
|
||
|
||
"contentSecurityDemo/conf"
|
||
"contentSecurityDemo/internal/model"
|
||
|
||
"github.com/aliyun/alibaba-cloud-sdk-go/sdk/requests"
|
||
"github.com/aliyun/alibaba-cloud-sdk-go/services/green"
|
||
)
|
||
|
||
// TextScanner 文本内容安全扫描器2.0
|
||
type TextScanner struct {
|
||
Client *green.Client
|
||
Config *conf.Config
|
||
}
|
||
|
||
// TextServiceType 文本审核服务类型
|
||
type TextServiceType string
|
||
|
||
const (
|
||
// TextBaselineCheckGlobal 文本通用基线检测
|
||
TextBaselineCheckGlobal TextServiceType = "baselineCheck_global"
|
||
// TextPostCheckByVLGlobal 文本大小模型融合审核服务
|
||
TextPostCheckByVLGlobal TextServiceType = "postTextCheckByVL_global"
|
||
)
|
||
|
||
// NewTextScanner 创建文本扫描器
|
||
func NewTextScanner(config *conf.Config) (*TextScanner, error) {
|
||
// 获取有效的访问凭证
|
||
accessKeyID, accessKeySecret, securityToken := config.GetEffectiveCredentials()
|
||
|
||
var client *green.Client
|
||
var err error
|
||
|
||
if securityToken != "" {
|
||
// 使用STS临时凭证
|
||
client, err = green.NewClientWithStsToken(config.Region, accessKeyID, accessKeySecret, securityToken)
|
||
} else {
|
||
// 使用直接凭证
|
||
client, err = green.NewClientWithAccessKey(config.Region, accessKeyID, accessKeySecret)
|
||
}
|
||
|
||
if err != nil {
|
||
return nil, fmt.Errorf("创建客户端失败: %v", err)
|
||
}
|
||
|
||
return &TextScanner{
|
||
Client: client,
|
||
Config: config,
|
||
}, nil
|
||
}
|
||
|
||
// ScanText 扫描文本内容(2.0版本)
|
||
func (s *TextScanner) ScanText(content string, dataID string, serviceType TextServiceType) (*model.TextScanResponse, error) {
|
||
request := requests.NewCommonRequest()
|
||
request.Method = "POST"
|
||
request.Scheme = "https"
|
||
request.Domain = s.Config.Endpoint
|
||
request.Version = "2022-03-02" // 内容安全2.0文本审核API版本
|
||
request.ApiName = "TextScan"
|
||
request.QueryParams["RegionId"] = s.Config.Region
|
||
|
||
scanRequest := model.TextScanRequest{
|
||
Tasks: []model.TextTask{
|
||
{
|
||
DataID: dataID,
|
||
Content: content,
|
||
},
|
||
},
|
||
Services: []string{string(serviceType)}, // 使用服务类型而不是场景
|
||
}
|
||
|
||
requestBody, err := json.Marshal(scanRequest)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("序列化请求失败: %v", err)
|
||
}
|
||
|
||
request.Content = requestBody
|
||
|
||
response, err := s.Client.ProcessCommonRequest(request)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("请求失败: %v", err)
|
||
}
|
||
|
||
var scanResponse model.TextScanResponse
|
||
if err := json.Unmarshal(response.GetHttpContentBytes(), &scanResponse); err != nil {
|
||
return nil, fmt.Errorf("解析响应失败: %v", err)
|
||
}
|
||
|
||
return &scanResponse, nil
|
||
}
|
||
|
||
// ScanTextBatch 批量扫描文本内容(2.0版本)
|
||
func (s *TextScanner) ScanTextBatch(texts []string, dataIDs []string, serviceType TextServiceType) (*model.TextScanResponse, error) {
|
||
if len(texts) != len(dataIDs) {
|
||
return nil, fmt.Errorf("文本数量和ID数量不匹配")
|
||
}
|
||
|
||
request := requests.NewCommonRequest()
|
||
request.Method = "POST"
|
||
request.Scheme = "https"
|
||
request.Domain = s.Config.Endpoint
|
||
request.Version = "2022-03-02" // 内容安全2.0文本审核API版本
|
||
request.ApiName = "TextScan"
|
||
request.QueryParams["RegionId"] = s.Config.Region
|
||
|
||
var tasks []model.TextTask
|
||
for i, text := range texts {
|
||
tasks = append(tasks, model.TextTask{
|
||
DataID: dataIDs[i],
|
||
Content: text,
|
||
})
|
||
}
|
||
|
||
scanRequest := model.TextScanRequest{
|
||
Tasks: tasks,
|
||
Services: []string{string(serviceType)}, // 使用服务类型而不是场景
|
||
}
|
||
|
||
requestBody, err := json.Marshal(scanRequest)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("序列化请求失败: %v", err)
|
||
}
|
||
|
||
request.Content = requestBody
|
||
|
||
response, err := s.Client.ProcessCommonRequest(request)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("请求失败: %v", err)
|
||
}
|
||
|
||
var scanResponse model.TextScanResponse
|
||
if err := json.Unmarshal(response.GetHttpContentBytes(), &scanResponse); err != nil {
|
||
return nil, fmt.Errorf("解析响应失败: %v", err)
|
||
}
|
||
|
||
return &scanResponse, nil
|
||
}
|
||
|
||
// PrintResult 打印扫描结果(2.0版本)
|
||
func (s *TextScanner) PrintResult(response *model.TextScanResponse) {
|
||
fmt.Println("=== 文本内容安全审核结果(2.0版本)===")
|
||
fmt.Printf("状态码: %d\n", response.Code)
|
||
fmt.Printf("消息: %s\n", response.Message)
|
||
|
||
for _, data := range response.Data {
|
||
fmt.Printf("\n数据ID: %s\n", data.DataID)
|
||
fmt.Printf("处理状态: %d - %s\n", data.Code, data.Message)
|
||
|
||
if len(data.Results) == 0 {
|
||
fmt.Println("✅ 未检测到违规内容")
|
||
continue
|
||
}
|
||
|
||
for _, result := range data.Results {
|
||
fmt.Printf("场景: %s\n", result.Scene)
|
||
fmt.Printf("标签: %s\n", result.Label)
|
||
if result.SubLabel != "" {
|
||
fmt.Printf("子标签: %s\n", result.SubLabel)
|
||
}
|
||
fmt.Printf("建议: %s\n", result.Suggestion)
|
||
fmt.Printf("置信度: %.2f\n", result.Rate)
|
||
|
||
// 打印详细信息
|
||
if len(result.Details) > 0 {
|
||
fmt.Println("违规详情:")
|
||
for _, detail := range result.Details {
|
||
fmt.Printf(" 内容: %s\n", detail.Context.Context)
|
||
if len(detail.Context.Pos) > 0 {
|
||
fmt.Printf(" 位置: %v\n", detail.Context.Pos)
|
||
}
|
||
}
|
||
}
|
||
|
||
// 打印扩展信息
|
||
if len(result.Extras) > 0 {
|
||
fmt.Println("扩展信息:")
|
||
for key, value := range result.Extras {
|
||
fmt.Printf(" %s: %v\n", key, value)
|
||
}
|
||
}
|
||
fmt.Println("---")
|
||
}
|
||
}
|
||
}
|
||
|
||
// GetSuggestionText 获取建议文本
|
||
func GetSuggestionText(suggestion string) string {
|
||
switch suggestion {
|
||
case "pass":
|
||
return "通过"
|
||
case "review":
|
||
return "需要人工审核"
|
||
case "block":
|
||
return "拒绝"
|
||
default:
|
||
return suggestion
|
||
}
|
||
}
|