relay_utils.go 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213
  1. package common
  2. import (
  3. "fmt"
  4. "net/http"
  5. "one-api/common"
  6. "one-api/constant"
  7. "one-api/dto"
  8. "strconv"
  9. "strings"
  10. "github.com/gin-gonic/gin"
  11. )
  12. type HasPrompt interface {
  13. GetPrompt() string
  14. }
  15. type HasImage interface {
  16. HasImage() bool
  17. }
  18. func GetFullRequestURL(baseURL string, requestURL string, channelType int) string {
  19. fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
  20. if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") {
  21. switch channelType {
  22. case constant.ChannelTypeOpenAI:
  23. fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/v1"))
  24. case constant.ChannelTypeAzure:
  25. fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/openai/deployments"))
  26. }
  27. }
  28. return fullRequestURL
  29. }
  30. func GetAPIVersion(c *gin.Context) string {
  31. query := c.Request.URL.Query()
  32. apiVersion := query.Get("api-version")
  33. if apiVersion == "" {
  34. apiVersion = c.GetString("api_version")
  35. }
  36. return apiVersion
  37. }
  38. func createTaskError(err error, code string, statusCode int, localError bool) *dto.TaskError {
  39. return &dto.TaskError{
  40. Code: code,
  41. Message: err.Error(),
  42. StatusCode: statusCode,
  43. LocalError: localError,
  44. Error: err,
  45. }
  46. }
  47. func storeTaskRequest(c *gin.Context, info *RelayInfo, action string, requestObj TaskSubmitReq) {
  48. info.Action = action
  49. c.Set("task_request", requestObj)
  50. }
  51. func validatePrompt(prompt string) *dto.TaskError {
  52. if strings.TrimSpace(prompt) == "" {
  53. return createTaskError(fmt.Errorf("prompt is required"), "invalid_request", http.StatusBadRequest, true)
  54. }
  55. return nil
  56. }
  57. func validateMultipartTaskRequest(c *gin.Context, info *RelayInfo, action string) (TaskSubmitReq, error) {
  58. var req TaskSubmitReq
  59. if _, err := c.MultipartForm(); err != nil {
  60. return req, err
  61. }
  62. formData := c.Request.PostForm
  63. req = TaskSubmitReq{
  64. Prompt: formData.Get("prompt"),
  65. Model: formData.Get("model"),
  66. Mode: formData.Get("mode"),
  67. Image: formData.Get("image"),
  68. Size: formData.Get("size"),
  69. Metadata: make(map[string]interface{}),
  70. }
  71. if durationStr := formData.Get("seconds"); durationStr != "" {
  72. if duration, err := strconv.Atoi(durationStr); err == nil {
  73. req.Duration = duration
  74. }
  75. }
  76. if images := formData["images"]; len(images) > 0 {
  77. req.Images = images
  78. }
  79. for key, values := range formData {
  80. if len(values) > 0 && !isKnownTaskField(key) {
  81. if intVal, err := strconv.Atoi(values[0]); err == nil {
  82. req.Metadata[key] = intVal
  83. } else if floatVal, err := strconv.ParseFloat(values[0], 64); err == nil {
  84. req.Metadata[key] = floatVal
  85. } else {
  86. req.Metadata[key] = values[0]
  87. }
  88. }
  89. }
  90. return req, nil
  91. }
  92. func ValidateMultipartDirect(c *gin.Context, info *RelayInfo) *dto.TaskError {
  93. contentType := c.GetHeader("Content-Type")
  94. var prompt string
  95. var hasInputReference bool
  96. if strings.HasPrefix(contentType, "multipart/form-data") {
  97. form, err := common.ParseMultipartFormReusable(c)
  98. if err != nil {
  99. return createTaskError(err, "invalid_multipart_form", http.StatusBadRequest, true)
  100. }
  101. defer form.RemoveAll()
  102. prompts, ok := form.Value["prompt"]
  103. if !ok || len(prompts) == 0 {
  104. return createTaskError(fmt.Errorf("prompt field is required"), "missing_prompt", http.StatusBadRequest, true)
  105. }
  106. prompt = prompts[0]
  107. if _, ok := form.Value["model"]; !ok {
  108. return createTaskError(fmt.Errorf("model field is required"), "missing_model", http.StatusBadRequest, true)
  109. }
  110. if _, ok := form.File["input_reference"]; ok {
  111. hasInputReference = true
  112. }
  113. } else {
  114. var req TaskSubmitReq
  115. if err := common.UnmarshalBodyReusable(c, &req); err != nil {
  116. return createTaskError(err, "invalid_json", http.StatusBadRequest, true)
  117. }
  118. prompt = req.Prompt
  119. if strings.TrimSpace(req.Model) == "" {
  120. return createTaskError(fmt.Errorf("model field is required"), "missing_model", http.StatusBadRequest, true)
  121. }
  122. if req.HasImage() {
  123. hasInputReference = true
  124. }
  125. }
  126. if taskErr := validatePrompt(prompt); taskErr != nil {
  127. return taskErr
  128. }
  129. action := constant.TaskActionTextGenerate
  130. if hasInputReference {
  131. action = constant.TaskActionGenerate
  132. }
  133. info.Action = action
  134. return nil
  135. }
  136. func isKnownTaskField(field string) bool {
  137. knownFields := map[string]bool{
  138. "prompt": true,
  139. "model": true,
  140. "mode": true,
  141. "image": true,
  142. "images": true,
  143. "size": true,
  144. "duration": true,
  145. "input_reference": true, // Sora 特有字段
  146. }
  147. return knownFields[field]
  148. }
  149. func ValidateBasicTaskRequest(c *gin.Context, info *RelayInfo, action string) *dto.TaskError {
  150. var err error
  151. contentType := c.GetHeader("Content-Type")
  152. var req TaskSubmitReq
  153. if strings.HasPrefix(contentType, "multipart/form-data") {
  154. req, err = validateMultipartTaskRequest(c, info, action)
  155. if err != nil {
  156. return createTaskError(err, "invalid_multipart_form", http.StatusBadRequest, true)
  157. }
  158. } else if err := common.UnmarshalBodyReusable(c, &req); err != nil {
  159. return createTaskError(err, "invalid_request", http.StatusBadRequest, true)
  160. }
  161. if taskErr := validatePrompt(req.Prompt); taskErr != nil {
  162. return taskErr
  163. }
  164. if len(req.Images) == 0 && strings.TrimSpace(req.Image) != "" {
  165. // 兼容单图上传
  166. req.Images = []string{req.Image}
  167. }
  168. if req.HasImage() {
  169. action = constant.TaskActionGenerate
  170. if info.ChannelType == constant.ChannelTypeVidu {
  171. // vidu 增加 首尾帧生视频和参考图生视频
  172. if len(req.Images) == 2 {
  173. action = constant.TaskActionFirstTailGenerate
  174. } else if len(req.Images) > 2 {
  175. action = constant.TaskActionReferenceGenerate
  176. }
  177. }
  178. }
  179. storeTaskRequest(c, info, action, req)
  180. return nil
  181. }