230 lines
6.3 KiB
Go
230 lines
6.3 KiB
Go
package api
|
||
|
||
import (
|
||
"encoding/json"
|
||
"fmt"
|
||
"io"
|
||
"os"
|
||
"path/filepath"
|
||
"strings"
|
||
|
||
"contentSecurityDemo/conf"
|
||
"contentSecurityDemo/internal/model"
|
||
|
||
"github.com/aliyun/alibaba-cloud-sdk-go/sdk/requests"
|
||
"github.com/aliyun/alibaba-cloud-sdk-go/services/green"
|
||
)
|
||
|
||
// ImageScanner 图片内容安全扫描器2.0
|
||
type ImageScanner struct {
|
||
Client *green.Client
|
||
Config *conf.Config
|
||
}
|
||
|
||
// ImageServiceType 图片审核服务类型
|
||
type ImageServiceType string
|
||
|
||
const (
|
||
// BaselineCheckGlobal 通用基线检测
|
||
BaselineCheckGlobal ImageServiceType = "baselineCheck_global"
|
||
// PostImageCheckByVLGlobal 大小模型融合图片审核服务
|
||
PostImageCheckByVLGlobal ImageServiceType = "postImageCheckByVL_global"
|
||
// AigcDetectorGlobal AI生成图片鉴别
|
||
AigcDetectorGlobal ImageServiceType = "aigcDetector_global"
|
||
)
|
||
|
||
// NewImageScanner 创建图片扫描器
|
||
func NewImageScanner(config *conf.Config) (*ImageScanner, error) {
|
||
// 获取有效的访问凭证
|
||
accessKeyID, accessKeySecret, securityToken := config.GetEffectiveCredentials()
|
||
|
||
var client *green.Client
|
||
var err error
|
||
|
||
if securityToken != "" {
|
||
// 使用STS临时凭证
|
||
client, err = green.NewClientWithStsToken(config.Region, accessKeyID, accessKeySecret, securityToken)
|
||
} else {
|
||
// 使用直接凭证
|
||
client, err = green.NewClientWithAccessKey(config.Region, accessKeyID, accessKeySecret)
|
||
}
|
||
|
||
if err != nil {
|
||
return nil, fmt.Errorf("创建客户端失败: %v", err)
|
||
}
|
||
|
||
return &ImageScanner{
|
||
Client: client,
|
||
Config: config,
|
||
}, nil
|
||
}
|
||
|
||
// ScanImageByURL 通过URL扫描图片(2.0版本)
|
||
func (s *ImageScanner) ScanImageByURL(imageURL string, dataID string, serviceType ImageServiceType) (*model.ImageScanResponse, error) {
|
||
request := requests.NewCommonRequest()
|
||
request.Method = "POST"
|
||
request.Scheme = "https"
|
||
request.Domain = s.Config.Endpoint
|
||
request.Version = "2022-03-02" // 内容安全2.0图片审核API版本
|
||
request.ApiName = "ImageSyncScan"
|
||
request.QueryParams["RegionId"] = s.Config.Region
|
||
|
||
scanRequest := model.ImageScanRequest{
|
||
Tasks: []model.ImageTask{
|
||
{
|
||
DataID: dataID,
|
||
URL: imageURL,
|
||
},
|
||
},
|
||
Services: []string{string(serviceType)}, // 使用服务类型而不是场景
|
||
}
|
||
|
||
requestBody, err := json.Marshal(scanRequest)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("序列化请求失败: %v", err)
|
||
}
|
||
|
||
request.Content = requestBody
|
||
|
||
response, err := s.Client.ProcessCommonRequest(request)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("请求失败: %v", err)
|
||
}
|
||
|
||
var scanResponse model.ImageScanResponse
|
||
if err := json.Unmarshal(response.GetHttpContentBytes(), &scanResponse); err != nil {
|
||
return nil, fmt.Errorf("解析响应失败: %v", err)
|
||
}
|
||
|
||
return &scanResponse, nil
|
||
}
|
||
|
||
// ScanImageByFile 通过文件扫描图片(2.0版本)
|
||
func (s *ImageScanner) ScanImageByFile(filePath string, dataID string, serviceType ImageServiceType) (*model.ImageScanResponse, error) {
|
||
// 读取文件
|
||
file, err := os.Open(filePath)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("打开文件失败: %v", err)
|
||
}
|
||
defer file.Close()
|
||
|
||
// 检查文件大小(限制为9MB)
|
||
fileInfo, err := file.Stat()
|
||
if err != nil {
|
||
return nil, fmt.Errorf("获取文件信息失败: %v", err)
|
||
}
|
||
|
||
if fileInfo.Size() > 9*1024*1024 {
|
||
return nil, fmt.Errorf("文件大小超过9MB限制")
|
||
}
|
||
|
||
// 读取文件内容并转换为base64
|
||
fileContent, err := io.ReadAll(file)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("读取文件失败: %v", err)
|
||
}
|
||
|
||
// 检查文件类型 - 2.0版本支持更多格式
|
||
ext := strings.ToLower(filepath.Ext(filePath))
|
||
supportedExts := []string{".jpg", ".jpeg", ".png", ".bmp", ".gif", ".webp", ".tiff", ".svg", ".ico", ".heic"}
|
||
isSupported := false
|
||
for _, supportedExt := range supportedExts {
|
||
if ext == supportedExt {
|
||
isSupported = true
|
||
break
|
||
}
|
||
}
|
||
|
||
if !isSupported {
|
||
return nil, fmt.Errorf("不支持的文件格式: %s", ext)
|
||
}
|
||
|
||
// 将文件内容转换为base64
|
||
base64Content := fmt.Sprintf("data:image/%s;base64,%s", ext[1:],
|
||
strings.ReplaceAll(string(fileContent), "\n", ""))
|
||
|
||
request := requests.NewCommonRequest()
|
||
request.Method = "POST"
|
||
request.Scheme = "https"
|
||
request.Domain = s.Config.Endpoint
|
||
request.Version = "2022-03-02" // 内容安全2.0图片审核API版本
|
||
request.ApiName = "ImageSyncScan"
|
||
request.QueryParams["RegionId"] = s.Config.Region
|
||
|
||
scanRequest := model.ImageScanRequest{
|
||
Tasks: []model.ImageTask{
|
||
{
|
||
DataID: dataID,
|
||
Content: base64Content,
|
||
},
|
||
},
|
||
Services: []string{string(serviceType)}, // 使用服务类型而不是场景
|
||
}
|
||
|
||
requestBody, err := json.Marshal(scanRequest)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("序列化请求失败: %v", err)
|
||
}
|
||
|
||
request.Content = requestBody
|
||
|
||
response, err := s.Client.ProcessCommonRequest(request)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("请求失败: %v", err)
|
||
}
|
||
|
||
var scanResponse model.ImageScanResponse
|
||
if err := json.Unmarshal(response.GetHttpContentBytes(), &scanResponse); err != nil {
|
||
return nil, fmt.Errorf("解析响应失败: %v", err)
|
||
}
|
||
|
||
return &scanResponse, nil
|
||
}
|
||
|
||
// PrintResult 打印扫描结果(2.0版本)
|
||
func (s *ImageScanner) PrintResult(response *model.ImageScanResponse) {
|
||
fmt.Println("=== 图片内容安全审核结果(2.0版本)===")
|
||
fmt.Printf("状态码: %d\n", response.Code)
|
||
fmt.Printf("消息: %s\n", response.Message)
|
||
|
||
for _, data := range response.Data {
|
||
fmt.Printf("\n数据ID: %s\n", data.DataID)
|
||
fmt.Printf("处理状态: %d - %s\n", data.Code, data.Message)
|
||
|
||
if len(data.Results) == 0 {
|
||
fmt.Println("✅ 未检测到违规内容")
|
||
continue
|
||
}
|
||
|
||
for _, result := range data.Results {
|
||
fmt.Printf("场景: %s\n", result.Scene)
|
||
fmt.Printf("标签: %s\n", result.Label)
|
||
if result.SubLabel != "" {
|
||
fmt.Printf("子标签: %s\n", result.SubLabel)
|
||
}
|
||
fmt.Printf("建议: %s\n", result.Suggestion)
|
||
fmt.Printf("置信度: %.2f\n", result.Rate)
|
||
|
||
// 打印详细信息
|
||
if len(result.Details) > 0 {
|
||
fmt.Println("详细信息:")
|
||
for _, detail := range result.Details {
|
||
fmt.Printf(" 内容: %s\n", detail.Context.Context)
|
||
if len(detail.Context.Pos) > 0 {
|
||
fmt.Printf(" 位置: %v\n", detail.Context.Pos)
|
||
}
|
||
}
|
||
}
|
||
|
||
// 打印扩展信息
|
||
if len(result.Extras) > 0 {
|
||
fmt.Println("扩展信息:")
|
||
for key, value := range result.Extras {
|
||
fmt.Printf(" %s: %v\n", key, value)
|
||
}
|
||
}
|
||
fmt.Println("---")
|
||
}
|
||
}
|
||
}
|