relay_utils.go 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185
  1. package common
  2. import (
  3. "fmt"
  4. "net/http"
  5. "one-api/common"
  6. "one-api/constant"
  7. "one-api/dto"
  8. "strconv"
  9. "strings"
  10. "github.com/gin-gonic/gin"
  11. )
  12. type HasPrompt interface {
  13. GetPrompt() string
  14. }
  15. type HasImage interface {
  16. HasImage() bool
  17. }
  18. func GetFullRequestURL(baseURL string, requestURL string, channelType int) string {
  19. fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
  20. if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") {
  21. switch channelType {
  22. case constant.ChannelTypeOpenAI:
  23. fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/v1"))
  24. case constant.ChannelTypeAzure:
  25. fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/openai/deployments"))
  26. }
  27. }
  28. return fullRequestURL
  29. }
  30. func GetAPIVersion(c *gin.Context) string {
  31. query := c.Request.URL.Query()
  32. apiVersion := query.Get("api-version")
  33. if apiVersion == "" {
  34. apiVersion = c.GetString("api_version")
  35. }
  36. return apiVersion
  37. }
  38. func createTaskError(err error, code string, statusCode int, localError bool) *dto.TaskError {
  39. return &dto.TaskError{
  40. Code: code,
  41. Message: err.Error(),
  42. StatusCode: statusCode,
  43. LocalError: localError,
  44. Error: err,
  45. }
  46. }
  47. func storeTaskRequest(c *gin.Context, info *RelayInfo, action string, requestObj TaskSubmitReq) {
  48. info.Action = action
  49. c.Set("task_request", requestObj)
  50. }
  51. func validatePrompt(prompt string) *dto.TaskError {
  52. if strings.TrimSpace(prompt) == "" {
  53. return createTaskError(fmt.Errorf("prompt is required"), "invalid_request", http.StatusBadRequest, true)
  54. }
  55. return nil
  56. }
  57. func validateMultipartTaskRequest(c *gin.Context, info *RelayInfo, action string) (TaskSubmitReq, error) {
  58. var req TaskSubmitReq
  59. if _, err := c.MultipartForm(); err != nil {
  60. return req, err
  61. }
  62. formData := c.Request.PostForm
  63. req = TaskSubmitReq{
  64. Prompt: formData.Get("prompt"),
  65. Model: formData.Get("model"),
  66. Mode: formData.Get("mode"),
  67. Image: formData.Get("image"),
  68. Size: formData.Get("size"),
  69. Metadata: make(map[string]interface{}),
  70. }
  71. if durationStr := formData.Get("seconds"); durationStr != "" {
  72. if duration, err := strconv.Atoi(durationStr); err == nil {
  73. req.Duration = duration
  74. }
  75. }
  76. if images := formData["images"]; len(images) > 0 {
  77. req.Images = images
  78. }
  79. for key, values := range formData {
  80. if len(values) > 0 && !isKnownTaskField(key) {
  81. if intVal, err := strconv.Atoi(values[0]); err == nil {
  82. req.Metadata[key] = intVal
  83. } else if floatVal, err := strconv.ParseFloat(values[0], 64); err == nil {
  84. req.Metadata[key] = floatVal
  85. } else {
  86. req.Metadata[key] = values[0]
  87. }
  88. }
  89. }
  90. return req, nil
  91. }
  92. func ValidateMultipartDirect(c *gin.Context, info *RelayInfo) *dto.TaskError {
  93. form, err := common.ParseMultipartFormReusable(c)
  94. if err != nil {
  95. return createTaskError(err, "invalid_multipart_form", http.StatusBadRequest, true)
  96. }
  97. defer form.RemoveAll()
  98. prompts, ok := form.Value["prompt"]
  99. if !ok || len(prompts) == 0 {
  100. return createTaskError(fmt.Errorf("prompt field is required"), "missing_prompt", http.StatusBadRequest, true)
  101. }
  102. if taskErr := validatePrompt(prompts[0]); taskErr != nil {
  103. return taskErr
  104. }
  105. if _, ok := form.Value["model"]; !ok {
  106. return createTaskError(fmt.Errorf("model field is required"), "missing_model", http.StatusBadRequest, true)
  107. }
  108. action := constant.TaskActionTextGenerate
  109. if _, ok := form.File["input_reference"]; ok {
  110. action = constant.TaskActionGenerate
  111. }
  112. info.Action = action
  113. return nil
  114. }
  115. func isKnownTaskField(field string) bool {
  116. knownFields := map[string]bool{
  117. "prompt": true,
  118. "model": true,
  119. "mode": true,
  120. "image": true,
  121. "images": true,
  122. "size": true,
  123. "duration": true,
  124. "input_reference": true, // Sora 特有字段
  125. }
  126. return knownFields[field]
  127. }
  128. func ValidateBasicTaskRequest(c *gin.Context, info *RelayInfo, action string) *dto.TaskError {
  129. var err error
  130. contentType := c.GetHeader("Content-Type")
  131. var req TaskSubmitReq
  132. if strings.HasPrefix(contentType, "multipart/form-data") {
  133. req, err = validateMultipartTaskRequest(c, info, action)
  134. if err != nil {
  135. return createTaskError(err, "invalid_multipart_form", http.StatusBadRequest, true)
  136. }
  137. } else if err := common.UnmarshalBodyReusable(c, &req); err != nil {
  138. return createTaskError(err, "invalid_request", http.StatusBadRequest, true)
  139. }
  140. if taskErr := validatePrompt(req.Prompt); taskErr != nil {
  141. return taskErr
  142. }
  143. if len(req.Images) == 0 && strings.TrimSpace(req.Image) != "" {
  144. // 兼容单图上传
  145. req.Images = []string{req.Image}
  146. }
  147. if req.HasImage() {
  148. action = constant.TaskActionGenerate
  149. if info.ChannelType == constant.ChannelTypeVidu {
  150. // vidu 增加 首尾帧生视频和参考图生视频
  151. if len(req.Images) == 2 {
  152. action = constant.TaskActionFirstTailGenerate
  153. } else if len(req.Images) > 2 {
  154. action = constant.TaskActionReferenceGenerate
  155. }
  156. }
  157. }
  158. storeTaskRequest(c, info, action, req)
  159. return nil
  160. }