gin.go 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263
  1. package common
  2. import (
  3. "bytes"
  4. "errors"
  5. "io"
  6. "mime"
  7. "mime/multipart"
  8. "net/http"
  9. "net/url"
  10. "strings"
  11. "time"
  12. "github.com/QuantumNous/new-api/constant"
  13. "github.com/gin-gonic/gin"
  14. )
  15. const KeyRequestBody = "key_request_body"
  16. var ErrRequestBodyTooLarge = errors.New("request body too large")
  17. func IsRequestBodyTooLargeError(err error) bool {
  18. if err == nil {
  19. return false
  20. }
  21. if errors.Is(err, ErrRequestBodyTooLarge) {
  22. return true
  23. }
  24. var mbe *http.MaxBytesError
  25. return errors.As(err, &mbe)
  26. }
  27. func GetRequestBody(c *gin.Context) ([]byte, error) {
  28. cached, exists := c.Get(KeyRequestBody)
  29. if exists && cached != nil {
  30. if b, ok := cached.([]byte); ok {
  31. return b, nil
  32. }
  33. }
  34. maxMB := constant.MaxRequestBodyMB
  35. if maxMB <= 0 {
  36. maxMB = 32
  37. }
  38. maxBytes := int64(maxMB) << 20
  39. limited := io.LimitReader(c.Request.Body, maxBytes+1)
  40. body, err := io.ReadAll(limited)
  41. if err != nil {
  42. _ = c.Request.Body.Close()
  43. if IsRequestBodyTooLargeError(err) {
  44. return nil, ErrRequestBodyTooLarge
  45. }
  46. return nil, err
  47. }
  48. _ = c.Request.Body.Close()
  49. if int64(len(body)) > maxBytes {
  50. return nil, ErrRequestBodyTooLarge
  51. }
  52. c.Set(KeyRequestBody, body)
  53. return body, nil
  54. }
  55. func UnmarshalBodyReusable(c *gin.Context, v any) error {
  56. requestBody, err := GetRequestBody(c)
  57. if err != nil {
  58. return err
  59. }
  60. //if DebugEnabled {
  61. // println("UnmarshalBodyReusable request body:", string(requestBody))
  62. //}
  63. contentType := c.Request.Header.Get("Content-Type")
  64. if strings.HasPrefix(contentType, "application/json") {
  65. err = Unmarshal(requestBody, v)
  66. } else if strings.Contains(contentType, gin.MIMEPOSTForm) {
  67. err = parseFormData(requestBody, v)
  68. } else if strings.Contains(contentType, gin.MIMEMultipartPOSTForm) {
  69. err = parseMultipartFormData(c, requestBody, v)
  70. } else {
  71. // skip for now
  72. // TODO: someday non json request have variant model, we will need to implementation this
  73. }
  74. if err != nil {
  75. return err
  76. }
  77. // Reset request body
  78. c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
  79. return nil
  80. }
  81. func SetContextKey(c *gin.Context, key constant.ContextKey, value any) {
  82. c.Set(string(key), value)
  83. }
  84. func GetContextKey(c *gin.Context, key constant.ContextKey) (any, bool) {
  85. return c.Get(string(key))
  86. }
  87. func GetContextKeyString(c *gin.Context, key constant.ContextKey) string {
  88. return c.GetString(string(key))
  89. }
  90. func GetContextKeyInt(c *gin.Context, key constant.ContextKey) int {
  91. return c.GetInt(string(key))
  92. }
  93. func GetContextKeyBool(c *gin.Context, key constant.ContextKey) bool {
  94. return c.GetBool(string(key))
  95. }
  96. func GetContextKeyStringSlice(c *gin.Context, key constant.ContextKey) []string {
  97. return c.GetStringSlice(string(key))
  98. }
  99. func GetContextKeyStringMap(c *gin.Context, key constant.ContextKey) map[string]any {
  100. return c.GetStringMap(string(key))
  101. }
  102. func GetContextKeyTime(c *gin.Context, key constant.ContextKey) time.Time {
  103. return c.GetTime(string(key))
  104. }
  105. func GetContextKeyType[T any](c *gin.Context, key constant.ContextKey) (T, bool) {
  106. if value, ok := c.Get(string(key)); ok {
  107. if v, ok := value.(T); ok {
  108. return v, true
  109. }
  110. }
  111. var t T
  112. return t, false
  113. }
  114. func ApiError(c *gin.Context, err error) {
  115. c.JSON(http.StatusOK, gin.H{
  116. "success": false,
  117. "message": err.Error(),
  118. })
  119. }
  120. func ApiErrorMsg(c *gin.Context, msg string) {
  121. c.JSON(http.StatusOK, gin.H{
  122. "success": false,
  123. "message": msg,
  124. })
  125. }
  126. func ApiSuccess(c *gin.Context, data any) {
  127. c.JSON(http.StatusOK, gin.H{
  128. "success": true,
  129. "message": "",
  130. "data": data,
  131. })
  132. }
  133. func ParseMultipartFormReusable(c *gin.Context) (*multipart.Form, error) {
  134. requestBody, err := GetRequestBody(c)
  135. if err != nil {
  136. return nil, err
  137. }
  138. contentType := c.Request.Header.Get("Content-Type")
  139. boundary, err := parseBoundary(contentType)
  140. if err != nil {
  141. return nil, err
  142. }
  143. reader := multipart.NewReader(bytes.NewReader(requestBody), boundary)
  144. form, err := reader.ReadForm(multipartMemoryLimit())
  145. if err != nil {
  146. return nil, err
  147. }
  148. // Reset request body
  149. c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
  150. return form, nil
  151. }
  152. func processFormMap(formMap map[string]any, v any) error {
  153. jsonData, err := Marshal(formMap)
  154. if err != nil {
  155. return err
  156. }
  157. err = Unmarshal(jsonData, v)
  158. if err != nil {
  159. return err
  160. }
  161. return nil
  162. }
  163. func parseFormData(data []byte, v any) error {
  164. values, err := url.ParseQuery(string(data))
  165. if err != nil {
  166. return err
  167. }
  168. formMap := make(map[string]any)
  169. for key, vals := range values {
  170. if len(vals) == 1 {
  171. formMap[key] = vals[0]
  172. } else {
  173. formMap[key] = vals
  174. }
  175. }
  176. return processFormMap(formMap, v)
  177. }
  178. func parseMultipartFormData(c *gin.Context, data []byte, v any) error {
  179. contentType := c.Request.Header.Get("Content-Type")
  180. boundary, err := parseBoundary(contentType)
  181. if err != nil {
  182. if errors.Is(err, errBoundaryNotFound) {
  183. return Unmarshal(data, v) // Fallback to JSON
  184. }
  185. return err
  186. }
  187. reader := multipart.NewReader(bytes.NewReader(data), boundary)
  188. form, err := reader.ReadForm(multipartMemoryLimit())
  189. if err != nil {
  190. return err
  191. }
  192. defer form.RemoveAll()
  193. formMap := make(map[string]any)
  194. for key, vals := range form.Value {
  195. if len(vals) == 1 {
  196. formMap[key] = vals[0]
  197. } else {
  198. formMap[key] = vals
  199. }
  200. }
  201. return processFormMap(formMap, v)
  202. }
  203. var errBoundaryNotFound = errors.New("multipart boundary not found")
  204. // parseBoundary extracts the multipart boundary from the Content-Type header using mime.ParseMediaType
  205. func parseBoundary(contentType string) (string, error) {
  206. if contentType == "" {
  207. return "", errBoundaryNotFound
  208. }
  209. // Boundary-UUID / boundary-------xxxxxx
  210. _, params, err := mime.ParseMediaType(contentType)
  211. if err != nil {
  212. return "", err
  213. }
  214. boundary, ok := params["boundary"]
  215. if !ok || boundary == "" {
  216. return "", errBoundaryNotFound
  217. }
  218. return boundary, nil
  219. }
  220. // multipartMemoryLimit returns the configured multipart memory limit in bytes
  221. func multipartMemoryLimit() int64 {
  222. limitMB := constant.MaxFileDownloadMB
  223. if limitMB <= 0 {
  224. limitMB = 32
  225. }
  226. return int64(limitMB) << 20
  227. }