package service import ( "bytes" "encoding/base64" "fmt" "image" _ "image/gif" _ "image/jpeg" _ "image/png" "io" "net/http" "strings" "github.com/QuantumNous/new-api/common" "github.com/QuantumNous/new-api/constant" "github.com/QuantumNous/new-api/logger" "github.com/QuantumNous/new-api/types" "github.com/gin-gonic/gin" "golang.org/x/image/webp" ) // FileService 统一的文件处理服务 // 提供文件下载、解码、缓存等功能的统一入口 // getContextCacheKey 生成 context 缓存的 key func getContextCacheKey(url string) string { return fmt.Sprintf("file_cache_%s", common.GenerateHMAC(url)) } // LoadFileSource 加载文件源数据 // 这是统一的入口,会自动处理缓存和不同的来源类型 func LoadFileSource(c *gin.Context, source *types.FileSource, reason ...string) (*types.CachedFileData, error) { if source == nil { return nil, fmt.Errorf("file source is nil") } // 如果已有缓存,直接返回 if source.HasCache() { return source.GetCache(), nil } var cachedData *types.CachedFileData var err error if source.IsURL() { cachedData, err = loadFromURL(c, source.URL, reason...) } else { cachedData, err = loadFromBase64(source.Base64Data, source.MimeType) } if err != nil { return nil, err } // 设置缓存 source.SetCache(cachedData) // 注册到 context 以便请求结束时自动清理 if c != nil { registerSourceForCleanup(c, source) } return cachedData, nil } // registerSourceForCleanup 注册 FileSource 到 context 以便请求结束时清理 func registerSourceForCleanup(c *gin.Context, source *types.FileSource) { key := string(constant.ContextKeyFileSourcesToCleanup) var sources []*types.FileSource if existing, exists := c.Get(key); exists { sources = existing.([]*types.FileSource) } sources = append(sources, source) c.Set(key, sources) } // CleanupFileSources 清理请求中所有注册的 FileSource // 应在请求结束时调用(通常由中间件自动调用) func CleanupFileSources(c *gin.Context) { key := string(constant.ContextKeyFileSourcesToCleanup) if sources, exists := c.Get(key); exists { for _, source := range sources.([]*types.FileSource) { if cache := source.GetCache(); cache != nil { if cache.IsDisk() { common.DecrementDiskFiles(cache.Size) } cache.Close() } } c.Set(key, nil) // 清除引用 } } // loadFromURL 从 URL 加载文件 // 支持磁盘缓存:当文件大小超过阈值且磁盘缓存可用时,将数据存储到磁盘 func loadFromURL(c *gin.Context, url string, reason ...string) (*types.CachedFileData, error) { contextKey := getContextCacheKey(url) // 检查 context 缓存 if cachedData, exists := c.Get(contextKey); exists { if common.DebugEnabled { logger.LogDebug(c, fmt.Sprintf("Using cached file data for URL: %s", url)) } return cachedData.(*types.CachedFileData), nil } // 下载文件 var maxFileSize = constant.MaxFileDownloadMB * 1024 * 1024 resp, err := DoDownloadRequest(url, reason...) if err != nil { return nil, fmt.Errorf("failed to download file from %s: %w", url, err) } defer resp.Body.Close() if resp.StatusCode != 200 { return nil, fmt.Errorf("failed to download file, status code: %d", resp.StatusCode) } // 读取文件内容(限制大小) fileBytes, err := io.ReadAll(io.LimitReader(resp.Body, int64(maxFileSize+1))) if err != nil { return nil, fmt.Errorf("failed to read file content: %w", err) } if len(fileBytes) > maxFileSize { return nil, fmt.Errorf("file size exceeds maximum allowed size: %dMB", constant.MaxFileDownloadMB) } // 转换为 base64 base64Data := base64.StdEncoding.EncodeToString(fileBytes) // 智能获取 MIME 类型 mimeType := smartDetectMimeType(resp, url, fileBytes) // 判断是否使用磁盘缓存 base64Size := int64(len(base64Data)) var cachedData *types.CachedFileData if shouldUseDiskCache(base64Size) { // 使用磁盘缓存 diskPath, err := writeToDiskCache(base64Data) if err != nil { // 磁盘缓存失败,回退到内存 logger.LogWarn(c, fmt.Sprintf("Failed to write to disk cache, falling back to memory: %v", err)) cachedData = types.NewMemoryCachedData(base64Data, mimeType, int64(len(fileBytes))) } else { cachedData = types.NewDiskCachedData(diskPath, mimeType, int64(len(fileBytes))) common.IncrementDiskFiles(base64Size) if common.DebugEnabled { logger.LogDebug(c, fmt.Sprintf("File cached to disk: %s, size: %d bytes", diskPath, base64Size)) } } } else { // 使用内存缓存 cachedData = types.NewMemoryCachedData(base64Data, mimeType, int64(len(fileBytes))) } // 如果是图片,尝试获取图片配置 if strings.HasPrefix(mimeType, "image/") { config, format, err := decodeImageConfig(fileBytes) if err == nil { cachedData.ImageConfig = &config cachedData.ImageFormat = format // 如果通过图片解码获取了更准确的格式,更新 MIME 类型 if mimeType == "application/octet-stream" || mimeType == "" { cachedData.MimeType = "image/" + format } } } // 存入 context 缓存 c.Set(contextKey, cachedData) return cachedData, nil } // shouldUseDiskCache 判断是否应该使用磁盘缓存 func shouldUseDiskCache(dataSize int64) bool { return common.ShouldUseDiskCache(dataSize) } // writeToDiskCache 将数据写入磁盘缓存 func writeToDiskCache(base64Data string) (string, error) { return common.WriteDiskCacheFileString(common.DiskCacheTypeFile, base64Data) } // smartDetectMimeType 智能检测 MIME 类型 // 优先级:Content-Type header > Content-Disposition filename > URL 路径 > 内容嗅探 > 图片解码 func smartDetectMimeType(resp *http.Response, url string, fileBytes []byte) string { // 1. 尝试从 Content-Type header 获取 mimeType := resp.Header.Get("Content-Type") if idx := strings.Index(mimeType, ";"); idx != -1 { mimeType = strings.TrimSpace(mimeType[:idx]) } if mimeType != "" && mimeType != "application/octet-stream" { return mimeType } // 2. 尝试从 Content-Disposition header 的 filename 获取 if cd := resp.Header.Get("Content-Disposition"); cd != "" { parts := strings.Split(cd, ";") for _, part := range parts { part = strings.TrimSpace(part) if strings.HasPrefix(strings.ToLower(part), "filename=") { name := strings.TrimSpace(strings.TrimPrefix(part, "filename=")) // 移除引号 if len(name) > 2 && name[0] == '"' && name[len(name)-1] == '"' { name = name[1 : len(name)-1] } if dot := strings.LastIndex(name, "."); dot != -1 && dot+1 < len(name) { ext := strings.ToLower(name[dot+1:]) if ext != "" { mt := GetMimeTypeByExtension(ext) if mt != "application/octet-stream" { return mt } } } break } } } // 3. 尝试从 URL 路径获取扩展名 mt := guessMimeTypeFromURL(url) if mt != "application/octet-stream" { return mt } // 4. 使用 http.DetectContentType 内容嗅探 if len(fileBytes) > 0 { sniffed := http.DetectContentType(fileBytes) if sniffed != "" && sniffed != "application/octet-stream" { // 去除可能的 charset 参数 if idx := strings.Index(sniffed, ";"); idx != -1 { sniffed = strings.TrimSpace(sniffed[:idx]) } return sniffed } } // 5. 尝试作为图片解码获取格式 if len(fileBytes) > 0 { if _, format, err := decodeImageConfig(fileBytes); err == nil && format != "" { return "image/" + strings.ToLower(format) } } // 最终回退 return "application/octet-stream" } // loadFromBase64 从 base64 字符串加载文件 func loadFromBase64(base64String string, providedMimeType string) (*types.CachedFileData, error) { var mimeType string var cleanBase64 string // 处理 data: 前缀 if strings.HasPrefix(base64String, "data:") { // 格式: data:mime/type;base64,xxxxx idx := strings.Index(base64String, ",") if idx != -1 { header := base64String[:idx] cleanBase64 = base64String[idx+1:] // 从 header 提取 MIME 类型 if strings.Contains(header, ":") && strings.Contains(header, ";") { mimeStart := strings.Index(header, ":") + 1 mimeEnd := strings.Index(header, ";") if mimeStart < mimeEnd { mimeType = header[mimeStart:mimeEnd] } } } else { cleanBase64 = base64String } } else { cleanBase64 = base64String } // 使用提供的 MIME 类型(如果有) if providedMimeType != "" { mimeType = providedMimeType } // 解码 base64 decodedData, err := base64.StdEncoding.DecodeString(cleanBase64) if err != nil { return nil, fmt.Errorf("failed to decode base64 data: %w", err) } // 判断是否使用磁盘缓存(对于 base64 内联数据也支持磁盘缓存) base64Size := int64(len(cleanBase64)) var cachedData *types.CachedFileData if shouldUseDiskCache(base64Size) { // 使用磁盘缓存 diskPath, err := writeToDiskCache(cleanBase64) if err != nil { // 磁盘缓存失败,回退到内存 cachedData = types.NewMemoryCachedData(cleanBase64, mimeType, int64(len(decodedData))) } else { cachedData = types.NewDiskCachedData(diskPath, mimeType, int64(len(decodedData))) common.IncrementDiskFiles(base64Size) } } else { cachedData = types.NewMemoryCachedData(cleanBase64, mimeType, int64(len(decodedData))) } // 如果是图片或 MIME 类型未知,尝试解码图片获取更多信息 if mimeType == "" || strings.HasPrefix(mimeType, "image/") { config, format, err := decodeImageConfig(decodedData) if err == nil { cachedData.ImageConfig = &config cachedData.ImageFormat = format if mimeType == "" { cachedData.MimeType = "image/" + format } } } return cachedData, nil } // GetImageConfig 获取图片配置(宽高等信息) // 会自动处理缓存,避免重复下载/解码 func GetImageConfig(c *gin.Context, source *types.FileSource) (image.Config, string, error) { cachedData, err := LoadFileSource(c, source, "get_image_config") if err != nil { return image.Config{}, "", err } if cachedData.ImageConfig != nil { return *cachedData.ImageConfig, cachedData.ImageFormat, nil } // 如果缓存中没有图片配置,尝试解码 base64Str, err := cachedData.GetBase64Data() if err != nil { return image.Config{}, "", fmt.Errorf("failed to get base64 data: %w", err) } decodedData, err := base64.StdEncoding.DecodeString(base64Str) if err != nil { return image.Config{}, "", fmt.Errorf("failed to decode base64 for image config: %w", err) } config, format, err := decodeImageConfig(decodedData) if err != nil { return image.Config{}, "", err } // 更新缓存 cachedData.ImageConfig = &config cachedData.ImageFormat = format return config, format, nil } // GetBase64Data 获取 base64 编码的数据 // 会自动处理缓存,避免重复下载 // 支持内存缓存和磁盘缓存 func GetBase64Data(c *gin.Context, source *types.FileSource, reason ...string) (string, string, error) { cachedData, err := LoadFileSource(c, source, reason...) if err != nil { return "", "", err } base64Str, err := cachedData.GetBase64Data() if err != nil { return "", "", fmt.Errorf("failed to get base64 data: %w", err) } return base64Str, cachedData.MimeType, nil } // GetMimeType 获取文件的 MIME 类型 func GetMimeType(c *gin.Context, source *types.FileSource) (string, error) { // 如果已经有缓存,直接返回 if source.HasCache() { return source.GetCache().MimeType, nil } // 如果是 URL,尝试只获取 header 而不下载完整文件 if source.IsURL() { mimeType, err := GetFileTypeFromUrl(c, source.URL, "get_mime_type") if err == nil && mimeType != "" && mimeType != "application/octet-stream" { return mimeType, nil } } // 否则加载完整数据 cachedData, err := LoadFileSource(c, source, "get_mime_type") if err != nil { return "", err } return cachedData.MimeType, nil } // DetectFileType 检测文件类型(image/audio/video/file) func DetectFileType(mimeType string) types.FileType { if strings.HasPrefix(mimeType, "image/") { return types.FileTypeImage } if strings.HasPrefix(mimeType, "audio/") { return types.FileTypeAudio } if strings.HasPrefix(mimeType, "video/") { return types.FileTypeVideo } return types.FileTypeFile } // decodeImageConfig 从字节数据解码图片配置 func decodeImageConfig(data []byte) (image.Config, string, error) { reader := bytes.NewReader(data) // 尝试标准格式 config, format, err := image.DecodeConfig(reader) if err == nil { return config, format, nil } // 尝试 webp reader.Seek(0, io.SeekStart) config, err = webp.DecodeConfig(reader) if err == nil { return config, "webp", nil } return image.Config{}, "", fmt.Errorf("failed to decode image config: unsupported format") } // guessMimeTypeFromURL 从 URL 猜测 MIME 类型 func guessMimeTypeFromURL(url string) string { // 移除查询参数 cleanedURL := url if q := strings.Index(cleanedURL, "?"); q != -1 { cleanedURL = cleanedURL[:q] } // 获取最后一段 if slash := strings.LastIndex(cleanedURL, "/"); slash != -1 && slash+1 < len(cleanedURL) { last := cleanedURL[slash+1:] if dot := strings.LastIndex(last, "."); dot != -1 && dot+1 < len(last) { ext := strings.ToLower(last[dot+1:]) return GetMimeTypeByExtension(ext) } } return "application/octet-stream" }