569 lines
15 KiB
Go
569 lines
15 KiB
Go
package imports
|
||
|
||
import (
|
||
"errors"
|
||
"fmt"
|
||
"fonchain-fiee/pkg/config"
|
||
"io"
|
||
"log"
|
||
"net/http"
|
||
"os"
|
||
"path/filepath"
|
||
"strconv"
|
||
"strings"
|
||
"sync"
|
||
"time"
|
||
|
||
"github.com/fonchain_enterprise/utils/objstorage"
|
||
"go.uber.org/zap"
|
||
)
|
||
|
||
type TaskStatus string
|
||
|
||
const (
|
||
TaskPending TaskStatus = "PENDING" //任务排队中
|
||
TaskRunning TaskStatus = "RUNNING" //任务处理中
|
||
TaskSuccessful TaskStatus = "SUCCESSFUL" //任务执行成功
|
||
TaskFailed TaskStatus = "FAILED" //任务执行失败
|
||
TaskCancelled TaskStatus = "CANCELLED" //任务已经取消
|
||
TaskCanceled TaskStatus = "UNKNOWN" //任务不存在
|
||
)
|
||
|
||
var (
|
||
batchProcessor *BatchProcessor
|
||
instanceMutex sync.Mutex
|
||
)
|
||
|
||
const (
|
||
StatusIdle = 0 // 空闲中(可执行新任务)
|
||
StatusProcessing = 1 // 处理中(只能读取进度)
|
||
StatusCompleted = 2 // 已完成(可读取结果)
|
||
)
|
||
|
||
var StatusMap = map[int]string{StatusIdle: "空闲中", StatusProcessing: "处理中", StatusCompleted: "已完成"}
|
||
|
||
type BatchProcessor struct {
|
||
mu sync.RWMutex
|
||
tasks map[string]*ImageTask //任务
|
||
inProgress map[string]bool //是否成功
|
||
pollInterval time.Duration
|
||
status int // 当前状态
|
||
}
|
||
|
||
type ImageTask struct {
|
||
Data excelData
|
||
TaskID string
|
||
Status TaskStatus
|
||
StartTime time.Time
|
||
EndTime time.Time
|
||
RetryCount int
|
||
Error error
|
||
Title string //标题
|
||
Content string //内容
|
||
}
|
||
|
||
func GetBatchProcessorRead() *BatchProcessor {
|
||
instanceMutex.Lock()
|
||
defer instanceMutex.Unlock()
|
||
|
||
if batchProcessor == nil {
|
||
batchProcessor = &BatchProcessor{
|
||
tasks: make(map[string]*ImageTask),
|
||
inProgress: make(map[string]bool),
|
||
pollInterval: 1 * time.Second,
|
||
status: StatusIdle,
|
||
}
|
||
}
|
||
return batchProcessor
|
||
}
|
||
|
||
func GetBatchProcessor() *BatchProcessor {
|
||
instanceMutex.Lock()
|
||
defer instanceMutex.Unlock()
|
||
|
||
if batchProcessor == nil || batchProcessor.status == StatusCompleted {
|
||
batchProcessor = &BatchProcessor{
|
||
tasks: make(map[string]*ImageTask),
|
||
inProgress: make(map[string]bool),
|
||
pollInterval: 1 * time.Second,
|
||
status: StatusIdle,
|
||
}
|
||
}
|
||
return batchProcessor
|
||
}
|
||
|
||
func (p *BatchProcessor) SetStatus(status int) {
|
||
p.mu.Lock()
|
||
defer p.mu.Unlock()
|
||
p.status = status
|
||
|
||
}
|
||
|
||
// GetStatus 获取当前状态
|
||
func (p *BatchProcessor) GetStatus() int {
|
||
p.mu.RLock()
|
||
defer p.mu.RUnlock()
|
||
return p.status
|
||
}
|
||
|
||
// 提交任务
|
||
func (p *BatchProcessor) submitSingleTask(req excelData, num int) error {
|
||
p.mu.Lock()
|
||
defer p.mu.Unlock()
|
||
//title, content, err := p.generateTitleAndContent(req)
|
||
//if err != nil {
|
||
// return fmt.Errorf("生成标题和内容失败: %v", err)
|
||
//}
|
||
var taskId string
|
||
if req.PhotoUrl != "" {
|
||
taskId = strconv.Itoa(num)
|
||
} else {
|
||
id, err := p.generateImage(req)
|
||
if err != nil {
|
||
return fmt.Errorf("生成图片失败: %v", err)
|
||
}
|
||
taskId = id
|
||
}
|
||
|
||
task := &ImageTask{
|
||
//Title: title,
|
||
//Content: content,
|
||
Data: req,
|
||
TaskID: taskId,
|
||
Status: TaskPending,
|
||
StartTime: time.Now(),
|
||
}
|
||
p.tasks[task.TaskID] = task
|
||
p.inProgress[task.TaskID] = false
|
||
return nil
|
||
}
|
||
|
||
func (p *BatchProcessor) generateTitleAndContent(req excelData) (string, string, error) {
|
||
if req.PhotoUrl != "" {
|
||
title, content, err := NewAiGenerator().GenerateTitleAndContentFromImage(
|
||
req.PhotoUrl,
|
||
req.TitleRequire,
|
||
req.ContentRequire,
|
||
)
|
||
if err != nil {
|
||
return "", "", fmt.Errorf("图生文失败: %v", err)
|
||
}
|
||
return title, content, nil
|
||
} else {
|
||
// 无图片:使用文生文
|
||
title, err := p.generateTitle(req)
|
||
if err != nil {
|
||
return "", "", fmt.Errorf("生成标题失败: %v", err)
|
||
}
|
||
|
||
content, err := p.generateContent(req)
|
||
if err != nil {
|
||
return "", "", fmt.Errorf("生成内容失败: %v", err)
|
||
}
|
||
|
||
return title, content, nil
|
||
}
|
||
}
|
||
|
||
func (p *BatchProcessor) generateTitle(req excelData) (string, error) {
|
||
prompt := fmt.Sprintf("请根据以下要求生成一个标题:%s", req.TitleRequire)
|
||
if req.Desc != "" {
|
||
prompt += fmt.Sprintf("\n艺人简介:%s", req.Desc)
|
||
}
|
||
prompt += "\n请直接输出标题,不要包含任何其他文字。"
|
||
|
||
result, err := NewAiGenerator().GenerateTextSync(prompt)
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
|
||
if len(result.Output.Choices) == 0 {
|
||
return "", errors.New("AI未生成标题内容")
|
||
}
|
||
|
||
return strings.TrimSpace(result.Output.Choices[0].Message.Content), nil
|
||
}
|
||
|
||
func (p *BatchProcessor) generateContent(req excelData) (string, error) {
|
||
prompt := fmt.Sprintf("请根据以下要求生成内容:%s", req.ContentRequire)
|
||
if req.Desc != "" {
|
||
prompt += fmt.Sprintf("\n艺人简介:%s", req.Desc)
|
||
}
|
||
prompt += "\n请直接输出内容,不要包含任何其他文字。"
|
||
|
||
result, err := NewAiGenerator().GenerateTextSync(prompt)
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
|
||
if len(result.Output.Choices) == 0 {
|
||
return "", errors.New("AI未生成内容")
|
||
}
|
||
|
||
return strings.TrimSpace(result.Output.Choices[0].Message.Content), nil
|
||
}
|
||
func (p *BatchProcessor) generateImage(req excelData) (string, error) {
|
||
prompt := fmt.Sprintf("请根据以下要求生成内容:%s", req.PhotoRequire)
|
||
if req.Desc != "" {
|
||
prompt += fmt.Sprintf("\n艺人简介:%s", req.Desc)
|
||
}
|
||
|
||
result, err := NewAiGenerator().TextToImage(
|
||
prompt,
|
||
"1024*1024",
|
||
req.PhotoNum,
|
||
)
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
return result.Output.TaskID, nil
|
||
//}
|
||
}
|
||
|
||
func (p *BatchProcessor) GetTaskStatistics() (completed, pending, total int, completedTasks, failedTasks []*ImageTask) {
|
||
p.mu.Lock()
|
||
defer p.mu.Unlock()
|
||
total = len(p.tasks)
|
||
for _, task := range p.tasks {
|
||
if p.inProgress[task.TaskID] { //是否转换成功
|
||
completed++
|
||
if task.Status == TaskSuccessful && task.Error == nil { //转换成功 并且 发布成功
|
||
completedTasks = append(completedTasks, task)
|
||
} else if task.Status == TaskFailed || task.Error != nil { //转换失败 或者 发布失败
|
||
failedTasks = append(failedTasks, task)
|
||
}
|
||
} else {
|
||
pending++
|
||
}
|
||
}
|
||
|
||
return completed, pending, total, completedTasks, failedTasks
|
||
}
|
||
|
||
// func (p *BatchProcessor) StartPolling() {
|
||
// go func() {
|
||
// ticker := time.NewTicker(p.pollInterval)
|
||
// defer ticker.Stop()
|
||
//
|
||
// for range ticker.C {
|
||
// if p.IsAllCompleted() {
|
||
// p.SetStatus(StatusCompleted)
|
||
// zap.L().Info("所有任务已完成,停止轮询")
|
||
// ticker.Stop()
|
||
// break
|
||
// }
|
||
// for i, v := range p.inProgress {
|
||
// if !v {
|
||
// if err := p.UpdateTaskStatuses(i); err != nil {
|
||
// zap.L().Error("批量更新任务状态失败: %v", zap.Error(err))
|
||
// continue
|
||
// }
|
||
// }
|
||
// continue
|
||
// }
|
||
// }
|
||
// }()
|
||
// }
|
||
func (p *BatchProcessor) StartPolling() {
|
||
go func() {
|
||
ticker := time.NewTicker(p.pollInterval) // 1秒间隔
|
||
defer ticker.Stop()
|
||
|
||
// 令牌桶,控制每秒最多10个请求
|
||
tokenBucket := make(chan struct{}, 10)
|
||
|
||
// 每秒补充令牌
|
||
go func() {
|
||
refillTicker := time.NewTicker(time.Second)
|
||
defer refillTicker.Stop()
|
||
|
||
for {
|
||
select {
|
||
case <-refillTicker.C:
|
||
// 每秒补充到10个令牌
|
||
for i := 0; i < 10-len(tokenBucket); i++ {
|
||
select {
|
||
case tokenBucket <- struct{}{}:
|
||
default:
|
||
// 桶已满,跳过
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}()
|
||
|
||
for range ticker.C {
|
||
if p.IsAllCompleted() {
|
||
p.SetStatus(StatusCompleted)
|
||
zap.L().Info("所有任务已完成,停止轮询")
|
||
ticker.Stop()
|
||
break
|
||
}
|
||
|
||
// 获取未完成的任务
|
||
incompleteTasks := p.getIncompleteTasks()
|
||
if len(incompleteTasks) == 0 {
|
||
continue
|
||
}
|
||
|
||
// 处理当前可用的任务(最多10个)
|
||
processedCount := 0
|
||
for _, taskID := range incompleteTasks {
|
||
if processedCount >= 10 {
|
||
break // 本秒已达到10个请求限制
|
||
}
|
||
|
||
select {
|
||
case <-tokenBucket:
|
||
// 获取到令牌,可以发送请求
|
||
processedCount++
|
||
go p.updateTaskWithToken(taskID, tokenBucket)
|
||
default:
|
||
// 没有令牌了,跳过
|
||
break
|
||
}
|
||
}
|
||
|
||
zap.L().Debug("本轮处理任务数量",
|
||
zap.Int("processed", processedCount),
|
||
zap.Int("remaining", len(incompleteTasks)-processedCount))
|
||
}
|
||
}()
|
||
}
|
||
|
||
// 使用令牌更新任务状态
|
||
func (p *BatchProcessor) updateTaskWithToken(taskID string, tokenBucket chan struct{}) {
|
||
defer func() {
|
||
// 任务完成后不返还令牌,由定时器统一补充
|
||
}()
|
||
|
||
if err := p.UpdateTaskStatuses(taskID); err != nil {
|
||
zap.L().Error("更新任务状态失败",
|
||
zap.String("task_id", taskID),
|
||
zap.Error(err))
|
||
}
|
||
}
|
||
|
||
// 获取未完成的任务列表
|
||
func (p *BatchProcessor) getIncompleteTasks() []string {
|
||
p.mu.RLock()
|
||
defer p.mu.RUnlock()
|
||
|
||
var incomplete []string
|
||
for taskID, completed := range p.inProgress {
|
||
if !completed {
|
||
incomplete = append(incomplete, taskID)
|
||
}
|
||
}
|
||
return incomplete
|
||
}
|
||
func (p *BatchProcessor) IsAllCompleted() bool {
|
||
p.mu.RLock()
|
||
defer p.mu.RUnlock()
|
||
|
||
if len(p.inProgress) == 0 {
|
||
return true
|
||
}
|
||
|
||
// 检查是否所有任务都标记为完成
|
||
for _, completed := range p.inProgress {
|
||
if !completed {
|
||
return false
|
||
}
|
||
}
|
||
return true
|
||
}
|
||
|
||
func (p *BatchProcessor) UpdateTaskStatuses(taskId string) (err error) {
|
||
p.mu.Lock()
|
||
defer p.mu.Unlock()
|
||
if p.tasks[taskId].Data.PhotoUrl != "" {
|
||
//生成标题
|
||
title, content, err := p.generateTitleAndContent(p.tasks[taskId].Data)
|
||
if err != nil {
|
||
p.tasks[taskId].Status = TaskFailed
|
||
p.inProgress[taskId] = true
|
||
p.tasks[taskId].EndTime = time.Now()
|
||
return fmt.Errorf("生成标题和内容失败: %v", err)
|
||
}
|
||
p.tasks[taskId].Title = title
|
||
p.tasks[taskId].Content = content
|
||
|
||
if err = publishImage(publishImageReq{
|
||
ArtistName: p.tasks[taskId].Data.ArtistName,
|
||
SubNum: p.tasks[taskId].Data.SubNum,
|
||
Title: p.tasks[taskId].Title,
|
||
Content: p.tasks[taskId].Content,
|
||
GeneratePhotoUrl: []string{p.tasks[taskId].Data.PhotoUrl},
|
||
}); err != nil {
|
||
p.tasks[taskId].Error = err
|
||
p.tasks[taskId].Status = TaskFailed
|
||
p.inProgress[taskId] = true
|
||
p.tasks[taskId].EndTime = time.Now()
|
||
}
|
||
p.tasks[taskId].Status = TaskSuccessful
|
||
p.inProgress[taskId] = true
|
||
p.tasks[taskId].EndTime = time.Now()
|
||
return err
|
||
}
|
||
|
||
getTaskDetailRes, err := NewAiGenerator().GetTaskDetail(taskId)
|
||
if err != nil {
|
||
return fmt.Errorf("查看图片生成结果失败: %v", err)
|
||
}
|
||
// 更新本地任务状态
|
||
if localTask, exists := p.tasks[getTaskDetailRes.Output.TaskID]; exists {
|
||
switch getTaskDetailRes.Output.TaskStatus {
|
||
case "SUCCEEDED":
|
||
if localTask.Status != TaskSuccessful {
|
||
//生成标题
|
||
title, content, err := p.generateTitleAndContent(p.tasks[taskId].Data)
|
||
if err != nil {
|
||
zap.L().Debug("生成标题失败")
|
||
localTask.Status = TaskFailed
|
||
p.tasks[getTaskDetailRes.Output.TaskID].Error = err
|
||
p.inProgress[getTaskDetailRes.Output.TaskID] = true
|
||
p.tasks[getTaskDetailRes.Output.TaskID].EndTime = time.Now()
|
||
return fmt.Errorf("生成标题和内容失败: %v", err)
|
||
}
|
||
p.tasks[taskId].Title = title
|
||
p.tasks[taskId].Content = content
|
||
|
||
// 直接下载并上传到桶
|
||
localTask.EndTime = time.Now()
|
||
urls := make([]string, 0, len(getTaskDetailRes.Output.Results))
|
||
for _, v1 := range getTaskDetailRes.Output.Results {
|
||
urls = append(urls, v1.URL)
|
||
}
|
||
uploadedURLs, err := downloadAndUploadImages(urls)
|
||
if err != nil {
|
||
zap.L().Debug("图片上传失败")
|
||
localTask.Status = TaskFailed
|
||
p.tasks[getTaskDetailRes.Output.TaskID].Error = err
|
||
p.inProgress[getTaskDetailRes.Output.TaskID] = true
|
||
p.tasks[getTaskDetailRes.Output.TaskID].EndTime = time.Now()
|
||
localTask.Error = fmt.Errorf("图片上传失败: %v", err)
|
||
return err
|
||
}
|
||
var messages string
|
||
for _, v1 := range getTaskDetailRes.Output.Results {
|
||
messages += v1.Message
|
||
}
|
||
if err = publishImage(publishImageReq{
|
||
ArtistName: p.tasks[getTaskDetailRes.Output.TaskID].Data.ArtistName,
|
||
SubNum: p.tasks[getTaskDetailRes.Output.TaskID].Data.SubNum,
|
||
Title: p.tasks[getTaskDetailRes.Output.TaskID].Title,
|
||
Content: p.tasks[getTaskDetailRes.Output.TaskID].Content,
|
||
GeneratePhotoUrl: uploadedURLs,
|
||
}); err != nil {
|
||
localTask.Status = TaskFailed
|
||
p.tasks[getTaskDetailRes.Output.TaskID].Error = err
|
||
p.inProgress[getTaskDetailRes.Output.TaskID] = true
|
||
p.tasks[getTaskDetailRes.Output.TaskID].EndTime = time.Now()
|
||
}
|
||
localTask.Status = TaskSuccessful
|
||
p.inProgress[getTaskDetailRes.Output.TaskID] = true
|
||
p.tasks[getTaskDetailRes.Output.TaskID].EndTime = time.Now()
|
||
|
||
}
|
||
case "FAILED":
|
||
if localTask.Status != TaskFailed {
|
||
localTask.Status = TaskFailed
|
||
p.tasks[getTaskDetailRes.Output.TaskID].Error = errors.New("转换失败")
|
||
p.inProgress[getTaskDetailRes.Output.TaskID] = true
|
||
p.tasks[getTaskDetailRes.Output.TaskID].EndTime = time.Now()
|
||
}
|
||
}
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
func downloadAndUploadImages(urls []string) ([]string, error) {
|
||
var uploadedURLs []string
|
||
|
||
for _, result := range urls {
|
||
if result == "" {
|
||
continue
|
||
}
|
||
|
||
// 下载并直接上传到桶
|
||
bucketURL, err := downloadAndUploadToBucket(result)
|
||
if err != nil {
|
||
log.Printf("图片上传失败 [%s]: %v", result, err)
|
||
continue
|
||
}
|
||
|
||
uploadedURLs = append(uploadedURLs, bucketURL)
|
||
log.Printf("图片上传成功: %s -> %s", result, bucketURL)
|
||
}
|
||
|
||
if len(uploadedURLs) == 0 {
|
||
return nil, errors.New("所有图片上传失败")
|
||
}
|
||
|
||
return uploadedURLs, nil
|
||
}
|
||
|
||
func downloadAndUploadToBucket(imageURL string) (string, error) {
|
||
// 创建临时目录
|
||
tempDir := "tmp"
|
||
if err := os.MkdirAll(tempDir, 0755); err != nil {
|
||
return "", fmt.Errorf("创建临时目录失败: %v", err)
|
||
}
|
||
defer os.RemoveAll(tempDir) // 程序结束时清理整个目录
|
||
// 生成唯一文件名
|
||
fileName := fmt.Sprintf("%d.jpg",
|
||
time.Now().Unix())
|
||
|
||
// 构建文件路径
|
||
imgPath := filepath.Join(tempDir, fileName)
|
||
|
||
// 创建文件
|
||
file, err := os.Create(imgPath)
|
||
if err != nil {
|
||
return "", fmt.Errorf("创建文件失败: %v", err)
|
||
}
|
||
defer file.Close()
|
||
|
||
log.Printf("文件创建在: %s", imgPath)
|
||
|
||
// 下载图片到文件
|
||
resp, err := http.Get(imageURL)
|
||
if err != nil {
|
||
return "", fmt.Errorf("下载图片失败: %v", err)
|
||
}
|
||
defer resp.Body.Close()
|
||
|
||
if resp.StatusCode != http.StatusOK {
|
||
return "", fmt.Errorf("下载失败,状态码: %d", resp.StatusCode)
|
||
}
|
||
|
||
// 复制到文件
|
||
_, err = io.Copy(file, resp.Body)
|
||
if err != nil {
|
||
return "", fmt.Errorf("保存文件失败: %v", err)
|
||
}
|
||
|
||
file.Sync()
|
||
|
||
fileBytes, err := os.ReadFile(imgPath)
|
||
if err != nil {
|
||
return "", fmt.Errorf("读取本地文件失败: %v", err)
|
||
}
|
||
|
||
// 上传到桶
|
||
BOSClient, err := objstorage.NewOSS(
|
||
os.Getenv(config.ConfigData.Oss.AccessKeyId),
|
||
os.Getenv(config.ConfigData.Oss.AccessKeySecret),
|
||
os.Getenv(config.ConfigData.Oss.Endpoint),
|
||
)
|
||
_, err = BOSClient.PutObjectFromBytes(os.Getenv(config.ConfigData.Oss.BucketName), fileName, fileBytes)
|
||
if err != nil {
|
||
return "", fmt.Errorf("上传文件失败: %v", err)
|
||
}
|
||
url := fmt.Sprintf("%s/%s", os.Getenv(config.ConfigData.Oss.CdnHost), fileName)
|
||
|
||
return url, nil
|
||
}
|