Browse Source

Merge branch 'main' into ui/refactor

Apple\Apple 9 months ago
parent
commit
19cd98cb99
3 changed files with 36 additions and 69 deletions
  1. 1 1
      dto/openai_request.go
  2. 6 5
      relay/channel/api_request.go
  3. 29 63
      relay/relay-image.go

+ 1 - 1
dto/openai_request.go

@@ -43,7 +43,7 @@ type GeneralOpenAIRequest struct {
 	ResponseFormat   *ResponseFormat   `json:"response_format,omitempty"`
 	EncodingFormat   any               `json:"encoding_format,omitempty"`
 	Seed             float64           `json:"seed,omitempty"`
-	ParallelTooCalls bool              `json:"parallel_tool_calls,omitempty"`
+	ParallelTooCalls *bool             `json:"parallel_tool_calls,omitempty"`
 	Tools            []ToolCallRequest `json:"tools,omitempty"`
 	ToolChoice       any               `json:"tool_choice,omitempty"`
 	User             string            `json:"user,omitempty"`

+ 6 - 5
relay/channel/api_request.go

@@ -122,11 +122,13 @@ func doRequest(c *gin.Context, req *http.Request, info *common.RelayInfo) (*http
 	var pingerWg sync.WaitGroup
 	if info.IsStream {
 		helper.SetEventStreamHeaders(c)
-		pingInterval := time.Duration(generalSettings.PingIntervalSeconds) * time.Second
-		var pingerCtx context.Context
-		pingerCtx, stopPinger = context.WithCancel(c.Request.Context())
 
 		if pingEnabled {
+			pingInterval := time.Duration(generalSettings.PingIntervalSeconds) * time.Second
+			var pingerCtx context.Context
+			pingerCtx, stopPinger = context.WithCancel(c.Request.Context())
+			// 退出时清理 pingerCtx 防止泄露
+			defer stopPinger()
 			pingerWg.Add(1)
 			gopool.Go(func() {
 				defer pingerWg.Done()
@@ -166,9 +168,8 @@ func doRequest(c *gin.Context, req *http.Request, info *common.RelayInfo) (*http
 	}
 
 	resp, err := client.Do(req)
-	// request结束后停止ping
+	// request结束后等待 ping goroutine 完成
 	if info.IsStream && pingEnabled {
-		stopPinger()
 		pingerWg.Wait()
 	}
 	if err != nil {

+ 29 - 63
relay/relay-image.go

@@ -46,11 +46,23 @@ func getAndValidImageRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.
 		if err != nil {
 			return nil, err
 		}
+
+		if imageRequest.Model == "" {
+			imageRequest.Model = "dall-e-3"
+		}
+
+		if strings.Contains(imageRequest.Size, "×") {
+			return nil, errors.New("size an unexpected error occurred in the parameter, please use 'x' instead of the multiplication sign '×'")
+		}
+
 		// Not "256x256", "512x512", or "1024x1024"
 		if imageRequest.Model == "dall-e-2" || imageRequest.Model == "dall-e" {
 			if imageRequest.Size != "" && imageRequest.Size != "256x256" && imageRequest.Size != "512x512" && imageRequest.Size != "1024x1024" {
 				return nil, errors.New("size must be one of 256x256, 512x512, or 1024x1024 for dall-e-2 or dall-e")
 			}
+			if imageRequest.Size == "" {
+				imageRequest.Size = "1024x1024"
+			}
 		} else if imageRequest.Model == "dall-e-3" {
 			if imageRequest.Size != "" && imageRequest.Size != "1024x1024" && imageRequest.Size != "1024x1792" && imageRequest.Size != "1792x1024" {
 				return nil, errors.New("size must be one of 1024x1024, 1024x1792 or 1792x1024 for dall-e-3")
@@ -58,74 +70,24 @@ func getAndValidImageRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.
 			if imageRequest.Quality == "" {
 				imageRequest.Quality = "standard"
 			}
-			// N should between 1 and 10
-			//if imageRequest.N != 0 && (imageRequest.N < 1 || imageRequest.N > 10) {
-			//	return service.OpenAIErrorWrapper(errors.New("n must be between 1 and 10"), "invalid_field_value", http.StatusBadRequest)
-			//}
+			if imageRequest.Size == "" {
+				imageRequest.Size = "1024x1024"
+			}
+		} else if imageRequest.Model == "gpt-image-1" {
+			if imageRequest.Quality == "" {
+				imageRequest.Quality = "auto"
+			}
 		}
-	}
-
-	if imageRequest.Prompt == "" {
-		return nil, errors.New("prompt is required")
-	}
 
-	if imageRequest.Model == "" {
-		imageRequest.Model = "dall-e-2"
-	}
-	if strings.Contains(imageRequest.Size, "×") {
-		return nil, errors.New("size an unexpected error occurred in the parameter, please use 'x' instead of the multiplication sign '×'")
-	}
-	if imageRequest.N == 0 {
-		imageRequest.N = 1
-	}
-	if imageRequest.Size == "" {
-		imageRequest.Size = "1024x1024"
-	}
-
-	err := common.UnmarshalBodyReusable(c, imageRequest)
-	if err != nil {
-		return nil, err
-	}
-	if imageRequest.Prompt == "" {
-		return nil, errors.New("prompt is required")
-	}
-	if strings.Contains(imageRequest.Size, "×") {
-		return nil, errors.New("size an unexpected error occurred in the parameter, please use 'x' instead of the multiplication sign '×'")
-	}
-	if imageRequest.N == 0 {
-		imageRequest.N = 1
-	}
-	if imageRequest.Size == "" {
-		imageRequest.Size = "1024x1024"
-	}
-	if imageRequest.Model == "" {
-		imageRequest.Model = "dall-e-2"
-	}
-	// x.ai grok-2-image not support size, quality or style
-	if imageRequest.Size == "empty" {
-		imageRequest.Size = ""
-	}
-
-	// Not "256x256", "512x512", or "1024x1024"
-	if imageRequest.Model == "dall-e-2" || imageRequest.Model == "dall-e" {
-		if imageRequest.Size != "" && imageRequest.Size != "256x256" && imageRequest.Size != "512x512" && imageRequest.Size != "1024x1024" {
-			return nil, errors.New("size must be one of 256x256, 512x512, or 1024x1024, dall-e-3 1024x1792 or 1792x1024")
+		if imageRequest.Prompt == "" {
+			return nil, errors.New("prompt is required")
 		}
-	} else if imageRequest.Model == "dall-e-3" {
-		if imageRequest.Size != "" && imageRequest.Size != "1024x1024" && imageRequest.Size != "1024x1792" && imageRequest.Size != "1792x1024" {
-			return nil, errors.New("size must be one of 256x256, 512x512, or 1024x1024, dall-e-3 1024x1792 or 1792x1024")
-		}
-		if imageRequest.Quality == "" {
-			imageRequest.Quality = "standard"
+
+		if imageRequest.N == 0 {
+			imageRequest.N = 1
 		}
-		//if imageRequest.N != 1 {
-		//	return nil, errors.New("n must be 1")
-		//}
 	}
-	// N should between 1 and 10
-	//if imageRequest.N != 0 && (imageRequest.N < 1 || imageRequest.N > 10) {
-	//	return service.OpenAIErrorWrapper(errors.New("n must be between 1 and 10"), "invalid_field_value", http.StatusBadRequest)
-	//}
+
 	if setting.ShouldCheckPromptSensitive() {
 		words, err := service.CheckSensitiveInput(imageRequest.Prompt)
 		if err != nil {
@@ -229,6 +191,10 @@ func ImageHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
 		requestBody = bytes.NewBuffer(jsonData)
 	}
 
+	if common.DebugEnabled {
+		println(fmt.Sprintf("image request body: %s", requestBody))
+	}
+
 	statusCodeMappingStr := c.GetString("status_code_mapping")
 
 	resp, err := adaptor.DoRequest(c, relayInfo, requestBody)