file_service.go 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607
  1. package service
  2. import (
  3. "bytes"
  4. "encoding/base64"
  5. "encoding/binary"
  6. "fmt"
  7. "image"
  8. _ "image/gif"
  9. _ "image/jpeg"
  10. _ "image/png"
  11. "io"
  12. "net/http"
  13. "strings"
  14. "github.com/QuantumNous/new-api/common"
  15. "github.com/QuantumNous/new-api/constant"
  16. "github.com/QuantumNous/new-api/logger"
  17. "github.com/QuantumNous/new-api/types"
  18. "github.com/gin-gonic/gin"
  19. "golang.org/x/image/webp"
  20. )
  21. // FileService 统一的文件处理服务
  22. // 提供文件下载、解码、缓存等功能的统一入口
  23. // getContextCacheKey 生成 URL context 缓存的 key
  24. func getContextCacheKey(url string) string {
  25. return fmt.Sprintf("file_cache_%s", common.GenerateHMAC(url))
  26. }
  27. // getBase64ContextCacheKey 生成 base64 context 缓存的 key
  28. // 使用 length + MIME + 前 128 字符作为输入,避免对整个 base64 数据做 hash
  29. func getBase64ContextCacheKey(data string, mimeType string) string {
  30. keyMaterial := fmt.Sprintf("%d:%s:", len(data), mimeType)
  31. if len(data) > 128 {
  32. keyMaterial += data[:128]
  33. } else {
  34. keyMaterial += data
  35. }
  36. return fmt.Sprintf("b64_cache_%s", common.GenerateHMAC(keyMaterial))
  37. }
  38. // LoadFileSource 加载文件源数据
  39. // 这是统一的入口,会自动处理缓存和不同的来源类型
  40. func LoadFileSource(c *gin.Context, source types.FileSource, reason ...string) (*types.CachedFileData, error) {
  41. if source == nil {
  42. return nil, fmt.Errorf("file source is nil")
  43. }
  44. if common.DebugEnabled {
  45. logger.LogDebug(c, fmt.Sprintf("LoadFileSource starting for: %s", source.GetIdentifier()))
  46. }
  47. // 1. 快速检查内部缓存
  48. if source.HasCache() {
  49. if c != nil {
  50. registerSourceForCleanup(c, source)
  51. }
  52. return source.GetCache(), nil
  53. }
  54. // 2. 加锁保护加载过程
  55. source.Mu().Lock()
  56. defer source.Mu().Unlock()
  57. // 3. 双重检查
  58. if source.HasCache() {
  59. if c != nil {
  60. registerSourceForCleanup(c, source)
  61. }
  62. return source.GetCache(), nil
  63. }
  64. // 4. 根据来源类型加载(含 URL context 缓存查找)
  65. var cachedData *types.CachedFileData
  66. var contextKey string
  67. var err error
  68. switch s := source.(type) {
  69. case *types.URLSource:
  70. if c != nil {
  71. contextKey = getContextCacheKey(s.URL)
  72. if cached, exists := c.Get(contextKey); exists {
  73. data := cached.(*types.CachedFileData)
  74. source.SetCache(data)
  75. registerSourceForCleanup(c, source)
  76. return data, nil
  77. }
  78. }
  79. cachedData, err = loadFromURL(c, s.URL, reason...)
  80. case *types.Base64Source:
  81. if c != nil {
  82. contextKey = getBase64ContextCacheKey(s.Base64Data, s.MimeType)
  83. if cached, exists := c.Get(contextKey); exists {
  84. data := cached.(*types.CachedFileData)
  85. source.SetCache(data)
  86. registerSourceForCleanup(c, source)
  87. return data, nil
  88. }
  89. }
  90. cachedData, err = loadFromBase64(s.Base64Data, s.MimeType)
  91. default:
  92. return nil, fmt.Errorf("unsupported file source type: %T", source)
  93. }
  94. if err != nil {
  95. return nil, err
  96. }
  97. // 5. 设置缓存
  98. source.SetCache(cachedData)
  99. if contextKey != "" && c != nil {
  100. c.Set(contextKey, cachedData)
  101. }
  102. // 6. 注册到 context 以便请求结束时自动清理
  103. if c != nil {
  104. registerSourceForCleanup(c, source)
  105. }
  106. return cachedData, nil
  107. }
  108. // registerSourceForCleanup 注册 FileSource 到 context 以便请求结束时清理
  109. func registerSourceForCleanup(c *gin.Context, source types.FileSource) {
  110. if source.IsRegistered() {
  111. return
  112. }
  113. key := string(constant.ContextKeyFileSourcesToCleanup)
  114. var sources []types.FileSource
  115. if existing, exists := c.Get(key); exists {
  116. sources = existing.([]types.FileSource)
  117. }
  118. sources = append(sources, source)
  119. c.Set(key, sources)
  120. source.SetRegistered(true)
  121. }
  122. // CleanupFileSources 清理请求中所有注册的 FileSource
  123. // 应在请求结束时调用(通常由中间件自动调用)
  124. func CleanupFileSources(c *gin.Context) {
  125. key := string(constant.ContextKeyFileSourcesToCleanup)
  126. if sources, exists := c.Get(key); exists {
  127. for _, source := range sources.([]types.FileSource) {
  128. if cache := source.GetCache(); cache != nil {
  129. cache.Close()
  130. }
  131. }
  132. c.Set(key, nil)
  133. }
  134. }
  135. // loadFromURL 从 URL 加载文件
  136. func loadFromURL(c *gin.Context, url string, reason ...string) (*types.CachedFileData, error) {
  137. // 下载文件
  138. var maxFileSize = constant.MaxFileDownloadMB * 1024 * 1024
  139. if common.DebugEnabled {
  140. logger.LogDebug(c, "loadFromURL: initiating download")
  141. }
  142. resp, err := DoDownloadRequest(url, reason...)
  143. if err != nil {
  144. return nil, fmt.Errorf("failed to download file from %s: %w", url, err)
  145. }
  146. defer resp.Body.Close()
  147. if resp.StatusCode != 200 {
  148. return nil, fmt.Errorf("failed to download file, status code: %d", resp.StatusCode)
  149. }
  150. // 读取文件内容(限制大小)
  151. if common.DebugEnabled {
  152. logger.LogDebug(c, "loadFromURL: reading response body")
  153. }
  154. fileBytes, err := io.ReadAll(io.LimitReader(resp.Body, int64(maxFileSize+1)))
  155. if err != nil {
  156. return nil, fmt.Errorf("failed to read file content: %w", err)
  157. }
  158. if len(fileBytes) > maxFileSize {
  159. return nil, fmt.Errorf("file size exceeds maximum allowed size: %dMB", constant.MaxFileDownloadMB)
  160. }
  161. // 转换为 base64
  162. base64Data := base64.StdEncoding.EncodeToString(fileBytes)
  163. // 智能获取 MIME 类型
  164. mimeType := smartDetectMimeType(resp, url, fileBytes)
  165. // 判断是否使用磁盘缓存
  166. base64Size := int64(len(base64Data))
  167. var cachedData *types.CachedFileData
  168. if shouldUseDiskCache(base64Size) {
  169. // 使用磁盘缓存
  170. diskPath, err := writeToDiskCache(base64Data)
  171. if err != nil {
  172. // 磁盘缓存失败,回退到内存
  173. logger.LogWarn(c, fmt.Sprintf("Failed to write to disk cache, falling back to memory: %v", err))
  174. cachedData = types.NewMemoryCachedData(base64Data, mimeType, int64(len(fileBytes)))
  175. } else {
  176. cachedData = types.NewDiskCachedData(diskPath, mimeType, int64(len(fileBytes)))
  177. cachedData.DiskSize = base64Size
  178. cachedData.OnClose = func(size int64) {
  179. common.DecrementDiskFiles(size)
  180. }
  181. common.IncrementDiskFiles(base64Size)
  182. if common.DebugEnabled {
  183. logger.LogDebug(c, fmt.Sprintf("File cached to disk: %s, size: %d bytes", diskPath, base64Size))
  184. }
  185. }
  186. } else {
  187. // 使用内存缓存
  188. cachedData = types.NewMemoryCachedData(base64Data, mimeType, int64(len(fileBytes)))
  189. }
  190. // 如果是图片,尝试获取图片配置
  191. if strings.HasPrefix(mimeType, "image/") {
  192. if common.DebugEnabled {
  193. logger.LogDebug(c, "loadFromURL: decoding image config")
  194. }
  195. config, format, err := decodeImageConfig(fileBytes)
  196. if err == nil {
  197. cachedData.ImageConfig = &config
  198. cachedData.ImageFormat = format
  199. // 如果通过图片解码获取了更准确的格式,更新 MIME 类型
  200. if mimeType == "application/octet-stream" || mimeType == "" {
  201. cachedData.MimeType = "image/" + format
  202. }
  203. }
  204. }
  205. return cachedData, nil
  206. }
  207. // shouldUseDiskCache 判断是否应该使用磁盘缓存
  208. func shouldUseDiskCache(dataSize int64) bool {
  209. return common.ShouldUseDiskCache(dataSize)
  210. }
  211. // writeToDiskCache 将数据写入磁盘缓存
  212. func writeToDiskCache(base64Data string) (string, error) {
  213. return common.WriteDiskCacheFileString(common.DiskCacheTypeFile, base64Data)
  214. }
  215. // smartDetectMimeType 智能检测 MIME 类型
  216. func smartDetectMimeType(resp *http.Response, url string, fileBytes []byte) string {
  217. // 1. 尝试从 Content-Type header 获取
  218. mimeType := resp.Header.Get("Content-Type")
  219. if idx := strings.Index(mimeType, ";"); idx != -1 {
  220. mimeType = strings.TrimSpace(mimeType[:idx])
  221. }
  222. if mimeType != "" && mimeType != "application/octet-stream" {
  223. return mimeType
  224. }
  225. // 2. 尝试从 Content-Disposition header 的 filename 获取
  226. if cd := resp.Header.Get("Content-Disposition"); cd != "" {
  227. parts := strings.Split(cd, ";")
  228. for _, part := range parts {
  229. part = strings.TrimSpace(part)
  230. if strings.HasPrefix(strings.ToLower(part), "filename=") {
  231. name := strings.TrimSpace(strings.TrimPrefix(part, "filename="))
  232. // 移除引号
  233. if len(name) > 2 && name[0] == '"' && name[len(name)-1] == '"' {
  234. name = name[1 : len(name)-1]
  235. }
  236. if dot := strings.LastIndex(name, "."); dot != -1 && dot+1 < len(name) {
  237. ext := strings.ToLower(name[dot+1:])
  238. if ext != "" {
  239. mt := GetMimeTypeByExtension(ext)
  240. if mt != "application/octet-stream" {
  241. return mt
  242. }
  243. }
  244. }
  245. break
  246. }
  247. }
  248. }
  249. // 3. 尝试从 URL 路径获取扩展名
  250. mt := guessMimeTypeFromURL(url)
  251. if mt != "application/octet-stream" {
  252. return mt
  253. }
  254. // 4. 使用 http.DetectContentType 内容嗅探
  255. if len(fileBytes) > 0 {
  256. sniffed := http.DetectContentType(fileBytes)
  257. if sniffed != "" && sniffed != "application/octet-stream" {
  258. // 去除可能的 charset 参数
  259. if idx := strings.Index(sniffed, ";"); idx != -1 {
  260. sniffed = strings.TrimSpace(sniffed[:idx])
  261. }
  262. return sniffed
  263. }
  264. // 4.5 尝试 HEIF/HEIC 检测(Go 标准库不识别)
  265. if heifMime := detectHEIF(fileBytes); heifMime != "" {
  266. return heifMime
  267. }
  268. }
  269. // 5. 尝试作为图片解码获取格式
  270. if len(fileBytes) > 0 {
  271. if _, format, err := decodeImageConfig(fileBytes); err == nil && format != "" {
  272. return "image/" + strings.ToLower(format)
  273. }
  274. }
  275. // 最终回退
  276. return "application/octet-stream"
  277. }
  278. // loadFromBase64 从 base64 字符串加载文件
  279. func loadFromBase64(base64String string, providedMimeType string) (*types.CachedFileData, error) {
  280. var mimeType string
  281. var cleanBase64 string
  282. // 处理 data: 前缀
  283. if strings.HasPrefix(base64String, "data:") {
  284. idx := strings.Index(base64String, ",")
  285. if idx != -1 {
  286. header := base64String[:idx]
  287. cleanBase64 = base64String[idx+1:]
  288. if strings.Contains(header, ":") && strings.Contains(header, ";") {
  289. mimeStart := strings.Index(header, ":") + 1
  290. mimeEnd := strings.Index(header, ";")
  291. if mimeStart < mimeEnd {
  292. mimeType = header[mimeStart:mimeEnd]
  293. }
  294. }
  295. } else {
  296. cleanBase64 = base64String
  297. }
  298. } else {
  299. cleanBase64 = base64String
  300. }
  301. if providedMimeType != "" {
  302. mimeType = providedMimeType
  303. }
  304. decodedData, err := base64.StdEncoding.DecodeString(cleanBase64)
  305. if err != nil {
  306. return nil, fmt.Errorf("failed to decode base64 data: %w", err)
  307. }
  308. base64Size := int64(len(cleanBase64))
  309. var cachedData *types.CachedFileData
  310. if shouldUseDiskCache(base64Size) {
  311. diskPath, err := writeToDiskCache(cleanBase64)
  312. if err != nil {
  313. cachedData = types.NewMemoryCachedData(cleanBase64, mimeType, int64(len(decodedData)))
  314. } else {
  315. cachedData = types.NewDiskCachedData(diskPath, mimeType, int64(len(decodedData)))
  316. cachedData.DiskSize = base64Size
  317. cachedData.OnClose = func(size int64) {
  318. common.DecrementDiskFiles(size)
  319. }
  320. common.IncrementDiskFiles(base64Size)
  321. }
  322. } else {
  323. cachedData = types.NewMemoryCachedData(cleanBase64, mimeType, int64(len(decodedData)))
  324. }
  325. if mimeType == "" || strings.HasPrefix(mimeType, "image/") {
  326. config, format, err := decodeImageConfig(decodedData)
  327. if err == nil {
  328. cachedData.ImageConfig = &config
  329. cachedData.ImageFormat = format
  330. if mimeType == "" {
  331. cachedData.MimeType = "image/" + format
  332. }
  333. }
  334. }
  335. return cachedData, nil
  336. }
  337. // GetImageConfig 获取图片配置
  338. func GetImageConfig(c *gin.Context, source types.FileSource) (image.Config, string, error) {
  339. cachedData, err := LoadFileSource(c, source, "get_image_config")
  340. if err != nil {
  341. return image.Config{}, "", err
  342. }
  343. if cachedData.ImageConfig != nil {
  344. return *cachedData.ImageConfig, cachedData.ImageFormat, nil
  345. }
  346. base64Str, err := cachedData.GetBase64Data()
  347. if err != nil {
  348. return image.Config{}, "", fmt.Errorf("failed to get base64 data: %w", err)
  349. }
  350. decodedData, err := base64.StdEncoding.DecodeString(base64Str)
  351. if err != nil {
  352. return image.Config{}, "", fmt.Errorf("failed to decode base64 for image config: %w", err)
  353. }
  354. config, format, err := decodeImageConfig(decodedData)
  355. if err != nil {
  356. return image.Config{}, "", err
  357. }
  358. cachedData.ImageConfig = &config
  359. cachedData.ImageFormat = format
  360. return config, format, nil
  361. }
  362. // GetBase64Data 获取 base64 编码的数据
  363. func GetBase64Data(c *gin.Context, source types.FileSource, reason ...string) (string, string, error) {
  364. cachedData, err := LoadFileSource(c, source, reason...)
  365. if err != nil {
  366. return "", "", err
  367. }
  368. base64Str, err := cachedData.GetBase64Data()
  369. if err != nil {
  370. return "", "", fmt.Errorf("failed to get base64 data: %w", err)
  371. }
  372. return base64Str, cachedData.MimeType, nil
  373. }
  374. // GetMimeType 获取文件的 MIME 类型
  375. func GetMimeType(c *gin.Context, source types.FileSource) (string, error) {
  376. if source.HasCache() {
  377. return source.GetCache().MimeType, nil
  378. }
  379. if urlSource, ok := source.(*types.URLSource); ok {
  380. mimeType, err := GetFileTypeFromUrl(c, urlSource.URL, "get_mime_type")
  381. if err == nil && mimeType != "" && mimeType != "application/octet-stream" {
  382. return mimeType, nil
  383. }
  384. }
  385. cachedData, err := LoadFileSource(c, source, "get_mime_type")
  386. if err != nil {
  387. return "", err
  388. }
  389. return cachedData.MimeType, nil
  390. }
  391. // DetectFileType 检测文件类型
  392. func DetectFileType(mimeType string) types.FileType {
  393. if strings.HasPrefix(mimeType, "image/") {
  394. return types.FileTypeImage
  395. }
  396. if strings.HasPrefix(mimeType, "audio/") {
  397. return types.FileTypeAudio
  398. }
  399. if strings.HasPrefix(mimeType, "video/") {
  400. return types.FileTypeVideo
  401. }
  402. return types.FileTypeFile
  403. }
  404. // decodeImageConfig 从字节数据解码图片配置
  405. func decodeImageConfig(data []byte) (image.Config, string, error) {
  406. reader := bytes.NewReader(data)
  407. config, format, err := image.DecodeConfig(reader)
  408. if err == nil {
  409. return config, format, nil
  410. }
  411. reader.Seek(0, io.SeekStart)
  412. config, err = webp.DecodeConfig(reader)
  413. if err == nil {
  414. return config, "webp", nil
  415. }
  416. // Try HEIF/HEIC: parse ISOBMFF ispe box for dimensions
  417. if heifMime := detectHEIF(data); heifMime != "" {
  418. formatName := "heif"
  419. if heifMime == "image/heic" {
  420. formatName = "heic"
  421. }
  422. if w, h, ok := parseHEIFDimensions(data); ok {
  423. return image.Config{Width: w, Height: h}, formatName, nil
  424. }
  425. return image.Config{}, "", fmt.Errorf("failed to decode HEIF/HEIC image dimensions")
  426. }
  427. return image.Config{}, "", fmt.Errorf("failed to decode image config: unsupported format")
  428. }
  429. // detectHEIF checks ISOBMFF magic bytes to detect HEIC/HEIF files.
  430. // Returns "image/heic", "image/heif", or "" if not recognized.
  431. func detectHEIF(data []byte) string {
  432. if len(data) < 12 {
  433. return ""
  434. }
  435. // ISOBMFF: bytes[4:8] must be "ftyp"
  436. if string(data[4:8]) != "ftyp" {
  437. return ""
  438. }
  439. brand := string(data[8:12])
  440. switch brand {
  441. case "heic", "heix", "hevc", "hevx", "heim", "heis":
  442. return "image/heic"
  443. case "mif1", "msf1":
  444. return "image/heif"
  445. default:
  446. return ""
  447. }
  448. }
  449. // parseHEIFDimensions parses ISOBMFF box tree to find the ispe box
  450. // and extract image width/height. Returns (width, height, ok).
  451. func parseHEIFDimensions(data []byte) (int, int, bool) {
  452. size := len(data)
  453. if size < 12 {
  454. return 0, 0, false
  455. }
  456. // Walk top-level boxes to find "meta"
  457. offset := 0
  458. for offset+8 <= size {
  459. boxSize := int(binary.BigEndian.Uint32(data[offset : offset+4]))
  460. boxType := string(data[offset+4 : offset+8])
  461. headerLen := 8
  462. if boxSize == 1 {
  463. // 64-bit extended size
  464. if offset+16 > size {
  465. break
  466. }
  467. boxSize = int(binary.BigEndian.Uint64(data[offset+8 : offset+16]))
  468. headerLen = 16
  469. } else if boxSize == 0 {
  470. // box extends to end of data
  471. boxSize = size - offset
  472. }
  473. if boxSize < headerLen || offset+boxSize > size {
  474. break
  475. }
  476. if boxType == "meta" {
  477. // meta is a full box: 4 bytes version/flags after header
  478. metaData := data[offset+headerLen : offset+boxSize]
  479. if len(metaData) < 4 {
  480. return 0, 0, false
  481. }
  482. return findISPE(metaData[4:])
  483. }
  484. offset += boxSize
  485. }
  486. return 0, 0, false
  487. }
  488. // findISPE recursively searches for the ispe box within container boxes.
  489. // Path: meta -> iprp -> ipco -> ispe
  490. func findISPE(data []byte) (int, int, bool) {
  491. offset := 0
  492. size := len(data)
  493. for offset+8 <= size {
  494. boxSize := int(binary.BigEndian.Uint32(data[offset : offset+4]))
  495. boxType := string(data[offset+4 : offset+8])
  496. if boxSize < 8 || offset+boxSize > size {
  497. break
  498. }
  499. content := data[offset+8 : offset+boxSize]
  500. switch boxType {
  501. case "iprp", "ipco":
  502. if w, h, ok := findISPE(content); ok {
  503. return w, h, true
  504. }
  505. case "ispe":
  506. // ispe is a full box: 4 bytes version/flags, then 4 bytes width, 4 bytes height
  507. if len(content) >= 12 {
  508. w := int(binary.BigEndian.Uint32(content[4:8]))
  509. h := int(binary.BigEndian.Uint32(content[8:12]))
  510. if w > 0 && h > 0 {
  511. return w, h, true
  512. }
  513. }
  514. }
  515. offset += boxSize
  516. }
  517. return 0, 0, false
  518. }
  519. // guessMimeTypeFromURL 从 URL 猜测 MIME 类型
  520. func guessMimeTypeFromURL(url string) string {
  521. cleanedURL := url
  522. if q := strings.Index(cleanedURL, "?"); q != -1 {
  523. cleanedURL = cleanedURL[:q]
  524. }
  525. if slash := strings.LastIndex(cleanedURL, "/"); slash != -1 && slash+1 < len(cleanedURL) {
  526. last := cleanedURL[slash+1:]
  527. if dot := strings.LastIndex(last, "."); dot != -1 && dot+1 < len(last) {
  528. ext := strings.ToLower(last[dot+1:])
  529. return GetMimeTypeByExtension(ext)
  530. }
  531. }
  532. return "application/octet-stream"
  533. }