feat:ai生成图文

This commit is contained in:
bx1834938347-prog 2025-12-05 19:12:09 +08:00
parent a68502d957
commit dea013471a
7 changed files with 542 additions and 461 deletions

View File

@ -1,91 +1,92 @@
package imports package imports
import ( //
"bytes" //import (
"encoding/json" // "bytes"
"fmt" // "encoding/json"
"net/http" // "fmt"
) // "net/http"
//)
//------------------------------------------图生图 //
////------------------------------------------图生图
type Image2ImageRequest struct { //
Model string `json:"model"` //type Image2ImageRequest struct {
Input ImageInput `json:"input"` // Model string `json:"model"`
Params ImageParams `json:"parameters"` // Input ImageInput `json:"input"`
} // Params ImageParams `json:"parameters"`
//}
type ImageInput struct { //
Images []string `json:"images"` //type ImageInput struct {
Prompt string `json:"prompt"` // 可选的条件文本 // Images []string `json:"images"`
} // Prompt string `json:"prompt"` // 可选的条件文本
//}
type ImageParams struct { //
Size string `json:"size,omitempty"` // 输出尺寸 //type ImageParams struct {
Strength float64 `json:"strength"` // 重绘强度0-1 // Size string `json:"size,omitempty"` // 输出尺寸
N int `json:"n,omitempty"` // 生成数量 // Strength float64 `json:"strength"` // 重绘强度0-1
} // N int `json:"n,omitempty"` // 生成数量
//}
type ImageResponse struct { //
Output struct { //type ImageResponse struct {
TaskID string `json:"task_id"` // Output struct {
Results []struct { // TaskID string `json:"task_id"`
URL string `json:"url"` // Results []struct {
} `json:"results"` // URL string `json:"url"`
} `json:"output"` // } `json:"results"`
RequestID string `json:"request_id"` // } `json:"output"`
} // RequestID string `json:"request_id"`
//}
// Image2image 图生图 //
func (g *AiGenerator) Image2image(imagePath string, prompt string, strength float64, size string, n int) (*ImageResponse, error) { //// Image2image 图生图
if g.cfg.APIKey == "" { //func (g *AiGenerator) Image2image(imagePath string, prompt string, strength float64, size string, n int) (*ImageResponse, error) {
return nil, fmt.Errorf("API密钥未配置") // if g.cfg.APIKey == "" {
} // return nil, fmt.Errorf("API密钥未配置")
// }
// 构建请求 //
req := Image2ImageRequest{ // // 构建请求
Model: g.cfg.ImageModel, // req := Image2ImageRequest{
Input: ImageInput{ // Model: g.cfg.ImageModel,
Images: []string{imagePath}, // Input: ImageInput{
Prompt: prompt, // Images: []string{imagePath},
}, // Prompt: prompt,
Params: ImageParams{ // },
Size: size, // Params: ImageParams{
Strength: strength, // Size: size,
N: n, // Strength: strength,
}, // N: n,
} // },
// }
url := g.cfg.BaseURL + g.cfg.Image2ImageURL //
jsonData, err := json.Marshal(req) // url := g.cfg.BaseURL + g.cfg.Image2ImageURL
if err != nil { // jsonData, err := json.Marshal(req)
return nil, err // if err != nil {
} // return nil, err
// }
httpReq, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData)) //
if err != nil { // httpReq, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
return nil, err // if err != nil {
} // return nil, err
// }
httpReq.Header.Set("Content-Type", "application/json") //
httpReq.Header.Set("Authorization", "Bearer "+g.cfg.APIKey) // httpReq.Header.Set("Content-Type", "application/json")
httpReq.Header.Set("X-DashScope-Async", "enable") // httpReq.Header.Set("Authorization", "Bearer "+g.cfg.APIKey)
// httpReq.Header.Set("X-DashScope-Async", "enable")
resp, err := g.client.Do(httpReq) //
if err != nil { // resp, err := g.client.Do(httpReq)
return nil, err // if err != nil {
} // return nil, err
defer resp.Body.Close() // }
// defer resp.Body.Close()
// 解析响应 //
var result ImageResponse // // 解析响应
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { // var result ImageResponse
return nil, fmt.Errorf("响应解析失败: %v", err) // if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
} // return nil, fmt.Errorf("响应解析失败: %v", err)
// }
if resp.StatusCode != http.StatusOK { //
return nil, fmt.Errorf("API错误: %d", resp.StatusCode) // if resp.StatusCode != http.StatusOK {
} // return nil, fmt.Errorf("API错误: %d", resp.StatusCode)
// }
return &result, nil //
} // return &result, nil
//}

View File

@ -249,22 +249,76 @@ func (g *AiGenerator) chatWithText(prompt string) (string, error) {
return strings.TrimSpace(result.Choices[0].Message.Content), nil return strings.TrimSpace(result.Choices[0].Message.Content), nil
} }
func (g *AiGenerator) GenerateImageFromText(prompt, size string, n int) (string, error) { //func (g *AiGenerator) GenerateImageFromText(prompt, size string, n int) (string, error) {
// 构建图片生成提示词 // // 构建图片生成提示词
imagePrompt := fmt.Sprintf(`请根据以下描述生成图片 // imagePrompt := fmt.Sprintf(`请根据以下描述生成图片:
//
//图片描述:%s
//生成数量:%d张
//图片尺寸:%s
//
//请直接生成图片,不要返回任何文字描述。`,
// prompt, n, size)
//
// // 使用文生图API
// result, err := g.TextToImage(imagePrompt, size, n)
// if err != nil {
// return "", err
// }
//
// return result.Output.TaskID, nil
//}
图片描述%s // 文本生成图像
生成数量%d张 //func (g *AiGenerator) TextToImage(prompt, size string, n int) (ImageGenerationResponse, error) {
图片尺寸%s // // 构建图像生成请求
// reqBody := map[string]interface{}{
请直接生成图片不要返回任何文字描述`, // "prompt": prompt,
prompt, n, size) // "n": n,
// "size": size,
// 使用文生图API // "response_format": "url", // 假设返回的格式为图像 URL可以根据实际 API 调整
result, err := g.TextToImage(imagePrompt, size, n) // }
if err != nil { //
return "", err // // 使用图像生成接口
} // url := g.cfg.BaseURL + "/v1/images/generations"
// jsonData, err := json.Marshal(reqBody)
return result.Output.TaskID, nil // if err != nil {
} // return ImageGenerationResponse{}, fmt.Errorf("JSON序列化失败: %v", err)
// }
//
// req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
// if err != nil {
// return ImageGenerationResponse{}, fmt.Errorf("创建请求失败: %v", err)
// }
//
// req.Header.Set("Content-Type", "application/json")
// req.Header.Set("Authorization", "Bearer "+g.cfg.APIKey)
//
// resp, err := g.client.Do(req)
// if err != nil {
// return ImageGenerationResponse{}, fmt.Errorf("API请求失败: %v", err)
// }
// defer resp.Body.Close()
//
// // 读取响应体
// body, err := io.ReadAll(resp.Body)
// if err != nil {
// return ImageGenerationResponse{}, fmt.Errorf("读取响应失败: %v", err)
// }
//
// if resp.StatusCode != http.StatusOK {
// return ImageGenerationResponse{}, fmt.Errorf("API错误: %d, 响应: %s", resp.StatusCode, string(body))
// }
//
// // 解析图像生成响应
// var result ImageGenerationResponse
// if err := json.Unmarshal(body, &result); err != nil {
// return ImageGenerationResponse{}, fmt.Errorf("JSON解析失败: %v, 响应: %s", err, string(body))
// }
//
// if len(result.Data) == 0 {
// return ImageGenerationResponse{}, errors.New("未生成任何图像")
// }
//
// return result, nil
//}

View File

@ -1,99 +1,100 @@
package imports package imports
import ( //
"bytes" //import (
"encoding/json" // "bytes"
"fmt" // "encoding/json"
"net/http" // "fmt"
) // "net/http"
//)
// ----------------------------文生文 //
//// ----------------------------文生文
// 文本生成请求结构 //
type TextGenerationRequest struct { //// 文本生成请求结构
Model string `json:"model"` //type TextGenerationRequest struct {
Input TextInput `json:"input"` // Model string `json:"model"`
Params TextParams `json:"parameters"` // Input TextInput `json:"input"`
} // Params TextParams `json:"parameters"`
//}
type TextInput struct { //
Messages []Message `json:"messages"` //type TextInput struct {
} // Messages []Message `json:"messages"`
//}
type TextParams struct { //
ResultFormat string `json:"result_format,omitempty"` // 结果格式 //type TextParams struct {
MaxTokens int `json:"max_tokens,omitempty"` // 最大token数 // ResultFormat string `json:"result_format,omitempty"` // 结果格式
Temperature float64 `json:"temperature,omitempty"` // 温度参数 // MaxTokens int `json:"max_tokens,omitempty"` // 最大token数
TopP float64 `json:"top_p,omitempty"` // 核采样参数 // Temperature float64 `json:"temperature,omitempty"` // 温度参数
} // TopP float64 `json:"top_p,omitempty"` // 核采样参数
//}
type TextResponse struct { //
Output struct { //type TextResponse struct {
Text string `json:"text"` // Output struct {
FinishReason string `json:"finish_reason"` // Text string `json:"text"`
} `json:"output"` // FinishReason string `json:"finish_reason"`
Usage struct { // } `json:"output"`
InputTokens int `json:"input_tokens"` // Usage struct {
OutputTokens int `json:"output_tokens"` // InputTokens int `json:"input_tokens"`
TotalTokens int `json:"total_tokens"` // OutputTokens int `json:"output_tokens"`
} `json:"usage"` // TotalTokens int `json:"total_tokens"`
RequestID string `json:"request_id"` // } `json:"usage"`
} // RequestID string `json:"request_id"`
//}
// GenerateText 生成文本 //
func (g *AiGenerator) GenerateText(prompt string) (*TextResponse, error) { //// GenerateText 生成文本
if g.cfg.APIKey == "" { //func (g *AiGenerator) GenerateText(prompt string) (*TextResponse, error) {
return nil, fmt.Errorf("API密钥未配置") // if g.cfg.APIKey == "" {
} // return nil, fmt.Errorf("API密钥未配置")
// }
// 构建请求 //
req := TextGenerationRequest{ // // 构建请求
Model: g.cfg.TextModel, // req := TextGenerationRequest{
Input: TextInput{ // Model: g.cfg.TextModel,
Messages: []Message{ // Input: TextInput{
{ // Messages: []Message{
Role: "user", // {
Content: prompt, // Role: "user",
}, // Content: prompt,
}, // },
}, // },
Params: TextParams{ // },
ResultFormat: "message", // Params: TextParams{
MaxTokens: g.cfg.MaxTokens, // ResultFormat: "message",
Temperature: g.cfg.Temperature, // MaxTokens: g.cfg.MaxTokens,
TopP: g.cfg.TopP, // Temperature: g.cfg.Temperature,
}, // TopP: g.cfg.TopP,
} // },
// }
url := g.cfg.BaseURL + g.cfg.TextGenerationURL //
jsonData, err := json.Marshal(req) // url := g.cfg.BaseURL + g.cfg.TextGenerationURL
if err != nil { // jsonData, err := json.Marshal(req)
return nil, err // if err != nil {
} // return nil, err
// }
httpReq, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData)) //
if err != nil { // httpReq, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
return nil, err // if err != nil {
} // return nil, err
// }
httpReq.Header.Set("Content-Type", "application/json") //
httpReq.Header.Set("Authorization", "Bearer "+g.cfg.APIKey) // httpReq.Header.Set("Content-Type", "application/json")
// httpReq.Header.Set("Authorization", "Bearer "+g.cfg.APIKey)
resp, err := g.client.Do(httpReq) //
if err != nil { // resp, err := g.client.Do(httpReq)
return nil, err // if err != nil {
} // return nil, err
defer resp.Body.Close() // }
// defer resp.Body.Close()
// 解析响应 //
var result TextResponse // // 解析响应
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { // var result TextResponse
return nil, fmt.Errorf("文本生成响应解析失败: %v", err) // if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
} // return nil, fmt.Errorf("文本生成响应解析失败: %v", err)
// }
if resp.StatusCode != http.StatusOK { //
return nil, fmt.Errorf("API错误: %d", resp.StatusCode) // if resp.StatusCode != http.StatusOK {
} // return nil, fmt.Errorf("API错误: %d", resp.StatusCode)
// }
return &result, nil //
} // return &result, nil
//}

View File

@ -110,13 +110,3 @@ func (g *AiGenerator) TextToImage(prompt string, size string, n int) (*TextToIma
return &result, nil return &result, nil
} }
func isValidSize(size string) bool {
validSizes := []string{"1024x1024", "720x1280", "1280x720", "512x512"}
for _, validSize := range validSizes {
if size == validSize {
return true
}
}
return false
}

View File

@ -1,115 +1,116 @@
package imports package imports
import ( //
"bytes" //import (
"encoding/json" // "bytes"
"fmt" // "encoding/json"
"io" // "fmt"
"net/http" // "io"
) // "net/http"
//)
// Message 结构体定义 //
type Message struct { //// Message 结构体定义
Role string `json:"role"` //type Message struct {
Content string `json:"content"` // Role string `json:"role"`
} // Content string `json:"content"`
//}
// 同步文本生成请求 //
type SyncTextGenerationRequest struct { //// 同步文本生成请求
Model string `json:"model"` //type SyncTextGenerationRequest struct {
Input SyncTextInput `json:"input"` // Model string `json:"model"`
Parameters SyncTextGenerationParams `json:"parameters"` // Input SyncTextInput `json:"input"`
} // Parameters SyncTextGenerationParams `json:"parameters"`
//}
type SyncTextInput struct { //
Messages []Message `json:"messages"` //type SyncTextInput struct {
} // Messages []Message `json:"messages"`
//}
type SyncTextGenerationParams struct { //
ResultFormat string `json:"result_format,omitempty"` //type SyncTextGenerationParams struct {
MaxTokens int `json:"max_tokens,omitempty"` // ResultFormat string `json:"result_format,omitempty"`
Temperature float64 `json:"temperature,omitempty"` // MaxTokens int `json:"max_tokens,omitempty"`
TopP float64 `json:"top_p,omitempty"` // Temperature float64 `json:"temperature,omitempty"`
TopK int `json:"top_k,omitempty"` // TopP float64 `json:"top_p,omitempty"`
Seed int64 `json:"seed,omitempty"` // TopK int `json:"top_k,omitempty"`
} // Seed int64 `json:"seed,omitempty"`
//}
// 同步文本生成响应 //
type SyncTextGenerationResponse struct { //// 同步文本生成响应
Output struct { //type SyncTextGenerationResponse struct {
Choices []struct { // Output struct {
Message Message `json:"message"` // Choices []struct {
} `json:"choices"` // Message Message `json:"message"`
Text string `json:"text"` // } `json:"choices"`
FinishReason string `json:"finish_reason"` // Text string `json:"text"`
} `json:"output"` // FinishReason string `json:"finish_reason"`
Usage struct { // } `json:"output"`
InputTokens int `json:"input_tokens"` // Usage struct {
OutputTokens int `json:"output_tokens"` // InputTokens int `json:"input_tokens"`
TotalTokens int `json:"total_tokens"` // OutputTokens int `json:"output_tokens"`
} `json:"usage"` // TotalTokens int `json:"total_tokens"`
RequestID string `json:"request_id"` // } `json:"usage"`
} // RequestID string `json:"request_id"`
//}
// 同步文本生成URL //
const DefaultSyncTextGenerationURL = "/api/v1/services/aigc/text-generation/generation" //// 同步文本生成URL
//const DefaultSyncTextGenerationURL = "/api/v1/services/aigc/text-generation/generation"
// 同步生成文本 //
func (g *AiGenerator) GenerateTextSync(prompt string) (*SyncTextGenerationResponse, error) { //// 同步生成文本
if g.cfg.APIKey == "" { //func (g *AiGenerator) GenerateTextSync(prompt string) (*SyncTextGenerationResponse, error) {
return nil, fmt.Errorf("API密钥未配置") // if g.cfg.APIKey == "" {
} // return nil, fmt.Errorf("API密钥未配置")
// }
// 构建请求 //
req := SyncTextGenerationRequest{ // // 构建请求
Model: g.cfg.TextModel, // req := SyncTextGenerationRequest{
Input: SyncTextInput{ // Model: g.cfg.TextModel,
Messages: []Message{ // Input: SyncTextInput{
{ // Messages: []Message{
Role: "user", // {
Content: prompt, // Role: "user",
}, // Content: prompt,
}, // },
}, // },
Parameters: SyncTextGenerationParams{ // },
ResultFormat: "message", // Parameters: SyncTextGenerationParams{
MaxTokens: g.cfg.MaxTokens, // ResultFormat: "message",
Temperature: g.cfg.Temperature, // MaxTokens: g.cfg.MaxTokens,
TopP: g.cfg.TopP, // Temperature: g.cfg.Temperature,
}, // TopP: g.cfg.TopP,
} // },
// }
url := g.cfg.BaseURL + DefaultSyncTextGenerationURL //
jsonData, err := json.Marshal(req) // url := g.cfg.BaseURL + DefaultSyncTextGenerationURL
if err != nil { // jsonData, err := json.Marshal(req)
return nil, fmt.Errorf("JSON序列化失败: %v", err) // if err != nil {
} // return nil, fmt.Errorf("JSON序列化失败: %v", err)
// }
httpReq, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData)) //
if err != nil { // httpReq, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
return nil, fmt.Errorf("创建请求失败: %v", err) // if err != nil {
} // return nil, fmt.Errorf("创建请求失败: %v", err)
// }
httpReq.Header.Set("Content-Type", "application/json") //
httpReq.Header.Set("Authorization", "Bearer "+g.cfg.APIKey) // httpReq.Header.Set("Content-Type", "application/json")
// 注意:这里不设置 X-DashScope-Async 头,使用同步模式 // httpReq.Header.Set("Authorization", "Bearer "+g.cfg.APIKey)
// // 注意:这里不设置 X-DashScope-Async 头,使用同步模式
resp, err := g.client.Do(httpReq) //
if err != nil { // resp, err := g.client.Do(httpReq)
return nil, fmt.Errorf("API请求失败: %v", err) // if err != nil {
} // return nil, fmt.Errorf("API请求失败: %v", err)
defer resp.Body.Close() // }
// defer resp.Body.Close()
if resp.StatusCode != http.StatusOK { //
body, _ := io.ReadAll(resp.Body) // if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("API错误: %d, 响应: %s", resp.StatusCode, string(body)) // body, _ := io.ReadAll(resp.Body)
} // return nil, fmt.Errorf("API错误: %d, 响应: %s", resp.StatusCode, string(body))
// }
// 解析响应 //
var result SyncTextGenerationResponse // // 解析响应
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { // var result SyncTextGenerationResponse
return nil, fmt.Errorf("响应解析失败: %v", err) // if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
} // return nil, fmt.Errorf("响应解析失败: %v", err)
// }
return &result, nil //
} // return &result, nil
//}

View File

@ -120,15 +120,15 @@ func ImageContentImport(c *gin.Context) {
time.Sleep(interval) time.Sleep(interval)
} }
if err := processor.submitSingleTask(&v, i); err != nil { if err := processor.submitSingleTask(&v); err != nil {
task := &ImageTask{ task := &ImageTask{
Data: &v, Data: &v,
TaskID: strconv.Itoa(i), TaskID: i,
Error: err, Error: err,
StartTime: time.Now(), StartTime: time.Now(),
} }
processor.tasks[strconv.Itoa(i)] = task processor.tasks[v.LineNum] = task
processor.inProgress[strconv.Itoa(i)] = true processor.inProgress[v.LineNum] = true
} }
} }
processor.StartPolling() processor.StartPolling()

View File

@ -14,7 +14,6 @@ import (
"os" "os"
"path/filepath" "path/filepath"
"strconv" "strconv"
"strings"
"sync" "sync"
"time" "time"
@ -48,15 +47,16 @@ var StatusMap = map[int]string{StatusIdle: "空闲中", StatusProcessing: "处
type BatchProcessor struct { type BatchProcessor struct {
mu sync.RWMutex mu sync.RWMutex
tasks map[string]*ImageTask //任务 tasks map[int]*ImageTask //任务 编号id
inProgress map[string]bool //是否成功 idFindTaskId map[int]string
inProgress map[int]bool //是否成功
pollInterval time.Duration pollInterval time.Duration
status int // 当前状态 status int // 当前状态
} }
type ImageTask struct { type ImageTask struct {
Data *excelData Data *excelData
TaskID string TaskID int
Status TaskStatus Status TaskStatus
StartTime time.Time StartTime time.Time
EndTime time.Time EndTime time.Time
@ -76,8 +76,9 @@ func GetBatchProcessor() *BatchProcessor {
if batchProcessor == nil || batchProcessor.status == StatusCompleted { if batchProcessor == nil || batchProcessor.status == StatusCompleted {
batchProcessor = &BatchProcessor{ batchProcessor = &BatchProcessor{
tasks: make(map[string]*ImageTask), tasks: make(map[int]*ImageTask),
inProgress: make(map[string]bool), inProgress: make(map[int]bool),
idFindTaskId: make(map[int]string),
pollInterval: 100 * time.Millisecond, pollInterval: 100 * time.Millisecond,
status: StatusIdle, status: StatusIdle,
} }
@ -100,7 +101,7 @@ func (p *BatchProcessor) GetStatus() int {
} }
// 提交任务 // 提交任务
func (p *BatchProcessor) submitSingleTask(req *excelData, num int) error { func (p *BatchProcessor) submitSingleTask(req *excelData) error {
//var infoResp *accountFiee.UserInfoResponse //var infoResp *accountFiee.UserInfoResponse
list, err := service.AccountFieeProvider.UserList(context.Background(), &accountFiee.UserListRequest{ list, err := service.AccountFieeProvider.UserList(context.Background(), &accountFiee.UserListRequest{
Name: req.ArtistName, Name: req.ArtistName,
@ -178,28 +179,61 @@ func (p *BatchProcessor) submitSingleTask(req *excelData, num int) error {
//if err != nil { //if err != nil {
// return fmt.Errorf("生成标题和内容失败: %v", err) // return fmt.Errorf("生成标题和内容失败: %v", err)
//} //}
var taskId string
if req.PhotoUrl != "" { if req.PhotoUrl == "" {
taskId = strconv.Itoa(num) //生成标题
} else { title, content, err := p.generateTitleAndContent(req)
id, err := p.generateImage(req)
if err != nil { if err != nil {
zap.L().Error("生成图片失败", zap.Error(err)) //task := &ImageTask{
// Data: req,
// Error: err,
// Status: TaskFailed,
// StartTime: time.Now(),
//}
//p.tasks[req.LineNum] = task
//p.inProgress[req.LineNum] = true
//zap.L().Error("生成标题和内容失败: %v", zap.Error(err))
return fmt.Errorf("生成标题失败")
}
req.Title = title
req.Content = content
//请求图片
generateImageRes, err := p.generateImage(req)
if err != nil {
//task := &ImageTask{
// Data: req,
// Error: err,
// Status: TaskFailed,
// StartTime: time.Now(),
//}
//p.tasks[req.LineNum] = task
//p.inProgress[req.LineNum] = true
//zap.L().Error("生成图片失败", zap.Error(err))
return fmt.Errorf("生成图片失败") return fmt.Errorf("生成图片失败")
} }
taskId = id task := &ImageTask{
} Title: title,
Content: content,
Data: req,
TaskID: req.LineNum,
Status: TaskPending,
StartTime: time.Now(),
}
p.idFindTaskId[req.LineNum] = generateImageRes
p.tasks[req.LineNum] = task
p.inProgress[req.LineNum] = false
return nil
}
task := &ImageTask{ task := &ImageTask{
//Title: title,
//Content: content,
Data: req, Data: req,
TaskID: taskId, TaskID: req.LineNum,
Status: TaskPending, Status: TaskPending,
StartTime: time.Now(), StartTime: time.Now(),
} }
p.tasks[task.TaskID] = task p.tasks[req.LineNum] = task
p.inProgress[task.TaskID] = false p.inProgress[req.LineNum] = false
return nil return nil
} }
@ -237,49 +271,49 @@ func (p *BatchProcessor) generateTitleAndContent(req *excelData) (string, string
} }
} }
func (p *BatchProcessor) generateTitle(req *excelData) (string, error) { //func (p *BatchProcessor) generateTitle(req *excelData) (string, error) {
prompt := fmt.Sprintf("请根据以下要求生成一个标题:%s", req.TitleRequire) // prompt := fmt.Sprintf("请根据以下要求生成一个标题:%s", req.TitleRequire)
if req.Desc != "" { // if req.Desc != "" {
prompt += fmt.Sprintf("\n艺人简介%s", req.Desc) // prompt += fmt.Sprintf("\n艺人简介%s", req.Desc)
} // }
prompt += "\n请直接输出标题不要包含任何其他文字。" // prompt += "\n请直接输出标题不要包含任何其他文字。"
//
result, err := NewAiGenerator().GenerateTextSync(prompt) // result, err := NewAiGenerator().GenerateTextSync(prompt)
if err != nil { // if err != nil {
return "", err // return "", err
} // }
//
if len(result.Output.Choices) == 0 { // if len(result.Output.Choices) == 0 {
return "", errors.New("AI未生成标题内容") // return "", errors.New("AI未生成标题内容")
} // }
//
req.Title = strings.TrimSpace(result.Output.Choices[0].Message.Content) // req.Title = strings.TrimSpace(result.Output.Choices[0].Message.Content)
return req.Title, nil // return req.Title, nil
} //}
//
func (p *BatchProcessor) generateContent(req *excelData) (string, error) { //func (p *BatchProcessor) generateContent(req *excelData) (string, error) {
// 使用已生成的标题作为上下文 // // 使用已生成的标题作为上下文
prompt := fmt.Sprintf("请根据以下要求生成内容:%s", req.ContentRequire) // prompt := fmt.Sprintf("请根据以下要求生成内容:%s", req.ContentRequire)
if req.Title != "" { // if req.Title != "" {
prompt += fmt.Sprintf("\n标题%s", req.Title) // 关联标题 // prompt += fmt.Sprintf("\n标题%s", req.Title) // 关联标题
} // }
if req.Desc != "" { // if req.Desc != "" {
prompt += fmt.Sprintf("\n艺人简介%s", req.Desc) // prompt += fmt.Sprintf("\n艺人简介%s", req.Desc)
} // }
prompt += "\n请基于标题生成相关内容直接输出内容不要包含任何其他文字。" // prompt += "\n请基于标题生成相关内容直接输出内容不要包含任何其他文字。"
//
result, err := NewAiGenerator().GenerateTextSync(prompt) // result, err := NewAiGenerator().GenerateTextSync(prompt)
if err != nil { // if err != nil {
return "", err // return "", err
} // }
//
if len(result.Output.Choices) == 0 { // if len(result.Output.Choices) == 0 {
return "", errors.New("AI未生成内容") // return "", errors.New("AI未生成内容")
} // }
//
req.Content = strings.TrimSpace(result.Output.Choices[0].Message.Content) // req.Content = strings.TrimSpace(result.Output.Choices[0].Message.Content)
return req.Content, nil // return req.Content, nil
} //}
//func (p *BatchProcessor) generateImage(req *excelData) (string, error) { //func (p *BatchProcessor) generateImage(req *excelData) (string, error) {
// prompt := fmt.Sprintf("请根据以下要求生成内容:%s", req.PhotoRequire) // prompt := fmt.Sprintf("请根据以下要求生成内容:%s", req.PhotoRequire)
@ -308,17 +342,17 @@ func (p *BatchProcessor) generateContent(req *excelData) (string, error) {
func (p *BatchProcessor) generateImage(req *excelData) (string, error) { func (p *BatchProcessor) generateImage(req *excelData) (string, error) {
prompt := fmt.Sprintf("请根据以下要求生成内容:%s", req.PhotoRequire) prompt := fmt.Sprintf("请根据以下要求生成内容:%s", req.PhotoRequire)
if req.Title != "" { if req.Title != "" {
prompt += fmt.Sprintf("\n标题:%s", req.Title) // 关联标题 prompt += fmt.Sprintf("1标题:%s", req.Title) // 关联标题
} }
if req.Content != "" { if req.Content != "" {
prompt += fmt.Sprintf("\n内容%s", req.Content) // 关联内容 prompt += fmt.Sprintf("2内容%s", req.Content) // 关联内容
}
if req.Desc != "" {
prompt += fmt.Sprintf("\n艺人简介%s", req.Desc)
} }
//if req.Desc != "" {
// prompt += fmt.Sprintf("3艺人简介%s", req.Desc)
//}
prompt += "\n请基于标题和内容生成相关内容" prompt += "\n请基于标题和内容生成相关内容"
result, err := NewAiGenerator().GenerateImageFromText( result, err := NewAiGenerator().TextToImage(
prompt, prompt,
"1024*1024", "1024*1024",
req.PhotoNum, req.PhotoNum,
@ -326,7 +360,7 @@ func (p *BatchProcessor) generateImage(req *excelData) (string, error) {
if err != nil { if err != nil {
return "", err return "", err
} }
return result, nil return result.Output.TaskID, nil
} }
func (p *BatchProcessor) GetTaskStatistics() (completed, pending, total int, completedTasks, failedTasks []*ImageTask) { func (p *BatchProcessor) GetTaskStatistics() (completed, pending, total int, completedTasks, failedTasks []*ImageTask) {
@ -455,11 +489,11 @@ func (p *BatchProcessor) StartPolling() {
//} //}
// 获取未完成的任务列表 // 获取未完成的任务列表
func (p *BatchProcessor) getIncompleteTasks() []string { func (p *BatchProcessor) getIncompleteTasks() []int {
p.mu.RLock() p.mu.RLock()
defer p.mu.RUnlock() defer p.mu.RUnlock()
var incomplete []string var incomplete []int
for taskID, completed := range p.inProgress { for taskID, completed := range p.inProgress {
if !completed { if !completed {
incomplete = append(incomplete, taskID) incomplete = append(incomplete, taskID)
@ -484,67 +518,67 @@ func (p *BatchProcessor) IsAllCompleted() bool {
return true return true
} }
func (p *BatchProcessor) UpdateTaskStatuses(taskId string) (err error) { func (p *BatchProcessor) UpdateTaskStatuses(id int) (err error) {
p.mu.Lock() p.mu.Lock()
defer p.mu.Unlock() defer p.mu.Unlock()
if p.tasks[taskId].Data.PhotoUrl != "" { //如果有图片 if p.tasks[id].Data.PhotoUrl != "" { //如果有图片
title, content, err := p.generateTitleAndContent(p.tasks[taskId].Data) title, content, err := p.generateTitleAndContent(p.tasks[id].Data)
if err != nil { if err != nil {
p.tasks[taskId].Status = TaskFailed p.tasks[id].Status = TaskFailed
p.inProgress[taskId] = true p.inProgress[id] = true
p.tasks[taskId].EndTime = time.Now() p.tasks[id].EndTime = time.Now()
zap.L().Error("生成标题和内容失败: %v", zap.Error(err)) zap.L().Error("生成标题和内容失败: %v", zap.Error(err))
p.tasks[taskId].Error = fmt.Errorf("生成标题和内容失败") p.tasks[id].Error = fmt.Errorf("生成标题和内容失败")
return err return err
} }
p.tasks[taskId].Title = title p.tasks[id].Title = title
p.tasks[taskId].Content = content p.tasks[id].Content = content
if err = publishImage(publishImageReq{ if err = publishImage(publishImageReq{
ArtistName: p.tasks[taskId].Data.ArtistName, ArtistName: p.tasks[id].Data.ArtistName,
SubNum: p.tasks[taskId].Data.SubNum, SubNum: p.tasks[id].Data.SubNum,
Title: p.tasks[taskId].Title, Title: p.tasks[id].Title,
Content: p.tasks[taskId].Content, Content: p.tasks[id].Content,
TikTok: p.tasks[taskId].Data.TikTok, TikTok: p.tasks[id].Data.TikTok,
Instagram: p.tasks[taskId].Data.Instagram, Instagram: p.tasks[id].Data.Instagram,
GeneratePhotoUrl: []string{p.tasks[taskId].Data.PhotoUrl}, GeneratePhotoUrl: []string{p.tasks[id].Data.PhotoUrl},
MediaAccountUuids: p.tasks[taskId].Data.MediaAccountUuids, MediaAccountUuids: p.tasks[id].Data.MediaAccountUuids,
MediaAccountNames: p.tasks[taskId].Data.MediaAccountNames, MediaAccountNames: p.tasks[id].Data.MediaAccountNames,
}); err != nil { }); err != nil {
p.tasks[taskId].Status = TaskFailed p.tasks[id].Status = TaskFailed
p.inProgress[taskId] = true p.inProgress[id] = true
p.tasks[taskId].EndTime = time.Now() p.tasks[id].EndTime = time.Now()
zap.L().Error("发布内容失败: %v", zap.Error(err)) zap.L().Error("发布内容失败: %v", zap.Error(err))
p.tasks[taskId].Error = fmt.Errorf("发布内容失败") p.tasks[id].Error = fmt.Errorf("发布内容失败")
} }
p.tasks[taskId].Status = TaskSuccessful p.tasks[id].Status = TaskSuccessful
p.inProgress[taskId] = true p.inProgress[id] = true
p.tasks[taskId].EndTime = time.Now() p.tasks[id].EndTime = time.Now()
return err return err
} }
getTaskDetailRes, err := NewAiGenerator().GetTaskDetail(taskId) getTaskDetailRes, err := NewAiGenerator().GetTaskDetail(p.idFindTaskId[id])
if err != nil { if err != nil {
zap.L().Error("查看图片生成结果失败: %v", zap.Error(err)) zap.L().Error("查看图片生成结果失败: %v", zap.Error(err))
return fmt.Errorf("查看图片生成结果失败") return fmt.Errorf("查看图片生成结果失败")
} }
// 更新本地任务状态 // 更新本地任务状态
if localTask, exists := p.tasks[getTaskDetailRes.Output.TaskID]; exists { if localTask, exists := p.tasks[id]; exists {
switch getTaskDetailRes.Output.TaskStatus { switch getTaskDetailRes.Output.TaskStatus {
case "SUCCEEDED": case "SUCCEEDED":
if localTask.Status != TaskSuccessful { if localTask.Status != TaskSuccessful {
//生成标题 ////生成标题
title, content, err := p.generateTitleAndContent(p.tasks[taskId].Data) //title, content, err := p.generateTitleAndContent(p.tasks[taskId].Data)
if err != nil { //if err != nil {
localTask.Status = TaskFailed // localTask.Status = TaskFailed
p.tasks[getTaskDetailRes.Output.TaskID].Error = err // p.tasks[getTaskDetailRes.Output.TaskID].Error = err
p.inProgress[getTaskDetailRes.Output.TaskID] = true // p.inProgress[getTaskDetailRes.Output.TaskID] = true
p.tasks[getTaskDetailRes.Output.TaskID].EndTime = time.Now() // p.tasks[getTaskDetailRes.Output.TaskID].EndTime = time.Now()
zap.L().Error("生成标题和内容失败: %v", zap.Error(err)) // zap.L().Error("生成标题和内容失败: %v", zap.Error(err))
return fmt.Errorf("生成标题和内容失败") // return fmt.Errorf("生成标题和内容失败")
} //}
p.tasks[taskId].Title = title //p.tasks[taskId].Title = title
p.tasks[taskId].Content = content //p.tasks[taskId].Content = content
//上传图片 //上传图片
localTask.EndTime = time.Now() localTask.EndTime = time.Now()
@ -555,9 +589,9 @@ func (p *BatchProcessor) UpdateTaskStatuses(taskId string) (err error) {
uploadedURLs, err := downloadAndUploadImages(urls) uploadedURLs, err := downloadAndUploadImages(urls)
if err != nil { if err != nil {
localTask.Status = TaskFailed localTask.Status = TaskFailed
p.tasks[getTaskDetailRes.Output.TaskID].Error = err p.tasks[id].Error = err
p.inProgress[getTaskDetailRes.Output.TaskID] = true p.inProgress[id] = true
p.tasks[getTaskDetailRes.Output.TaskID].EndTime = time.Now() p.tasks[id].EndTime = time.Now()
zap.L().Error("图片上传失败: %v", zap.Error(err)) zap.L().Error("图片上传失败: %v", zap.Error(err))
localTask.Error = fmt.Errorf("图片上传失败") localTask.Error = fmt.Errorf("图片上传失败")
return err return err
@ -569,34 +603,34 @@ func (p *BatchProcessor) UpdateTaskStatuses(taskId string) (err error) {
//发布 //发布
if err = publishImage(publishImageReq{ if err = publishImage(publishImageReq{
ArtistName: p.tasks[getTaskDetailRes.Output.TaskID].Data.ArtistName, ArtistName: p.tasks[id].Data.ArtistName,
SubNum: p.tasks[getTaskDetailRes.Output.TaskID].Data.SubNum, SubNum: p.tasks[id].Data.SubNum,
Title: p.tasks[getTaskDetailRes.Output.TaskID].Title, Title: p.tasks[id].Title,
Content: p.tasks[getTaskDetailRes.Output.TaskID].Content, Content: p.tasks[id].Content,
TikTok: p.tasks[taskId].Data.TikTok, TikTok: p.tasks[id].Data.TikTok,
Instagram: p.tasks[taskId].Data.Instagram, Instagram: p.tasks[id].Data.Instagram,
MediaAccountUuids: p.tasks[taskId].Data.MediaAccountUuids, MediaAccountUuids: p.tasks[id].Data.MediaAccountUuids,
MediaAccountNames: p.tasks[taskId].Data.MediaAccountNames, MediaAccountNames: p.tasks[id].Data.MediaAccountNames,
GeneratePhotoUrl: uploadedURLs, GeneratePhotoUrl: uploadedURLs,
}); err != nil { }); err != nil {
localTask.Status = TaskFailed localTask.Status = TaskFailed
p.tasks[getTaskDetailRes.Output.TaskID].Error = err p.tasks[id].Error = err
p.inProgress[getTaskDetailRes.Output.TaskID] = true p.inProgress[id] = true
p.tasks[getTaskDetailRes.Output.TaskID].EndTime = time.Now() p.tasks[id].EndTime = time.Now()
} }
//成功 //成功
localTask.Status = TaskSuccessful localTask.Status = TaskSuccessful
p.inProgress[getTaskDetailRes.Output.TaskID] = true p.inProgress[id] = true
p.tasks[getTaskDetailRes.Output.TaskID].EndTime = time.Now() p.tasks[id].EndTime = time.Now()
} }
case "FAILED": case "FAILED":
if localTask.Status != TaskFailed { if localTask.Status != TaskFailed {
localTask.Status = TaskFailed localTask.Status = TaskFailed
p.tasks[getTaskDetailRes.Output.TaskID].Error = errors.New("生成失败") p.tasks[id].Error = errors.New("生成失败")
p.inProgress[getTaskDetailRes.Output.TaskID] = true p.inProgress[id] = true
p.tasks[getTaskDetailRes.Output.TaskID].EndTime = time.Now() p.tasks[id].EndTime = time.Now()
} }
} }
} }