relay_utils.go 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282
  1. package common
  2. import (
  3. "encoding/base64"
  4. "errors"
  5. "fmt"
  6. "io"
  7. "net/http"
  8. "strconv"
  9. "strings"
  10. "github.com/QuantumNous/new-api/common"
  11. "github.com/QuantumNous/new-api/constant"
  12. "github.com/QuantumNous/new-api/dto"
  13. "github.com/gin-gonic/gin"
  14. "github.com/samber/lo"
  15. )
  16. type HasPrompt interface {
  17. GetPrompt() string
  18. }
  19. type HasImage interface {
  20. HasImage() bool
  21. }
  22. func GetFullRequestURL(baseURL string, requestURL string, channelType int) string {
  23. fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
  24. if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") {
  25. switch channelType {
  26. case constant.ChannelTypeOpenAI:
  27. fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/v1"))
  28. case constant.ChannelTypeAzure:
  29. fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/openai/deployments"))
  30. }
  31. }
  32. return fullRequestURL
  33. }
  34. func GetAPIVersion(c *gin.Context) string {
  35. query := c.Request.URL.Query()
  36. apiVersion := query.Get("api-version")
  37. if apiVersion == "" {
  38. apiVersion = c.GetString("api_version")
  39. }
  40. return apiVersion
  41. }
  42. func createTaskError(err error, code string, statusCode int, localError bool) *dto.TaskError {
  43. return &dto.TaskError{
  44. Code: code,
  45. Message: err.Error(),
  46. StatusCode: statusCode,
  47. LocalError: localError,
  48. Error: err,
  49. }
  50. }
  51. func storeTaskRequest(c *gin.Context, info *RelayInfo, action string, requestObj TaskSubmitReq) {
  52. info.Action = action
  53. c.Set("task_request", requestObj)
  54. }
  55. func GetTaskRequest(c *gin.Context) (TaskSubmitReq, error) {
  56. v, exists := c.Get("task_request")
  57. if !exists {
  58. return TaskSubmitReq{}, fmt.Errorf("request not found in context")
  59. }
  60. req, ok := v.(TaskSubmitReq)
  61. if !ok {
  62. return TaskSubmitReq{}, fmt.Errorf("invalid task request type")
  63. }
  64. return req, nil
  65. }
  66. func validatePrompt(prompt string) *dto.TaskError {
  67. if strings.TrimSpace(prompt) == "" {
  68. return createTaskError(fmt.Errorf("prompt is required"), "invalid_request", http.StatusBadRequest, true)
  69. }
  70. return nil
  71. }
  72. func validateMultipartTaskRequest(c *gin.Context, info *RelayInfo, action string) (TaskSubmitReq, error) {
  73. var req TaskSubmitReq
  74. if _, err := c.MultipartForm(); err != nil {
  75. return req, err
  76. }
  77. formData := c.Request.PostForm
  78. req = TaskSubmitReq{
  79. Prompt: formData.Get("prompt"),
  80. Model: formData.Get("model"),
  81. Mode: formData.Get("mode"),
  82. Image: formData.Get("image"),
  83. Size: formData.Get("size"),
  84. Metadata: make(map[string]interface{}),
  85. }
  86. if durationStr := formData.Get("seconds"); durationStr != "" {
  87. if duration, err := strconv.Atoi(durationStr); err == nil {
  88. req.Duration = duration
  89. }
  90. }
  91. if images := formData["images"]; len(images) > 0 {
  92. req.Images = images
  93. }
  94. for key, values := range formData {
  95. if len(values) > 0 && !isKnownTaskField(key) {
  96. if intVal, err := strconv.Atoi(values[0]); err == nil {
  97. req.Metadata[key] = intVal
  98. } else if floatVal, err := strconv.ParseFloat(values[0], 64); err == nil {
  99. req.Metadata[key] = floatVal
  100. } else {
  101. req.Metadata[key] = values[0]
  102. }
  103. }
  104. }
  105. return req, nil
  106. }
  107. func ValidateMultipartDirect(c *gin.Context, info *RelayInfo) *dto.TaskError {
  108. var prompt string
  109. var model string
  110. var seconds int
  111. var size string
  112. var hasInputReference bool
  113. var req TaskSubmitReq
  114. if err := common.UnmarshalBodyReusable(c, &req); err != nil {
  115. return createTaskError(err, "invalid_json", http.StatusBadRequest, true)
  116. }
  117. prompt = req.Prompt
  118. model = req.Model
  119. size = req.Size
  120. seconds, _ = strconv.Atoi(req.Seconds)
  121. if seconds == 0 {
  122. seconds = req.Duration
  123. }
  124. if req.InputReference != "" {
  125. req.Images = []string{req.InputReference}
  126. }
  127. if strings.TrimSpace(req.Model) == "" {
  128. return createTaskError(fmt.Errorf("model field is required"), "missing_model", http.StatusBadRequest, true)
  129. }
  130. if req.HasImage() {
  131. hasInputReference = true
  132. }
  133. if taskErr := validatePrompt(prompt); taskErr != nil {
  134. return taskErr
  135. }
  136. action := constant.TaskActionTextGenerate
  137. if hasInputReference {
  138. action = constant.TaskActionGenerate
  139. }
  140. if strings.HasPrefix(model, "sora-2") {
  141. if size == "" {
  142. size = "720x1280"
  143. }
  144. if seconds <= 0 {
  145. seconds = 4
  146. }
  147. if model == "sora-2" && !lo.Contains([]string{"720x1280", "1280x720"}, size) {
  148. return createTaskError(fmt.Errorf("sora-2 size is invalid"), "invalid_size", http.StatusBadRequest, true)
  149. }
  150. if model == "sora-2-pro" && !lo.Contains([]string{"720x1280", "1280x720", "1792x1024", "1024x1792"}, size) {
  151. return createTaskError(fmt.Errorf("sora-2 size is invalid"), "invalid_size", http.StatusBadRequest, true)
  152. }
  153. info.PriceData.OtherRatios = map[string]float64{
  154. "seconds": float64(seconds),
  155. "size": 1,
  156. }
  157. if lo.Contains([]string{"1792x1024", "1024x1792"}, size) {
  158. info.PriceData.OtherRatios["size"] = 1.666667
  159. }
  160. }
  161. info.Action = action
  162. return nil
  163. }
  164. func isKnownTaskField(field string) bool {
  165. knownFields := map[string]bool{
  166. "prompt": true,
  167. "model": true,
  168. "mode": true,
  169. "image": true,
  170. "images": true,
  171. "size": true,
  172. "duration": true,
  173. "input_reference": true, // Sora 特有字段
  174. }
  175. return knownFields[field]
  176. }
  177. func ValidateBasicTaskRequest(c *gin.Context, info *RelayInfo, action string) *dto.TaskError {
  178. var err error
  179. contentType := c.GetHeader("Content-Type")
  180. var req TaskSubmitReq
  181. if strings.HasPrefix(contentType, "multipart/form-data") {
  182. req, err = validateMultipartTaskRequest(c, info, action)
  183. if err != nil {
  184. return createTaskError(err, "invalid_multipart_form", http.StatusBadRequest, true)
  185. }
  186. } else if err := common.UnmarshalBodyReusable(c, &req); err != nil {
  187. return createTaskError(err, "invalid_request", http.StatusBadRequest, true)
  188. }
  189. if taskErr := validatePrompt(req.Prompt); taskErr != nil {
  190. return taskErr
  191. }
  192. if len(req.Images) == 0 && strings.TrimSpace(req.Image) != "" {
  193. // 兼容单图上传
  194. req.Images = []string{req.Image}
  195. }
  196. storeTaskRequest(c, info, action, req)
  197. return nil
  198. }
  199. func GetImagesBase64sFromForm(c *gin.Context) ([]*Base64Data, error) {
  200. return GetBase64sFromForm(c, "image")
  201. }
  202. func GetImageBase64sFromForm(c *gin.Context) (*Base64Data, error) {
  203. base64s, err := GetImagesBase64sFromForm(c)
  204. if err != nil {
  205. return nil, err
  206. }
  207. return base64s[0], nil
  208. }
  209. type Base64Data struct {
  210. MimeType string
  211. Data string
  212. }
  213. func (m Base64Data) String() string {
  214. return fmt.Sprintf("data:%s;base64,%s", m.MimeType, m.Data)
  215. }
  216. func GetBase64sFromForm(c *gin.Context, fieldName string) ([]*Base64Data, error) {
  217. mf := c.Request.MultipartForm
  218. if mf == nil {
  219. if _, err := c.MultipartForm(); err != nil {
  220. return nil, fmt.Errorf("failed to parse image edit form request: %w", err)
  221. }
  222. mf = c.Request.MultipartForm
  223. }
  224. imageFiles, exists := mf.File[fieldName]
  225. if !exists || len(imageFiles) == 0 {
  226. return nil, errors.New("field " + fieldName + "\" is not found or empty")
  227. }
  228. var imageBase64s []*Base64Data
  229. for _, file := range imageFiles {
  230. image, err := file.Open()
  231. defer image.Close()
  232. if err != nil {
  233. return nil, errors.New("failed to open image file")
  234. }
  235. imageData, err := io.ReadAll(image)
  236. if err != nil {
  237. return nil, errors.New("failed to read image file")
  238. }
  239. mimeType := http.DetectContentType(imageData)
  240. base64Data := base64.StdEncoding.EncodeToString(imageData)
  241. imageBase64s = append(imageBase64s, &Base64Data{
  242. MimeType: mimeType,
  243. Data: base64Data,
  244. })
  245. }
  246. return imageBase64s, nil
  247. }