file_service.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451
  1. package service
  2. import (
  3. "bytes"
  4. "encoding/base64"
  5. "fmt"
  6. "image"
  7. _ "image/gif"
  8. _ "image/jpeg"
  9. _ "image/png"
  10. "io"
  11. "net/http"
  12. "strings"
  13. "github.com/QuantumNous/new-api/common"
  14. "github.com/QuantumNous/new-api/constant"
  15. "github.com/QuantumNous/new-api/logger"
  16. "github.com/QuantumNous/new-api/types"
  17. "github.com/gin-gonic/gin"
  18. "golang.org/x/image/webp"
  19. )
  20. // FileService 统一的文件处理服务
  21. // 提供文件下载、解码、缓存等功能的统一入口
  22. // getContextCacheKey 生成 context 缓存的 key
  23. func getContextCacheKey(url string) string {
  24. return fmt.Sprintf("file_cache_%s", common.GenerateHMAC(url))
  25. }
  26. // LoadFileSource 加载文件源数据
  27. // 这是统一的入口,会自动处理缓存和不同的来源类型
  28. func LoadFileSource(c *gin.Context, source *types.FileSource, reason ...string) (*types.CachedFileData, error) {
  29. if source == nil {
  30. return nil, fmt.Errorf("file source is nil")
  31. }
  32. // 如果已有缓存,直接返回
  33. if source.HasCache() {
  34. return source.GetCache(), nil
  35. }
  36. var cachedData *types.CachedFileData
  37. var err error
  38. if source.IsURL() {
  39. cachedData, err = loadFromURL(c, source.URL, reason...)
  40. } else {
  41. cachedData, err = loadFromBase64(source.Base64Data, source.MimeType)
  42. }
  43. if err != nil {
  44. return nil, err
  45. }
  46. // 设置缓存
  47. source.SetCache(cachedData)
  48. // 注册到 context 以便请求结束时自动清理
  49. if c != nil {
  50. registerSourceForCleanup(c, source)
  51. }
  52. return cachedData, nil
  53. }
  54. // registerSourceForCleanup 注册 FileSource 到 context 以便请求结束时清理
  55. func registerSourceForCleanup(c *gin.Context, source *types.FileSource) {
  56. key := string(constant.ContextKeyFileSourcesToCleanup)
  57. var sources []*types.FileSource
  58. if existing, exists := c.Get(key); exists {
  59. sources = existing.([]*types.FileSource)
  60. }
  61. sources = append(sources, source)
  62. c.Set(key, sources)
  63. }
  64. // CleanupFileSources 清理请求中所有注册的 FileSource
  65. // 应在请求结束时调用(通常由中间件自动调用)
  66. func CleanupFileSources(c *gin.Context) {
  67. key := string(constant.ContextKeyFileSourcesToCleanup)
  68. if sources, exists := c.Get(key); exists {
  69. for _, source := range sources.([]*types.FileSource) {
  70. if cache := source.GetCache(); cache != nil {
  71. if cache.IsDisk() {
  72. common.DecrementDiskFiles(cache.Size)
  73. }
  74. cache.Close()
  75. }
  76. }
  77. c.Set(key, nil) // 清除引用
  78. }
  79. }
  80. // loadFromURL 从 URL 加载文件
  81. // 支持磁盘缓存:当文件大小超过阈值且磁盘缓存可用时,将数据存储到磁盘
  82. func loadFromURL(c *gin.Context, url string, reason ...string) (*types.CachedFileData, error) {
  83. contextKey := getContextCacheKey(url)
  84. // 检查 context 缓存
  85. if cachedData, exists := c.Get(contextKey); exists {
  86. if common.DebugEnabled {
  87. logger.LogDebug(c, fmt.Sprintf("Using cached file data for URL: %s", url))
  88. }
  89. return cachedData.(*types.CachedFileData), nil
  90. }
  91. // 下载文件
  92. var maxFileSize = constant.MaxFileDownloadMB * 1024 * 1024
  93. resp, err := DoDownloadRequest(url, reason...)
  94. if err != nil {
  95. return nil, fmt.Errorf("failed to download file from %s: %w", url, err)
  96. }
  97. defer resp.Body.Close()
  98. if resp.StatusCode != 200 {
  99. return nil, fmt.Errorf("failed to download file, status code: %d", resp.StatusCode)
  100. }
  101. // 读取文件内容(限制大小)
  102. fileBytes, err := io.ReadAll(io.LimitReader(resp.Body, int64(maxFileSize+1)))
  103. if err != nil {
  104. return nil, fmt.Errorf("failed to read file content: %w", err)
  105. }
  106. if len(fileBytes) > maxFileSize {
  107. return nil, fmt.Errorf("file size exceeds maximum allowed size: %dMB", constant.MaxFileDownloadMB)
  108. }
  109. // 转换为 base64
  110. base64Data := base64.StdEncoding.EncodeToString(fileBytes)
  111. // 智能获取 MIME 类型
  112. mimeType := smartDetectMimeType(resp, url, fileBytes)
  113. // 判断是否使用磁盘缓存
  114. base64Size := int64(len(base64Data))
  115. var cachedData *types.CachedFileData
  116. if shouldUseDiskCache(base64Size) {
  117. // 使用磁盘缓存
  118. diskPath, err := writeToDiskCache(base64Data)
  119. if err != nil {
  120. // 磁盘缓存失败,回退到内存
  121. logger.LogWarn(c, fmt.Sprintf("Failed to write to disk cache, falling back to memory: %v", err))
  122. cachedData = types.NewMemoryCachedData(base64Data, mimeType, int64(len(fileBytes)))
  123. } else {
  124. cachedData = types.NewDiskCachedData(diskPath, mimeType, int64(len(fileBytes)))
  125. common.IncrementDiskFiles(base64Size)
  126. if common.DebugEnabled {
  127. logger.LogDebug(c, fmt.Sprintf("File cached to disk: %s, size: %d bytes", diskPath, base64Size))
  128. }
  129. }
  130. } else {
  131. // 使用内存缓存
  132. cachedData = types.NewMemoryCachedData(base64Data, mimeType, int64(len(fileBytes)))
  133. }
  134. // 如果是图片,尝试获取图片配置
  135. if strings.HasPrefix(mimeType, "image/") {
  136. config, format, err := decodeImageConfig(fileBytes)
  137. if err == nil {
  138. cachedData.ImageConfig = &config
  139. cachedData.ImageFormat = format
  140. // 如果通过图片解码获取了更准确的格式,更新 MIME 类型
  141. if mimeType == "application/octet-stream" || mimeType == "" {
  142. cachedData.MimeType = "image/" + format
  143. }
  144. }
  145. }
  146. // 存入 context 缓存
  147. c.Set(contextKey, cachedData)
  148. return cachedData, nil
  149. }
  150. // shouldUseDiskCache 判断是否应该使用磁盘缓存
  151. func shouldUseDiskCache(dataSize int64) bool {
  152. return common.ShouldUseDiskCache(dataSize)
  153. }
  154. // writeToDiskCache 将数据写入磁盘缓存
  155. func writeToDiskCache(base64Data string) (string, error) {
  156. return common.WriteDiskCacheFileString(common.DiskCacheTypeFile, base64Data)
  157. }
  158. // smartDetectMimeType 智能检测 MIME 类型
  159. // 优先级:Content-Type header > Content-Disposition filename > URL 路径 > 内容嗅探 > 图片解码
  160. func smartDetectMimeType(resp *http.Response, url string, fileBytes []byte) string {
  161. // 1. 尝试从 Content-Type header 获取
  162. mimeType := resp.Header.Get("Content-Type")
  163. if idx := strings.Index(mimeType, ";"); idx != -1 {
  164. mimeType = strings.TrimSpace(mimeType[:idx])
  165. }
  166. if mimeType != "" && mimeType != "application/octet-stream" {
  167. return mimeType
  168. }
  169. // 2. 尝试从 Content-Disposition header 的 filename 获取
  170. if cd := resp.Header.Get("Content-Disposition"); cd != "" {
  171. parts := strings.Split(cd, ";")
  172. for _, part := range parts {
  173. part = strings.TrimSpace(part)
  174. if strings.HasPrefix(strings.ToLower(part), "filename=") {
  175. name := strings.TrimSpace(strings.TrimPrefix(part, "filename="))
  176. // 移除引号
  177. if len(name) > 2 && name[0] == '"' && name[len(name)-1] == '"' {
  178. name = name[1 : len(name)-1]
  179. }
  180. if dot := strings.LastIndex(name, "."); dot != -1 && dot+1 < len(name) {
  181. ext := strings.ToLower(name[dot+1:])
  182. if ext != "" {
  183. mt := GetMimeTypeByExtension(ext)
  184. if mt != "application/octet-stream" {
  185. return mt
  186. }
  187. }
  188. }
  189. break
  190. }
  191. }
  192. }
  193. // 3. 尝试从 URL 路径获取扩展名
  194. mt := guessMimeTypeFromURL(url)
  195. if mt != "application/octet-stream" {
  196. return mt
  197. }
  198. // 4. 使用 http.DetectContentType 内容嗅探
  199. if len(fileBytes) > 0 {
  200. sniffed := http.DetectContentType(fileBytes)
  201. if sniffed != "" && sniffed != "application/octet-stream" {
  202. // 去除可能的 charset 参数
  203. if idx := strings.Index(sniffed, ";"); idx != -1 {
  204. sniffed = strings.TrimSpace(sniffed[:idx])
  205. }
  206. return sniffed
  207. }
  208. }
  209. // 5. 尝试作为图片解码获取格式
  210. if len(fileBytes) > 0 {
  211. if _, format, err := decodeImageConfig(fileBytes); err == nil && format != "" {
  212. return "image/" + strings.ToLower(format)
  213. }
  214. }
  215. // 最终回退
  216. return "application/octet-stream"
  217. }
  218. // loadFromBase64 从 base64 字符串加载文件
  219. func loadFromBase64(base64String string, providedMimeType string) (*types.CachedFileData, error) {
  220. var mimeType string
  221. var cleanBase64 string
  222. // 处理 data: 前缀
  223. if strings.HasPrefix(base64String, "data:") {
  224. // 格式: data:mime/type;base64,xxxxx
  225. idx := strings.Index(base64String, ",")
  226. if idx != -1 {
  227. header := base64String[:idx]
  228. cleanBase64 = base64String[idx+1:]
  229. // 从 header 提取 MIME 类型
  230. if strings.Contains(header, ":") && strings.Contains(header, ";") {
  231. mimeStart := strings.Index(header, ":") + 1
  232. mimeEnd := strings.Index(header, ";")
  233. if mimeStart < mimeEnd {
  234. mimeType = header[mimeStart:mimeEnd]
  235. }
  236. }
  237. } else {
  238. cleanBase64 = base64String
  239. }
  240. } else {
  241. cleanBase64 = base64String
  242. }
  243. // 使用提供的 MIME 类型(如果有)
  244. if providedMimeType != "" {
  245. mimeType = providedMimeType
  246. }
  247. // 解码 base64
  248. decodedData, err := base64.StdEncoding.DecodeString(cleanBase64)
  249. if err != nil {
  250. return nil, fmt.Errorf("failed to decode base64 data: %w", err)
  251. }
  252. // 判断是否使用磁盘缓存(对于 base64 内联数据也支持磁盘缓存)
  253. base64Size := int64(len(cleanBase64))
  254. var cachedData *types.CachedFileData
  255. if shouldUseDiskCache(base64Size) {
  256. // 使用磁盘缓存
  257. diskPath, err := writeToDiskCache(cleanBase64)
  258. if err != nil {
  259. // 磁盘缓存失败,回退到内存
  260. cachedData = types.NewMemoryCachedData(cleanBase64, mimeType, int64(len(decodedData)))
  261. } else {
  262. cachedData = types.NewDiskCachedData(diskPath, mimeType, int64(len(decodedData)))
  263. common.IncrementDiskFiles(base64Size)
  264. }
  265. } else {
  266. cachedData = types.NewMemoryCachedData(cleanBase64, mimeType, int64(len(decodedData)))
  267. }
  268. // 如果是图片或 MIME 类型未知,尝试解码图片获取更多信息
  269. if mimeType == "" || strings.HasPrefix(mimeType, "image/") {
  270. config, format, err := decodeImageConfig(decodedData)
  271. if err == nil {
  272. cachedData.ImageConfig = &config
  273. cachedData.ImageFormat = format
  274. if mimeType == "" {
  275. cachedData.MimeType = "image/" + format
  276. }
  277. }
  278. }
  279. return cachedData, nil
  280. }
  281. // GetImageConfig 获取图片配置(宽高等信息)
  282. // 会自动处理缓存,避免重复下载/解码
  283. func GetImageConfig(c *gin.Context, source *types.FileSource) (image.Config, string, error) {
  284. cachedData, err := LoadFileSource(c, source, "get_image_config")
  285. if err != nil {
  286. return image.Config{}, "", err
  287. }
  288. if cachedData.ImageConfig != nil {
  289. return *cachedData.ImageConfig, cachedData.ImageFormat, nil
  290. }
  291. // 如果缓存中没有图片配置,尝试解码
  292. base64Str, err := cachedData.GetBase64Data()
  293. if err != nil {
  294. return image.Config{}, "", fmt.Errorf("failed to get base64 data: %w", err)
  295. }
  296. decodedData, err := base64.StdEncoding.DecodeString(base64Str)
  297. if err != nil {
  298. return image.Config{}, "", fmt.Errorf("failed to decode base64 for image config: %w", err)
  299. }
  300. config, format, err := decodeImageConfig(decodedData)
  301. if err != nil {
  302. return image.Config{}, "", err
  303. }
  304. // 更新缓存
  305. cachedData.ImageConfig = &config
  306. cachedData.ImageFormat = format
  307. return config, format, nil
  308. }
  309. // GetBase64Data 获取 base64 编码的数据
  310. // 会自动处理缓存,避免重复下载
  311. // 支持内存缓存和磁盘缓存
  312. func GetBase64Data(c *gin.Context, source *types.FileSource, reason ...string) (string, string, error) {
  313. cachedData, err := LoadFileSource(c, source, reason...)
  314. if err != nil {
  315. return "", "", err
  316. }
  317. base64Str, err := cachedData.GetBase64Data()
  318. if err != nil {
  319. return "", "", fmt.Errorf("failed to get base64 data: %w", err)
  320. }
  321. return base64Str, cachedData.MimeType, nil
  322. }
  323. // GetMimeType 获取文件的 MIME 类型
  324. func GetMimeType(c *gin.Context, source *types.FileSource) (string, error) {
  325. // 如果已经有缓存,直接返回
  326. if source.HasCache() {
  327. return source.GetCache().MimeType, nil
  328. }
  329. // 如果是 URL,尝试只获取 header 而不下载完整文件
  330. if source.IsURL() {
  331. mimeType, err := GetFileTypeFromUrl(c, source.URL, "get_mime_type")
  332. if err == nil && mimeType != "" && mimeType != "application/octet-stream" {
  333. return mimeType, nil
  334. }
  335. }
  336. // 否则加载完整数据
  337. cachedData, err := LoadFileSource(c, source, "get_mime_type")
  338. if err != nil {
  339. return "", err
  340. }
  341. return cachedData.MimeType, nil
  342. }
  343. // DetectFileType 检测文件类型(image/audio/video/file)
  344. func DetectFileType(mimeType string) types.FileType {
  345. if strings.HasPrefix(mimeType, "image/") {
  346. return types.FileTypeImage
  347. }
  348. if strings.HasPrefix(mimeType, "audio/") {
  349. return types.FileTypeAudio
  350. }
  351. if strings.HasPrefix(mimeType, "video/") {
  352. return types.FileTypeVideo
  353. }
  354. return types.FileTypeFile
  355. }
  356. // decodeImageConfig 从字节数据解码图片配置
  357. func decodeImageConfig(data []byte) (image.Config, string, error) {
  358. reader := bytes.NewReader(data)
  359. // 尝试标准格式
  360. config, format, err := image.DecodeConfig(reader)
  361. if err == nil {
  362. return config, format, nil
  363. }
  364. // 尝试 webp
  365. reader.Seek(0, io.SeekStart)
  366. config, err = webp.DecodeConfig(reader)
  367. if err == nil {
  368. return config, "webp", nil
  369. }
  370. return image.Config{}, "", fmt.Errorf("failed to decode image config: unsupported format")
  371. }
  372. // guessMimeTypeFromURL 从 URL 猜测 MIME 类型
  373. func guessMimeTypeFromURL(url string) string {
  374. // 移除查询参数
  375. cleanedURL := url
  376. if q := strings.Index(cleanedURL, "?"); q != -1 {
  377. cleanedURL = cleanedURL[:q]
  378. }
  379. // 获取最后一段
  380. if slash := strings.LastIndex(cleanedURL, "/"); slash != -1 && slash+1 < len(cleanedURL) {
  381. last := cleanedURL[slash+1:]
  382. if dot := strings.LastIndex(last, "."); dot != -1 && dot+1 < len(last) {
  383. ext := strings.ToLower(last[dot+1:])
  384. return GetMimeTypeByExtension(ext)
  385. }
  386. }
  387. return "application/octet-stream"
  388. }