body_storage.go 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365
  1. package common
  2. import (
  3. "bytes"
  4. "fmt"
  5. "io"
  6. "os"
  7. "path/filepath"
  8. "sync"
  9. "sync/atomic"
  10. "time"
  11. "github.com/google/uuid"
  12. )
  13. // BodyStorage 请求体存储接口
  14. type BodyStorage interface {
  15. io.ReadSeeker
  16. io.Closer
  17. // Bytes 获取全部内容
  18. Bytes() ([]byte, error)
  19. // Size 获取数据大小
  20. Size() int64
  21. // IsDisk 是否是磁盘存储
  22. IsDisk() bool
  23. }
  24. // ErrStorageClosed 存储已关闭错误
  25. var ErrStorageClosed = fmt.Errorf("body storage is closed")
  26. // memoryStorage 内存存储实现
  27. type memoryStorage struct {
  28. data []byte
  29. reader *bytes.Reader
  30. size int64
  31. closed int32
  32. mu sync.Mutex
  33. }
  34. func newMemoryStorage(data []byte) *memoryStorage {
  35. size := int64(len(data))
  36. IncrementMemoryBuffers(size)
  37. return &memoryStorage{
  38. data: data,
  39. reader: bytes.NewReader(data),
  40. size: size,
  41. }
  42. }
  43. func (m *memoryStorage) Read(p []byte) (n int, err error) {
  44. m.mu.Lock()
  45. defer m.mu.Unlock()
  46. if atomic.LoadInt32(&m.closed) == 1 {
  47. return 0, ErrStorageClosed
  48. }
  49. return m.reader.Read(p)
  50. }
  51. func (m *memoryStorage) Seek(offset int64, whence int) (int64, error) {
  52. m.mu.Lock()
  53. defer m.mu.Unlock()
  54. if atomic.LoadInt32(&m.closed) == 1 {
  55. return 0, ErrStorageClosed
  56. }
  57. return m.reader.Seek(offset, whence)
  58. }
  59. func (m *memoryStorage) Close() error {
  60. m.mu.Lock()
  61. defer m.mu.Unlock()
  62. if atomic.CompareAndSwapInt32(&m.closed, 0, 1) {
  63. DecrementMemoryBuffers(m.size)
  64. }
  65. return nil
  66. }
  67. func (m *memoryStorage) Bytes() ([]byte, error) {
  68. m.mu.Lock()
  69. defer m.mu.Unlock()
  70. if atomic.LoadInt32(&m.closed) == 1 {
  71. return nil, ErrStorageClosed
  72. }
  73. return m.data, nil
  74. }
  75. func (m *memoryStorage) Size() int64 {
  76. return m.size
  77. }
  78. func (m *memoryStorage) IsDisk() bool {
  79. return false
  80. }
  81. // diskStorage 磁盘存储实现
  82. type diskStorage struct {
  83. file *os.File
  84. filePath string
  85. size int64
  86. closed int32
  87. mu sync.Mutex
  88. }
  89. func newDiskStorage(data []byte, cachePath string) (*diskStorage, error) {
  90. // 确定缓存目录
  91. dir := cachePath
  92. if dir == "" {
  93. dir = os.TempDir()
  94. }
  95. dir = filepath.Join(dir, "new-api-body-cache")
  96. // 确保目录存在
  97. if err := os.MkdirAll(dir, 0755); err != nil {
  98. return nil, fmt.Errorf("failed to create cache directory: %w", err)
  99. }
  100. // 创建临时文件
  101. filename := fmt.Sprintf("body-%s-%d.tmp", uuid.New().String()[:8], time.Now().UnixNano())
  102. filePath := filepath.Join(dir, filename)
  103. file, err := os.OpenFile(filePath, os.O_CREATE|os.O_RDWR|os.O_EXCL, 0600)
  104. if err != nil {
  105. return nil, fmt.Errorf("failed to create temp file: %w", err)
  106. }
  107. // 写入数据
  108. n, err := file.Write(data)
  109. if err != nil {
  110. file.Close()
  111. os.Remove(filePath)
  112. return nil, fmt.Errorf("failed to write to temp file: %w", err)
  113. }
  114. // 重置文件指针
  115. if _, err := file.Seek(0, io.SeekStart); err != nil {
  116. file.Close()
  117. os.Remove(filePath)
  118. return nil, fmt.Errorf("failed to seek temp file: %w", err)
  119. }
  120. size := int64(n)
  121. IncrementDiskFiles(size)
  122. return &diskStorage{
  123. file: file,
  124. filePath: filePath,
  125. size: size,
  126. }, nil
  127. }
  128. func newDiskStorageFromReader(reader io.Reader, maxBytes int64, cachePath string) (*diskStorage, error) {
  129. // 确定缓存目录
  130. dir := cachePath
  131. if dir == "" {
  132. dir = os.TempDir()
  133. }
  134. dir = filepath.Join(dir, "new-api-body-cache")
  135. // 确保目录存在
  136. if err := os.MkdirAll(dir, 0755); err != nil {
  137. return nil, fmt.Errorf("failed to create cache directory: %w", err)
  138. }
  139. // 创建临时文件
  140. filename := fmt.Sprintf("body-%s-%d.tmp", uuid.New().String()[:8], time.Now().UnixNano())
  141. filePath := filepath.Join(dir, filename)
  142. file, err := os.OpenFile(filePath, os.O_CREATE|os.O_RDWR|os.O_EXCL, 0600)
  143. if err != nil {
  144. return nil, fmt.Errorf("failed to create temp file: %w", err)
  145. }
  146. // 从 reader 读取并写入文件
  147. written, err := io.Copy(file, io.LimitReader(reader, maxBytes+1))
  148. if err != nil {
  149. file.Close()
  150. os.Remove(filePath)
  151. return nil, fmt.Errorf("failed to write to temp file: %w", err)
  152. }
  153. if written > maxBytes {
  154. file.Close()
  155. os.Remove(filePath)
  156. return nil, ErrRequestBodyTooLarge
  157. }
  158. // 重置文件指针
  159. if _, err := file.Seek(0, io.SeekStart); err != nil {
  160. file.Close()
  161. os.Remove(filePath)
  162. return nil, fmt.Errorf("failed to seek temp file: %w", err)
  163. }
  164. IncrementDiskFiles(written)
  165. return &diskStorage{
  166. file: file,
  167. filePath: filePath,
  168. size: written,
  169. }, nil
  170. }
  171. func (d *diskStorage) Read(p []byte) (n int, err error) {
  172. d.mu.Lock()
  173. defer d.mu.Unlock()
  174. if atomic.LoadInt32(&d.closed) == 1 {
  175. return 0, ErrStorageClosed
  176. }
  177. return d.file.Read(p)
  178. }
  179. func (d *diskStorage) Seek(offset int64, whence int) (int64, error) {
  180. d.mu.Lock()
  181. defer d.mu.Unlock()
  182. if atomic.LoadInt32(&d.closed) == 1 {
  183. return 0, ErrStorageClosed
  184. }
  185. return d.file.Seek(offset, whence)
  186. }
  187. func (d *diskStorage) Close() error {
  188. d.mu.Lock()
  189. defer d.mu.Unlock()
  190. if atomic.CompareAndSwapInt32(&d.closed, 0, 1) {
  191. d.file.Close()
  192. os.Remove(d.filePath)
  193. DecrementDiskFiles(d.size)
  194. }
  195. return nil
  196. }
  197. func (d *diskStorage) Bytes() ([]byte, error) {
  198. d.mu.Lock()
  199. defer d.mu.Unlock()
  200. if atomic.LoadInt32(&d.closed) == 1 {
  201. return nil, ErrStorageClosed
  202. }
  203. // 保存当前位置
  204. currentPos, err := d.file.Seek(0, io.SeekCurrent)
  205. if err != nil {
  206. return nil, err
  207. }
  208. // 移动到开头
  209. if _, err := d.file.Seek(0, io.SeekStart); err != nil {
  210. return nil, err
  211. }
  212. // 读取全部内容
  213. data := make([]byte, d.size)
  214. _, err = io.ReadFull(d.file, data)
  215. if err != nil {
  216. return nil, err
  217. }
  218. // 恢复位置
  219. if _, err := d.file.Seek(currentPos, io.SeekStart); err != nil {
  220. return nil, err
  221. }
  222. return data, nil
  223. }
  224. func (d *diskStorage) Size() int64 {
  225. return d.size
  226. }
  227. func (d *diskStorage) IsDisk() bool {
  228. return true
  229. }
  230. // CreateBodyStorage 根据数据大小创建合适的存储
  231. func CreateBodyStorage(data []byte) (BodyStorage, error) {
  232. size := int64(len(data))
  233. threshold := GetDiskCacheThresholdBytes()
  234. // 检查是否应该使用磁盘缓存
  235. if IsDiskCacheEnabled() &&
  236. size >= threshold &&
  237. IsDiskCacheAvailable(size) {
  238. storage, err := newDiskStorage(data, GetDiskCachePath())
  239. if err != nil {
  240. // 如果磁盘存储失败,回退到内存存储
  241. SysError(fmt.Sprintf("failed to create disk storage, falling back to memory: %v", err))
  242. return newMemoryStorage(data), nil
  243. }
  244. return storage, nil
  245. }
  246. return newMemoryStorage(data), nil
  247. }
  248. // CreateBodyStorageFromReader 从 Reader 创建存储(用于大请求的流式处理)
  249. func CreateBodyStorageFromReader(reader io.Reader, contentLength int64, maxBytes int64) (BodyStorage, error) {
  250. threshold := GetDiskCacheThresholdBytes()
  251. // 如果启用了磁盘缓存且内容长度超过阈值,直接使用磁盘存储
  252. if IsDiskCacheEnabled() &&
  253. contentLength > 0 &&
  254. contentLength >= threshold &&
  255. IsDiskCacheAvailable(contentLength) {
  256. storage, err := newDiskStorageFromReader(reader, maxBytes, GetDiskCachePath())
  257. if err != nil {
  258. if IsRequestBodyTooLargeError(err) {
  259. return nil, err
  260. }
  261. // 磁盘存储失败,reader 已被消费,无法安全回退
  262. // 直接返回错误而非尝试回退(因为 reader 数据已丢失)
  263. return nil, fmt.Errorf("disk storage creation failed: %w", err)
  264. }
  265. IncrementDiskCacheHits()
  266. return storage, nil
  267. }
  268. // 使用内存读取
  269. data, err := io.ReadAll(io.LimitReader(reader, maxBytes+1))
  270. if err != nil {
  271. return nil, err
  272. }
  273. if int64(len(data)) > maxBytes {
  274. return nil, ErrRequestBodyTooLarge
  275. }
  276. storage, err := CreateBodyStorage(data)
  277. if err != nil {
  278. return nil, err
  279. }
  280. // 如果最终使用内存存储,记录内存缓存命中
  281. if !storage.IsDisk() {
  282. IncrementMemoryCacheHits()
  283. } else {
  284. IncrementDiskCacheHits()
  285. }
  286. return storage, nil
  287. }
  288. // CleanupOldCacheFiles 清理旧的缓存文件(用于启动时清理残留)
  289. func CleanupOldCacheFiles() {
  290. cachePath := GetDiskCachePath()
  291. if cachePath == "" {
  292. cachePath = os.TempDir()
  293. }
  294. dir := filepath.Join(cachePath, "new-api-body-cache")
  295. entries, err := os.ReadDir(dir)
  296. if err != nil {
  297. return // 目录不存在或无法读取
  298. }
  299. now := time.Now()
  300. for _, entry := range entries {
  301. if entry.IsDir() {
  302. continue
  303. }
  304. info, err := entry.Info()
  305. if err != nil {
  306. continue
  307. }
  308. // 删除超过 5 分钟的旧文件
  309. if now.Sub(info.ModTime()) > 5*time.Minute {
  310. os.Remove(filepath.Join(dir, entry.Name()))
  311. }
  312. }
  313. }