file_service.go 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586
  1. package service
  2. import (
  3. "bytes"
  4. "encoding/base64"
  5. "encoding/binary"
  6. "fmt"
  7. "image"
  8. _ "image/gif"
  9. _ "image/jpeg"
  10. _ "image/png"
  11. "io"
  12. "net/http"
  13. "strings"
  14. "github.com/QuantumNous/new-api/common"
  15. "github.com/QuantumNous/new-api/constant"
  16. "github.com/QuantumNous/new-api/logger"
  17. "github.com/QuantumNous/new-api/types"
  18. "github.com/gin-gonic/gin"
  19. "golang.org/x/image/webp"
  20. )
  21. // FileService 统一的文件处理服务
  22. // 提供文件下载、解码、缓存等功能的统一入口
  23. // getContextCacheKey 生成 context 缓存的 key
  24. func getContextCacheKey(url string) string {
  25. return fmt.Sprintf("file_cache_%s", common.GenerateHMAC(url))
  26. }
  27. // LoadFileSource 加载文件源数据
  28. // 这是统一的入口,会自动处理缓存和不同的来源类型
  29. func LoadFileSource(c *gin.Context, source *types.FileSource, reason ...string) (*types.CachedFileData, error) {
  30. if source == nil {
  31. return nil, fmt.Errorf("file source is nil")
  32. }
  33. if common.DebugEnabled {
  34. logger.LogDebug(c, fmt.Sprintf("LoadFileSource starting for: %s", source.GetIdentifier()))
  35. }
  36. // 1. 快速检查内部缓存
  37. if source.HasCache() {
  38. // 即使命中内部缓存,也要确保注册到清理列表(如果尚未注册)
  39. if c != nil {
  40. registerSourceForCleanup(c, source)
  41. }
  42. return source.GetCache(), nil
  43. }
  44. // 2. 加锁保护加载过程
  45. source.Mu().Lock()
  46. defer source.Mu().Unlock()
  47. // 3. 双重检查
  48. if source.HasCache() {
  49. if c != nil {
  50. registerSourceForCleanup(c, source)
  51. }
  52. return source.GetCache(), nil
  53. }
  54. // 4. 如果是 URL,检查 Context 缓存
  55. var contextKey string
  56. if source.IsURL() && c != nil {
  57. contextKey = getContextCacheKey(source.URL)
  58. if cachedData, exists := c.Get(contextKey); exists {
  59. data := cachedData.(*types.CachedFileData)
  60. source.SetCache(data)
  61. registerSourceForCleanup(c, source)
  62. return data, nil
  63. }
  64. }
  65. // 5. 执行加载逻辑
  66. var cachedData *types.CachedFileData
  67. var err error
  68. if source.IsURL() {
  69. cachedData, err = loadFromURL(c, source.URL, reason...)
  70. } else {
  71. cachedData, err = loadFromBase64(source.Base64Data, source.MimeType)
  72. }
  73. if err != nil {
  74. return nil, err
  75. }
  76. // 6. 设置缓存
  77. source.SetCache(cachedData)
  78. if contextKey != "" && c != nil {
  79. c.Set(contextKey, cachedData)
  80. }
  81. // 7. 注册到 context 以便请求结束时自动清理
  82. if c != nil {
  83. registerSourceForCleanup(c, source)
  84. }
  85. return cachedData, nil
  86. }
  87. // registerSourceForCleanup 注册 FileSource 到 context 以便请求结束时清理
  88. func registerSourceForCleanup(c *gin.Context, source *types.FileSource) {
  89. if source.IsRegistered() {
  90. return
  91. }
  92. key := string(constant.ContextKeyFileSourcesToCleanup)
  93. var sources []*types.FileSource
  94. if existing, exists := c.Get(key); exists {
  95. sources = existing.([]*types.FileSource)
  96. }
  97. sources = append(sources, source)
  98. c.Set(key, sources)
  99. source.SetRegistered(true)
  100. }
  101. // CleanupFileSources 清理请求中所有注册的 FileSource
  102. // 应在请求结束时调用(通常由中间件自动调用)
  103. func CleanupFileSources(c *gin.Context) {
  104. key := string(constant.ContextKeyFileSourcesToCleanup)
  105. if sources, exists := c.Get(key); exists {
  106. for _, source := range sources.([]*types.FileSource) {
  107. if cache := source.GetCache(); cache != nil {
  108. cache.Close()
  109. }
  110. }
  111. c.Set(key, nil) // 清除引用
  112. }
  113. }
  114. // loadFromURL 从 URL 加载文件
  115. func loadFromURL(c *gin.Context, url string, reason ...string) (*types.CachedFileData, error) {
  116. // 下载文件
  117. var maxFileSize = constant.MaxFileDownloadMB * 1024 * 1024
  118. if common.DebugEnabled {
  119. logger.LogDebug(c, "loadFromURL: initiating download")
  120. }
  121. resp, err := DoDownloadRequest(url, reason...)
  122. if err != nil {
  123. return nil, fmt.Errorf("failed to download file from %s: %w", url, err)
  124. }
  125. defer resp.Body.Close()
  126. if resp.StatusCode != 200 {
  127. return nil, fmt.Errorf("failed to download file, status code: %d", resp.StatusCode)
  128. }
  129. // 读取文件内容(限制大小)
  130. if common.DebugEnabled {
  131. logger.LogDebug(c, "loadFromURL: reading response body")
  132. }
  133. fileBytes, err := io.ReadAll(io.LimitReader(resp.Body, int64(maxFileSize+1)))
  134. if err != nil {
  135. return nil, fmt.Errorf("failed to read file content: %w", err)
  136. }
  137. if len(fileBytes) > maxFileSize {
  138. return nil, fmt.Errorf("file size exceeds maximum allowed size: %dMB", constant.MaxFileDownloadMB)
  139. }
  140. // 转换为 base64
  141. base64Data := base64.StdEncoding.EncodeToString(fileBytes)
  142. // 智能获取 MIME 类型
  143. mimeType := smartDetectMimeType(resp, url, fileBytes)
  144. // 判断是否使用磁盘缓存
  145. base64Size := int64(len(base64Data))
  146. var cachedData *types.CachedFileData
  147. if shouldUseDiskCache(base64Size) {
  148. // 使用磁盘缓存
  149. diskPath, err := writeToDiskCache(base64Data)
  150. if err != nil {
  151. // 磁盘缓存失败,回退到内存
  152. logger.LogWarn(c, fmt.Sprintf("Failed to write to disk cache, falling back to memory: %v", err))
  153. cachedData = types.NewMemoryCachedData(base64Data, mimeType, int64(len(fileBytes)))
  154. } else {
  155. cachedData = types.NewDiskCachedData(diskPath, mimeType, int64(len(fileBytes)))
  156. cachedData.DiskSize = base64Size
  157. cachedData.OnClose = func(size int64) {
  158. common.DecrementDiskFiles(size)
  159. }
  160. common.IncrementDiskFiles(base64Size)
  161. if common.DebugEnabled {
  162. logger.LogDebug(c, fmt.Sprintf("File cached to disk: %s, size: %d bytes", diskPath, base64Size))
  163. }
  164. }
  165. } else {
  166. // 使用内存缓存
  167. cachedData = types.NewMemoryCachedData(base64Data, mimeType, int64(len(fileBytes)))
  168. }
  169. // 如果是图片,尝试获取图片配置
  170. if strings.HasPrefix(mimeType, "image/") {
  171. if common.DebugEnabled {
  172. logger.LogDebug(c, "loadFromURL: decoding image config")
  173. }
  174. config, format, err := decodeImageConfig(fileBytes)
  175. if err == nil {
  176. cachedData.ImageConfig = &config
  177. cachedData.ImageFormat = format
  178. // 如果通过图片解码获取了更准确的格式,更新 MIME 类型
  179. if mimeType == "application/octet-stream" || mimeType == "" {
  180. cachedData.MimeType = "image/" + format
  181. }
  182. }
  183. }
  184. return cachedData, nil
  185. }
  186. // shouldUseDiskCache 判断是否应该使用磁盘缓存
  187. func shouldUseDiskCache(dataSize int64) bool {
  188. return common.ShouldUseDiskCache(dataSize)
  189. }
  190. // writeToDiskCache 将数据写入磁盘缓存
  191. func writeToDiskCache(base64Data string) (string, error) {
  192. return common.WriteDiskCacheFileString(common.DiskCacheTypeFile, base64Data)
  193. }
  194. // smartDetectMimeType 智能检测 MIME 类型
  195. func smartDetectMimeType(resp *http.Response, url string, fileBytes []byte) string {
  196. // 1. 尝试从 Content-Type header 获取
  197. mimeType := resp.Header.Get("Content-Type")
  198. if idx := strings.Index(mimeType, ";"); idx != -1 {
  199. mimeType = strings.TrimSpace(mimeType[:idx])
  200. }
  201. if mimeType != "" && mimeType != "application/octet-stream" {
  202. return mimeType
  203. }
  204. // 2. 尝试从 Content-Disposition header 的 filename 获取
  205. if cd := resp.Header.Get("Content-Disposition"); cd != "" {
  206. parts := strings.Split(cd, ";")
  207. for _, part := range parts {
  208. part = strings.TrimSpace(part)
  209. if strings.HasPrefix(strings.ToLower(part), "filename=") {
  210. name := strings.TrimSpace(strings.TrimPrefix(part, "filename="))
  211. // 移除引号
  212. if len(name) > 2 && name[0] == '"' && name[len(name)-1] == '"' {
  213. name = name[1 : len(name)-1]
  214. }
  215. if dot := strings.LastIndex(name, "."); dot != -1 && dot+1 < len(name) {
  216. ext := strings.ToLower(name[dot+1:])
  217. if ext != "" {
  218. mt := GetMimeTypeByExtension(ext)
  219. if mt != "application/octet-stream" {
  220. return mt
  221. }
  222. }
  223. }
  224. break
  225. }
  226. }
  227. }
  228. // 3. 尝试从 URL 路径获取扩展名
  229. mt := guessMimeTypeFromURL(url)
  230. if mt != "application/octet-stream" {
  231. return mt
  232. }
  233. // 4. 使用 http.DetectContentType 内容嗅探
  234. if len(fileBytes) > 0 {
  235. sniffed := http.DetectContentType(fileBytes)
  236. if sniffed != "" && sniffed != "application/octet-stream" {
  237. // 去除可能的 charset 参数
  238. if idx := strings.Index(sniffed, ";"); idx != -1 {
  239. sniffed = strings.TrimSpace(sniffed[:idx])
  240. }
  241. return sniffed
  242. }
  243. // 4.5 尝试 HEIF/HEIC 检测(Go 标准库不识别)
  244. if heifMime := detectHEIF(fileBytes); heifMime != "" {
  245. return heifMime
  246. }
  247. }
  248. // 5. 尝试作为图片解码获取格式
  249. if len(fileBytes) > 0 {
  250. if _, format, err := decodeImageConfig(fileBytes); err == nil && format != "" {
  251. return "image/" + strings.ToLower(format)
  252. }
  253. }
  254. // 最终回退
  255. return "application/octet-stream"
  256. }
  257. // loadFromBase64 从 base64 字符串加载文件
  258. func loadFromBase64(base64String string, providedMimeType string) (*types.CachedFileData, error) {
  259. var mimeType string
  260. var cleanBase64 string
  261. // 处理 data: 前缀
  262. if strings.HasPrefix(base64String, "data:") {
  263. idx := strings.Index(base64String, ",")
  264. if idx != -1 {
  265. header := base64String[:idx]
  266. cleanBase64 = base64String[idx+1:]
  267. if strings.Contains(header, ":") && strings.Contains(header, ";") {
  268. mimeStart := strings.Index(header, ":") + 1
  269. mimeEnd := strings.Index(header, ";")
  270. if mimeStart < mimeEnd {
  271. mimeType = header[mimeStart:mimeEnd]
  272. }
  273. }
  274. } else {
  275. cleanBase64 = base64String
  276. }
  277. } else {
  278. cleanBase64 = base64String
  279. }
  280. if providedMimeType != "" {
  281. mimeType = providedMimeType
  282. }
  283. decodedData, err := base64.StdEncoding.DecodeString(cleanBase64)
  284. if err != nil {
  285. return nil, fmt.Errorf("failed to decode base64 data: %w", err)
  286. }
  287. base64Size := int64(len(cleanBase64))
  288. var cachedData *types.CachedFileData
  289. if shouldUseDiskCache(base64Size) {
  290. diskPath, err := writeToDiskCache(cleanBase64)
  291. if err != nil {
  292. cachedData = types.NewMemoryCachedData(cleanBase64, mimeType, int64(len(decodedData)))
  293. } else {
  294. cachedData = types.NewDiskCachedData(diskPath, mimeType, int64(len(decodedData)))
  295. cachedData.DiskSize = base64Size
  296. cachedData.OnClose = func(size int64) {
  297. common.DecrementDiskFiles(size)
  298. }
  299. common.IncrementDiskFiles(base64Size)
  300. }
  301. } else {
  302. cachedData = types.NewMemoryCachedData(cleanBase64, mimeType, int64(len(decodedData)))
  303. }
  304. if mimeType == "" || strings.HasPrefix(mimeType, "image/") {
  305. config, format, err := decodeImageConfig(decodedData)
  306. if err == nil {
  307. cachedData.ImageConfig = &config
  308. cachedData.ImageFormat = format
  309. if mimeType == "" {
  310. cachedData.MimeType = "image/" + format
  311. }
  312. }
  313. }
  314. return cachedData, nil
  315. }
  316. // GetImageConfig 获取图片配置
  317. func GetImageConfig(c *gin.Context, source *types.FileSource) (image.Config, string, error) {
  318. cachedData, err := LoadFileSource(c, source, "get_image_config")
  319. if err != nil {
  320. return image.Config{}, "", err
  321. }
  322. if cachedData.ImageConfig != nil {
  323. return *cachedData.ImageConfig, cachedData.ImageFormat, nil
  324. }
  325. base64Str, err := cachedData.GetBase64Data()
  326. if err != nil {
  327. return image.Config{}, "", fmt.Errorf("failed to get base64 data: %w", err)
  328. }
  329. decodedData, err := base64.StdEncoding.DecodeString(base64Str)
  330. if err != nil {
  331. return image.Config{}, "", fmt.Errorf("failed to decode base64 for image config: %w", err)
  332. }
  333. config, format, err := decodeImageConfig(decodedData)
  334. if err != nil {
  335. return image.Config{}, "", err
  336. }
  337. cachedData.ImageConfig = &config
  338. cachedData.ImageFormat = format
  339. return config, format, nil
  340. }
  341. // GetBase64Data 获取 base64 编码的数据
  342. func GetBase64Data(c *gin.Context, source *types.FileSource, reason ...string) (string, string, error) {
  343. cachedData, err := LoadFileSource(c, source, reason...)
  344. if err != nil {
  345. return "", "", err
  346. }
  347. base64Str, err := cachedData.GetBase64Data()
  348. if err != nil {
  349. return "", "", fmt.Errorf("failed to get base64 data: %w", err)
  350. }
  351. return base64Str, cachedData.MimeType, nil
  352. }
  353. // GetMimeType 获取文件的 MIME 类型
  354. func GetMimeType(c *gin.Context, source *types.FileSource) (string, error) {
  355. if source.HasCache() {
  356. return source.GetCache().MimeType, nil
  357. }
  358. if source.IsURL() {
  359. mimeType, err := GetFileTypeFromUrl(c, source.URL, "get_mime_type")
  360. if err == nil && mimeType != "" && mimeType != "application/octet-stream" {
  361. return mimeType, nil
  362. }
  363. }
  364. cachedData, err := LoadFileSource(c, source, "get_mime_type")
  365. if err != nil {
  366. return "", err
  367. }
  368. return cachedData.MimeType, nil
  369. }
  370. // DetectFileType 检测文件类型
  371. func DetectFileType(mimeType string) types.FileType {
  372. if strings.HasPrefix(mimeType, "image/") {
  373. return types.FileTypeImage
  374. }
  375. if strings.HasPrefix(mimeType, "audio/") {
  376. return types.FileTypeAudio
  377. }
  378. if strings.HasPrefix(mimeType, "video/") {
  379. return types.FileTypeVideo
  380. }
  381. return types.FileTypeFile
  382. }
  383. // decodeImageConfig 从字节数据解码图片配置
  384. func decodeImageConfig(data []byte) (image.Config, string, error) {
  385. reader := bytes.NewReader(data)
  386. config, format, err := image.DecodeConfig(reader)
  387. if err == nil {
  388. return config, format, nil
  389. }
  390. reader.Seek(0, io.SeekStart)
  391. config, err = webp.DecodeConfig(reader)
  392. if err == nil {
  393. return config, "webp", nil
  394. }
  395. // Try HEIF/HEIC: parse ISOBMFF ispe box for dimensions
  396. if heifMime := detectHEIF(data); heifMime != "" {
  397. formatName := "heif"
  398. if heifMime == "image/heic" {
  399. formatName = "heic"
  400. }
  401. if w, h, ok := parseHEIFDimensions(data); ok {
  402. return image.Config{Width: w, Height: h}, formatName, nil
  403. }
  404. return image.Config{}, "", fmt.Errorf("failed to decode HEIF/HEIC image dimensions")
  405. }
  406. return image.Config{}, "", fmt.Errorf("failed to decode image config: unsupported format")
  407. }
  408. // detectHEIF checks ISOBMFF magic bytes to detect HEIC/HEIF files.
  409. // Returns "image/heic", "image/heif", or "" if not recognized.
  410. func detectHEIF(data []byte) string {
  411. if len(data) < 12 {
  412. return ""
  413. }
  414. // ISOBMFF: bytes[4:8] must be "ftyp"
  415. if string(data[4:8]) != "ftyp" {
  416. return ""
  417. }
  418. brand := string(data[8:12])
  419. switch brand {
  420. case "heic", "heix", "hevc", "hevx", "heim", "heis":
  421. return "image/heic"
  422. case "mif1", "msf1":
  423. return "image/heif"
  424. default:
  425. return ""
  426. }
  427. }
  428. // parseHEIFDimensions parses ISOBMFF box tree to find the ispe box
  429. // and extract image width/height. Returns (width, height, ok).
  430. func parseHEIFDimensions(data []byte) (int, int, bool) {
  431. size := len(data)
  432. if size < 12 {
  433. return 0, 0, false
  434. }
  435. // Walk top-level boxes to find "meta"
  436. offset := 0
  437. for offset+8 <= size {
  438. boxSize := int(binary.BigEndian.Uint32(data[offset : offset+4]))
  439. boxType := string(data[offset+4 : offset+8])
  440. headerLen := 8
  441. if boxSize == 1 {
  442. // 64-bit extended size
  443. if offset+16 > size {
  444. break
  445. }
  446. boxSize = int(binary.BigEndian.Uint64(data[offset+8 : offset+16]))
  447. headerLen = 16
  448. } else if boxSize == 0 {
  449. // box extends to end of data
  450. boxSize = size - offset
  451. }
  452. if boxSize < headerLen || offset+boxSize > size {
  453. break
  454. }
  455. if boxType == "meta" {
  456. // meta is a full box: 4 bytes version/flags after header
  457. metaData := data[offset+headerLen : offset+boxSize]
  458. if len(metaData) < 4 {
  459. return 0, 0, false
  460. }
  461. return findISPE(metaData[4:])
  462. }
  463. offset += boxSize
  464. }
  465. return 0, 0, false
  466. }
  467. // findISPE recursively searches for the ispe box within container boxes.
  468. // Path: meta -> iprp -> ipco -> ispe
  469. func findISPE(data []byte) (int, int, bool) {
  470. offset := 0
  471. size := len(data)
  472. for offset+8 <= size {
  473. boxSize := int(binary.BigEndian.Uint32(data[offset : offset+4]))
  474. boxType := string(data[offset+4 : offset+8])
  475. if boxSize < 8 || offset+boxSize > size {
  476. break
  477. }
  478. content := data[offset+8 : offset+boxSize]
  479. switch boxType {
  480. case "iprp", "ipco":
  481. if w, h, ok := findISPE(content); ok {
  482. return w, h, true
  483. }
  484. case "ispe":
  485. // ispe is a full box: 4 bytes version/flags, then 4 bytes width, 4 bytes height
  486. if len(content) >= 12 {
  487. w := int(binary.BigEndian.Uint32(content[4:8]))
  488. h := int(binary.BigEndian.Uint32(content[8:12]))
  489. if w > 0 && h > 0 {
  490. return w, h, true
  491. }
  492. }
  493. }
  494. offset += boxSize
  495. }
  496. return 0, 0, false
  497. }
  498. // guessMimeTypeFromURL 从 URL 猜测 MIME 类型
  499. func guessMimeTypeFromURL(url string) string {
  500. cleanedURL := url
  501. if q := strings.Index(cleanedURL, "?"); q != -1 {
  502. cleanedURL = cleanedURL[:q]
  503. }
  504. if slash := strings.LastIndex(cleanedURL, "/"); slash != -1 && slash+1 < len(cleanedURL) {
  505. last := cleanedURL[slash+1:]
  506. if dot := strings.LastIndex(last, "."); dot != -1 && dot+1 < len(last) {
  507. ext := strings.ToLower(last[dot+1:])
  508. return GetMimeTypeByExtension(ext)
  509. }
  510. }
  511. return "application/octet-stream"
  512. }