Quellcode durchsuchen

feat: add openai sdk create

feitianbubu vor 4 Monaten
Ursprung
Commit
5f36e32821
2 geänderte Dateien mit 57 neuen und 21 gelöschten Zeilen
  1. 16 8
      middleware/distributor.go
  2. 41 13
      relay/common/relay_utils.go

+ 16 - 8
middleware/distributor.go

@@ -174,14 +174,22 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
 		relayMode := relayconstant.RelayModeUnknown
 		if c.Request.Method == http.MethodPost {
 			relayMode = relayconstant.RelayModeVideoSubmit
-			form, err := common.ParseMultipartFormReusable(c)
-			if err != nil {
-				return nil, false, errors.New("无效的video请求, " + err.Error())
-			}
-			defer form.RemoveAll()
-			if form != nil {
-				if values, ok := form.Value["model"]; ok && len(values) > 0 {
-					modelRequest.Model = values[0]
+			contentType := c.Request.Header.Get("Content-Type")
+			if strings.HasPrefix(contentType, "multipart/form-data") {
+				form, err := common.ParseMultipartFormReusable(c)
+				if err != nil {
+					return nil, false, errors.New("无效的video请求, " + err.Error())
+				}
+				defer form.RemoveAll()
+				if form != nil {
+					if values, ok := form.Value["model"]; ok && len(values) > 0 {
+						modelRequest.Model = values[0]
+					}
+				}
+			} else if strings.HasPrefix(contentType, "application/json") {
+				err = common.UnmarshalBodyReusable(c, &modelRequest)
+				if err != nil {
+					return nil, false, errors.New("无效的video请求, " + err.Error())
 				}
 			}
 		} else if c.Request.Method == http.MethodGet {

+ 41 - 13
relay/common/relay_utils.go

@@ -106,25 +106,53 @@ func validateMultipartTaskRequest(c *gin.Context, info *RelayInfo, action string
 }
 
 func ValidateMultipartDirect(c *gin.Context, info *RelayInfo) *dto.TaskError {
-	form, err := common.ParseMultipartFormReusable(c)
-	if err != nil {
-		return createTaskError(err, "invalid_multipart_form", http.StatusBadRequest, true)
-	}
-	defer form.RemoveAll()
+	contentType := c.GetHeader("Content-Type")
+	var prompt string
+	var hasInputReference bool
+
+	if strings.HasPrefix(contentType, "multipart/form-data") {
+		form, err := common.ParseMultipartFormReusable(c)
+		if err != nil {
+			return createTaskError(err, "invalid_multipart_form", http.StatusBadRequest, true)
+		}
+		defer form.RemoveAll()
+
+		prompts, ok := form.Value["prompt"]
+		if !ok || len(prompts) == 0 {
+			return createTaskError(fmt.Errorf("prompt field is required"), "missing_prompt", http.StatusBadRequest, true)
+		}
+		prompt = prompts[0]
+
+		if _, ok := form.Value["model"]; !ok {
+			return createTaskError(fmt.Errorf("model field is required"), "missing_model", http.StatusBadRequest, true)
+		}
+
+		if _, ok := form.File["input_reference"]; ok {
+			hasInputReference = true
+		}
+	} else {
+		var req TaskSubmitReq
+		if err := common.UnmarshalBodyReusable(c, &req); err != nil {
+			return createTaskError(err, "invalid_json", http.StatusBadRequest, true)
+		}
+
+		prompt = req.Prompt
 
-	prompts, ok := form.Value["prompt"]
-	if !ok || len(prompts) == 0 {
-		return createTaskError(fmt.Errorf("prompt field is required"), "missing_prompt", http.StatusBadRequest, true)
+		if strings.TrimSpace(req.Model) == "" {
+			return createTaskError(fmt.Errorf("model field is required"), "missing_model", http.StatusBadRequest, true)
+		}
+
+		if req.HasImage() {
+			hasInputReference = true
+		}
 	}
-	if taskErr := validatePrompt(prompts[0]); taskErr != nil {
+
+	if taskErr := validatePrompt(prompt); taskErr != nil {
 		return taskErr
 	}
 
-	if _, ok := form.Value["model"]; !ok {
-		return createTaskError(fmt.Errorf("model field is required"), "missing_model", http.StatusBadRequest, true)
-	}
 	action := constant.TaskActionTextGenerate
-	if _, ok := form.File["input_reference"]; ok {
+	if hasInputReference {
 		action = constant.TaskActionGenerate
 	}
 	info.Action = action