183 lines
4.7 KiB
Go
183 lines
4.7 KiB
Go
package api
|
||
|
||
import (
|
||
"encoding/json"
|
||
"fmt"
|
||
"io"
|
||
"os"
|
||
"path/filepath"
|
||
"strings"
|
||
|
||
"contentSecurityDemo/conf"
|
||
"contentSecurityDemo/internal/model"
|
||
|
||
"github.com/aliyun/alibaba-cloud-sdk-go/sdk/requests"
|
||
"github.com/aliyun/alibaba-cloud-sdk-go/services/green"
|
||
)
|
||
|
||
// ImageScanner 图片内容安全扫描器
|
||
type ImageScanner struct {
|
||
Client *green.Client
|
||
Config *conf.Config
|
||
}
|
||
|
||
// NewImageScanner 创建图片扫描器
|
||
func NewImageScanner(config *conf.Config) (*ImageScanner, error) {
|
||
client, err := green.NewClientWithAccessKey(config.Region, config.AccessKeyID, config.AccessKeySecret)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("创建客户端失败: %v", err)
|
||
}
|
||
|
||
return &ImageScanner{
|
||
Client: client,
|
||
Config: config,
|
||
}, nil
|
||
}
|
||
|
||
// ScanImageByURL 通过URL扫描图片
|
||
func (s *ImageScanner) ScanImageByURL(imageURL string, dataID string) (*model.ImageScanResponse, error) {
|
||
request := requests.NewCommonRequest()
|
||
request.Method = "POST"
|
||
request.Scheme = "https"
|
||
request.Domain = s.Config.Endpoint
|
||
request.Version = "2018-05-09"
|
||
request.ApiName = "ImageSyncScan"
|
||
request.QueryParams["RegionId"] = s.Config.Region
|
||
|
||
scanRequest := model.ImageScanRequest{
|
||
Tasks: []model.ImageTask{
|
||
{
|
||
DataID: dataID,
|
||
URL: imageURL,
|
||
},
|
||
},
|
||
Scenes: []string{"porn", "terrorism", "ad", "live", "logo"},
|
||
}
|
||
|
||
requestBody, err := json.Marshal(scanRequest)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("序列化请求失败: %v", err)
|
||
}
|
||
|
||
request.Content = requestBody
|
||
|
||
response, err := s.Client.ProcessCommonRequest(request)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("请求失败: %v", err)
|
||
}
|
||
|
||
var scanResponse model.ImageScanResponse
|
||
if err := json.Unmarshal(response.GetHttpContentBytes(), &scanResponse); err != nil {
|
||
return nil, fmt.Errorf("解析响应失败: %v", err)
|
||
}
|
||
|
||
return &scanResponse, nil
|
||
}
|
||
|
||
// ScanImageByFile 通过文件扫描图片
|
||
func (s *ImageScanner) ScanImageByFile(filePath string, dataID string) (*model.ImageScanResponse, error) {
|
||
// 读取文件
|
||
file, err := os.Open(filePath)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("打开文件失败: %v", err)
|
||
}
|
||
defer file.Close()
|
||
|
||
// 检查文件大小(限制为9MB)
|
||
fileInfo, err := file.Stat()
|
||
if err != nil {
|
||
return nil, fmt.Errorf("获取文件信息失败: %v", err)
|
||
}
|
||
|
||
if fileInfo.Size() > 9*1024*1024 {
|
||
return nil, fmt.Errorf("文件大小超过9MB限制")
|
||
}
|
||
|
||
// 读取文件内容并转换为base64
|
||
fileContent, err := io.ReadAll(file)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("读取文件失败: %v", err)
|
||
}
|
||
|
||
// 检查文件类型
|
||
ext := strings.ToLower(filepath.Ext(filePath))
|
||
supportedExts := []string{".jpg", ".jpeg", ".png", ".bmp", ".gif", ".webp"}
|
||
isSupported := false
|
||
for _, supportedExt := range supportedExts {
|
||
if ext == supportedExt {
|
||
isSupported = true
|
||
break
|
||
}
|
||
}
|
||
|
||
if !isSupported {
|
||
return nil, fmt.Errorf("不支持的文件格式: %s", ext)
|
||
}
|
||
|
||
// 将文件内容转换为base64
|
||
base64Content := fmt.Sprintf("data:image/%s;base64,%s", ext[1:],
|
||
strings.ReplaceAll(string(fileContent), "\n", ""))
|
||
|
||
request := requests.NewCommonRequest()
|
||
request.Method = "POST"
|
||
request.Scheme = "https"
|
||
request.Domain = s.Config.Endpoint
|
||
request.Version = "2018-05-09"
|
||
request.ApiName = "ImageSyncScan"
|
||
request.QueryParams["RegionId"] = s.Config.Region
|
||
|
||
scanRequest := model.ImageScanRequest{
|
||
Tasks: []model.ImageTask{
|
||
{
|
||
DataID: dataID,
|
||
Content: base64Content,
|
||
},
|
||
},
|
||
Scenes: []string{"porn", "terrorism", "ad", "live", "logo"},
|
||
}
|
||
|
||
requestBody, err := json.Marshal(scanRequest)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("序列化请求失败: %v", err)
|
||
}
|
||
|
||
request.Content = requestBody
|
||
|
||
response, err := s.Client.ProcessCommonRequest(request)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("请求失败: %v", err)
|
||
}
|
||
|
||
var scanResponse model.ImageScanResponse
|
||
if err := json.Unmarshal(response.GetHttpContentBytes(), &scanResponse); err != nil {
|
||
return nil, fmt.Errorf("解析响应失败: %v", err)
|
||
}
|
||
|
||
return &scanResponse, nil
|
||
}
|
||
|
||
// PrintResult 打印扫描结果
|
||
func (s *ImageScanner) PrintResult(response *model.ImageScanResponse) {
|
||
fmt.Println("=== 图片内容安全审核结果 ===")
|
||
fmt.Printf("状态码: %d\n", response.Code)
|
||
fmt.Printf("消息: %s\n", response.Message)
|
||
|
||
for _, data := range response.Data {
|
||
fmt.Printf("\n数据ID: %s\n", data.DataID)
|
||
fmt.Printf("处理状态: %d - %s\n", data.Code, data.Message)
|
||
|
||
if len(data.Results) == 0 {
|
||
fmt.Println("✅ 未检测到违规内容")
|
||
continue
|
||
}
|
||
|
||
for _, result := range data.Results {
|
||
fmt.Printf("场景: %s\n", result.Scene)
|
||
fmt.Printf("标签: %s\n", result.Label)
|
||
fmt.Printf("建议: %s\n", result.Suggestion)
|
||
fmt.Printf("置信度: %.2f\n", result.Rate)
|
||
fmt.Println("---")
|
||
}
|
||
}
|
||
}
|