diff --git a/pkg/utils/files.go b/pkg/utils/files.go index 2f9c436..1f780c4 100644 --- a/pkg/utils/files.go +++ b/pkg/utils/files.go @@ -7,6 +7,7 @@ import ( "io" "net/http" "os" + "strconv" "strings" "go.uber.org/zap" @@ -50,3 +51,50 @@ func SaveUrlFileDisk(url string, path string, filename string) (fullPath string, } return } + +// GetRemoteFileSize 通过HTTP HEAD请求获取远程文件大小(不下载文件) +func GetRemoteFileSize(url string) (size int64, err error) { + // 创建HEAD请求 + req, err := http.NewRequest("HEAD", url, nil) + if err != nil { + zap.L().Error("GetRemoteFileSize create request err", zap.String("url", url), zap.Error(err)) + err = errors.New(e.GetMsg(e.ERROR_DOWNLOAD_FILE)) + return + } + + // 发送请求 + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + zap.L().Error("GetRemoteFileSize request err", zap.String("url", url), zap.Error(err)) + err = errors.New(e.GetMsg(e.ERROR_DOWNLOAD_FILE)) + return + } + defer resp.Body.Close() + + // 检查HTTP状态码 + if resp.StatusCode != http.StatusOK { + zap.L().Error("GetRemoteFileSize status code err", zap.String("url", url), zap.Int("status", resp.StatusCode)) + err = errors.New(e.GetMsg(e.ERROR_DOWNLOAD_FILE)) + return + } + + // 获取Content-Length头部 + contentLength := resp.Header.Get("Content-Length") + if contentLength == "" { + zap.L().Error("GetRemoteFileSize Content-Length header not found", zap.String("url", url)) + err = errors.New("无法获取文件大小") + return + } + + // 解析文件大小 + size, err = strconv.ParseInt(contentLength, 10, 64) + size = size / 1024 / 1024 + if err != nil { + zap.L().Error("GetRemoteFileSize parse size err", zap.String("url", url), zap.String("contentLength", contentLength), zap.Error(err)) + err = errors.New("解析文件大小失败") + return + } + + return +}