Przeglądaj źródła

feat: 上游渠道为OpenAI渠道类型时,透传请求 (close #532)

1808837298@qq.com 1 rok temu
rodzic
commit
8b67664995
1 zmienionych plików z 18 dodań i 8 usunięć
  1. 18 8
      relay/relay-text.go

+ 18 - 8
relay/relay-text.go

@@ -76,6 +76,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
 	}
 
 	// map model name
+	isModelMapped := false
 	modelMapping := c.GetString("model_mapping")
 	//isModelMapped := false
 	if modelMapping != "" && modelMapping != "{}" {
@@ -85,6 +86,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
 			return service.OpenAIErrorWrapperLocal(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
 		}
 		if modelMap[textRequest.Model] != "" {
+			isModelMapped = true
 			textRequest.Model = modelMap[textRequest.Model]
 			// set upstream model name
 			//isModelMapped = true
@@ -159,15 +161,23 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
 	adaptor.Init(relayInfo)
 	var requestBody io.Reader
 
-	convertedRequest, err := adaptor.ConvertRequest(c, relayInfo, textRequest)
-	if err != nil {
-		return service.OpenAIErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError)
-	}
-	jsonData, err := json.Marshal(convertedRequest)
-	if err != nil {
-		return service.OpenAIErrorWrapperLocal(err, "json_marshal_failed", http.StatusInternalServerError)
+	if relayInfo.ChannelType == common.ChannelTypeOpenAI && !isModelMapped {
+		body, err := common.GetRequestBody(c)
+		if err != nil {
+			return service.OpenAIErrorWrapperLocal(err, "get_request_body_failed", http.StatusInternalServerError)
+		}
+		requestBody = bytes.NewBuffer(body)
+	} else {
+		convertedRequest, err := adaptor.ConvertRequest(c, relayInfo, textRequest)
+		if err != nil {
+			return service.OpenAIErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError)
+		}
+		jsonData, err := json.Marshal(convertedRequest)
+		if err != nil {
+			return service.OpenAIErrorWrapperLocal(err, "json_marshal_failed", http.StatusInternalServerError)
+		}
+		requestBody = bytes.NewBuffer(jsonData)
 	}
-	requestBody = bytes.NewBuffer(jsonData)
 
 	statusCodeMappingStr := c.GetString("status_code_mapping")
 	resp, err := adaptor.DoRequest(c, relayInfo, requestBody)