file_decoder.go 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262
  1. package service
  2. import (
  3. "bytes"
  4. "encoding/base64"
  5. "fmt"
  6. "image"
  7. "io"
  8. "net/http"
  9. "one-api/common"
  10. "one-api/constant"
  11. "one-api/logger"
  12. "one-api/types"
  13. "strings"
  14. "github.com/gin-gonic/gin"
  15. )
  16. // GetFileTypeFromUrl 获取文件类型,返回 mime type, 例如 image/jpeg, image/png, image/gif, image/bmp, image/tiff, application/pdf
  17. // 如果获取失败,返回 application/octet-stream
  18. func GetFileTypeFromUrl(c *gin.Context, url string, reason ...string) (string, error) {
  19. response, err := DoDownloadRequest(url, reason...)
  20. if err != nil {
  21. common.SysLog(fmt.Sprintf("fail to get file type from url: %s, error: %s", url, err.Error()))
  22. return "", err
  23. }
  24. defer response.Body.Close()
  25. if response.StatusCode != 200 {
  26. logger.LogError(c, fmt.Sprintf("failed to download file from %s, status code: %d", url, response.StatusCode))
  27. return "", fmt.Errorf("failed to download file, status code: %d", response.StatusCode)
  28. }
  29. if headerType := strings.TrimSpace(response.Header.Get("Content-Type")); headerType != "" {
  30. if i := strings.Index(headerType, ";"); i != -1 {
  31. headerType = headerType[:i]
  32. }
  33. if headerType != "application/octet-stream" {
  34. return headerType, nil
  35. }
  36. }
  37. if cd := response.Header.Get("Content-Disposition"); cd != "" {
  38. parts := strings.Split(cd, ";")
  39. for _, part := range parts {
  40. part = strings.TrimSpace(part)
  41. if strings.HasPrefix(strings.ToLower(part), "filename=") {
  42. name := strings.TrimSpace(strings.TrimPrefix(part, "filename="))
  43. if len(name) > 2 && name[0] == '"' && name[len(name)-1] == '"' {
  44. name = name[1 : len(name)-1]
  45. }
  46. if dot := strings.LastIndex(name, "."); dot != -1 && dot+1 < len(name) {
  47. ext := strings.ToLower(name[dot+1:])
  48. if ext != "" {
  49. mt := GetMimeTypeByExtension(ext)
  50. if mt != "application/octet-stream" {
  51. return mt, nil
  52. }
  53. }
  54. }
  55. break
  56. }
  57. }
  58. }
  59. cleanedURL := url
  60. if q := strings.Index(cleanedURL, "?"); q != -1 {
  61. cleanedURL = cleanedURL[:q]
  62. }
  63. if slash := strings.LastIndex(cleanedURL, "/"); slash != -1 && slash+1 < len(cleanedURL) {
  64. last := cleanedURL[slash+1:]
  65. if dot := strings.LastIndex(last, "."); dot != -1 && dot+1 < len(last) {
  66. ext := strings.ToLower(last[dot+1:])
  67. if ext != "" {
  68. mt := GetMimeTypeByExtension(ext)
  69. if mt != "application/octet-stream" {
  70. return mt, nil
  71. }
  72. }
  73. }
  74. }
  75. var readData []byte
  76. limits := []int{512, 8 * 1024, 24 * 1024, 64 * 1024}
  77. for _, limit := range limits {
  78. logger.LogDebug(c, fmt.Sprintf("Trying to read %d bytes to determine file type", limit))
  79. if len(readData) < limit {
  80. need := limit - len(readData)
  81. tmp := make([]byte, need)
  82. n, _ := io.ReadFull(response.Body, tmp)
  83. if n > 0 {
  84. readData = append(readData, tmp[:n]...)
  85. }
  86. }
  87. if len(readData) == 0 {
  88. continue
  89. }
  90. sniffed := http.DetectContentType(readData)
  91. if sniffed != "" && sniffed != "application/octet-stream" {
  92. return sniffed, nil
  93. }
  94. if _, format, err := image.DecodeConfig(bytes.NewReader(readData)); err == nil {
  95. switch strings.ToLower(format) {
  96. case "jpeg", "jpg":
  97. return "image/jpeg", nil
  98. case "png":
  99. return "image/png", nil
  100. case "gif":
  101. return "image/gif", nil
  102. case "bmp":
  103. return "image/bmp", nil
  104. case "tiff":
  105. return "image/tiff", nil
  106. default:
  107. if format != "" {
  108. return "image/" + strings.ToLower(format), nil
  109. }
  110. }
  111. }
  112. }
  113. // Fallback
  114. return "application/octet-stream", nil
  115. }
  116. func GetFileBase64FromUrl(c *gin.Context, url string, reason ...string) (*types.LocalFileData, error) {
  117. contextKey := fmt.Sprintf("file_download_%s", common.GenerateHMAC(url))
  118. // Check if the file has already been downloaded in this request
  119. if cachedData, exists := c.Get(contextKey); exists {
  120. if common.DebugEnabled {
  121. logger.LogDebug(c, fmt.Sprintf("Using cached file data for URL: %s", url))
  122. }
  123. return cachedData.(*types.LocalFileData), nil
  124. }
  125. var maxFileSize = constant.MaxFileDownloadMB * 1024 * 1024
  126. resp, err := DoDownloadRequest(url, reason...)
  127. if err != nil {
  128. return nil, err
  129. }
  130. defer resp.Body.Close()
  131. // Always use LimitReader to prevent oversized downloads
  132. fileBytes, err := io.ReadAll(io.LimitReader(resp.Body, int64(maxFileSize+1)))
  133. if err != nil {
  134. return nil, err
  135. }
  136. // Check actual size after reading
  137. if len(fileBytes) > maxFileSize {
  138. return nil, fmt.Errorf("file size exceeds maximum allowed size: %dMB", constant.MaxFileDownloadMB)
  139. }
  140. // Convert to base64
  141. base64Data := base64.StdEncoding.EncodeToString(fileBytes)
  142. mimeType := resp.Header.Get("Content-Type")
  143. if len(strings.Split(mimeType, ";")) > 1 {
  144. // If Content-Type has parameters, take the first part
  145. mimeType = strings.Split(mimeType, ";")[0]
  146. }
  147. if mimeType == "application/octet-stream" {
  148. logger.LogDebug(c, fmt.Sprintf("MIME type is application/octet-stream for URL: %s", url))
  149. // try to guess the MIME type from the url last segment
  150. urlParts := strings.Split(url, "/")
  151. if len(urlParts) > 0 {
  152. lastSegment := urlParts[len(urlParts)-1]
  153. if strings.Contains(lastSegment, ".") {
  154. // Extract the file extension
  155. filename := strings.Split(lastSegment, ".")
  156. if len(filename) > 1 {
  157. ext := strings.ToLower(filename[len(filename)-1])
  158. // Guess MIME type based on file extension
  159. mimeType = GetMimeTypeByExtension(ext)
  160. }
  161. }
  162. } else {
  163. // try to guess the MIME type from the file extension
  164. fileName := resp.Header.Get("Content-Disposition")
  165. if fileName != "" {
  166. // Extract the filename from the Content-Disposition header
  167. parts := strings.Split(fileName, ";")
  168. for _, part := range parts {
  169. if strings.HasPrefix(strings.TrimSpace(part), "filename=") {
  170. fileName = strings.TrimSpace(strings.TrimPrefix(part, "filename="))
  171. // Remove quotes if present
  172. if len(fileName) > 2 && fileName[0] == '"' && fileName[len(fileName)-1] == '"' {
  173. fileName = fileName[1 : len(fileName)-1]
  174. }
  175. // Guess MIME type based on file extension
  176. if ext := strings.ToLower(strings.TrimPrefix(fileName, ".")); ext != "" {
  177. mimeType = GetMimeTypeByExtension(ext)
  178. }
  179. break
  180. }
  181. }
  182. }
  183. }
  184. }
  185. data := &types.LocalFileData{
  186. Base64Data: base64Data,
  187. MimeType: mimeType,
  188. Size: int64(len(fileBytes)),
  189. }
  190. // Store the file data in the context to avoid re-downloading
  191. c.Set(contextKey, data)
  192. return data, nil
  193. }
  194. func GetMimeTypeByExtension(ext string) string {
  195. // Convert to lowercase for case-insensitive comparison
  196. ext = strings.ToLower(ext)
  197. switch ext {
  198. // Text files
  199. case "txt", "md", "markdown", "csv", "json", "xml", "html", "htm":
  200. return "text/plain"
  201. // Image files
  202. case "jpg", "jpeg":
  203. return "image/jpeg"
  204. case "png":
  205. return "image/png"
  206. case "gif":
  207. return "image/gif"
  208. // Audio files
  209. case "mp3":
  210. return "audio/mp3"
  211. case "wav":
  212. return "audio/wav"
  213. case "mpeg":
  214. return "audio/mpeg"
  215. // Video files
  216. case "mp4":
  217. return "video/mp4"
  218. case "wmv":
  219. return "video/wmv"
  220. case "flv":
  221. return "video/flv"
  222. case "mov":
  223. return "video/mov"
  224. case "mpg":
  225. return "video/mpg"
  226. case "avi":
  227. return "video/avi"
  228. case "mpegps":
  229. return "video/mpegps"
  230. // Document files
  231. case "pdf":
  232. return "application/pdf"
  233. default:
  234. return "application/octet-stream" // Default for unknown types
  235. }
  236. }