فهرست منبع

✨ feat: add CloseResponseBodyGracefully function to handle HTTP response body closure

CaIon 8 ماه پیش
والد
کامیت
3002659f47

+ 13 - 0
common/http.go

@@ -0,0 +1,13 @@
+package common
+
+import "net/http"
+
+func CloseResponseBodyGracefully(httpResponse *http.Response) {
+	if httpResponse == nil || httpResponse.Body == nil {
+		return
+	}
+	err := httpResponse.Body.Close()
+	if err != nil {
+		SysError("failed to close response body: " + err.Error())
+	}
+}

+ 1 - 4
relay/channel/ali/image.go

@@ -132,10 +132,7 @@ func aliImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela
 	if err != nil {
 		return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
 	}
-	err = resp.Body.Close()
-	if err != nil {
-		return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
-	}
+	common.CloseResponseBodyGracefully(resp)
 	err = json.Unmarshal(responseBody, &aliTaskResponse)
 	if err != nil {
 		return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil

+ 2 - 4
relay/channel/ali/rerank.go

@@ -4,6 +4,7 @@ import (
 	"encoding/json"
 	"io"
 	"net/http"
+	"one-api/common"
 	"one-api/dto"
 	relaycommon "one-api/relay/common"
 	"one-api/service"
@@ -35,10 +36,7 @@ func RerankHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayI
 	if err != nil {
 		return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
 	}
-	err = resp.Body.Close()
-	if err != nil {
-		return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
-	}
+	common.CloseResponseBodyGracefully(resp)
 
 	var aliResponse AliRerankResponse
 	err = json.Unmarshal(responseBody, &aliResponse)

+ 2 - 8
relay/channel/ali/text.go

@@ -45,10 +45,7 @@ func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorW
 		return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
 	}
 
-	err = resp.Body.Close()
-	if err != nil {
-		return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
-	}
+	common.CloseResponseBodyGracefully(resp)
 
 	if aliResponse.Code != "" {
 		return &dto.OpenAIErrorWithStatusCode{
@@ -199,10 +196,7 @@ func aliHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatus
 	if err != nil {
 		return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
 	}
-	err = resp.Body.Close()
-	if err != nil {
-		return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
-	}
+	common.CloseResponseBodyGracefully(resp)
 	err = json.Unmarshal(responseBody, &aliResponse)
 	if err != nil {
 		return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil

+ 2 - 8
relay/channel/baidu/relay-baidu.go

@@ -179,10 +179,7 @@ func baiduHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStat
 	if err != nil {
 		return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
 	}
-	err = resp.Body.Close()
-	if err != nil {
-		return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
-	}
+	common.CloseResponseBodyGracefully(resp)
 	err = json.Unmarshal(responseBody, &baiduResponse)
 	if err != nil {
 		return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
@@ -215,10 +212,7 @@ func baiduEmbeddingHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErro
 	if err != nil {
 		return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
 	}
-	err = resp.Body.Close()
-	if err != nil {
-		return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
-	}
+	common.CloseResponseBodyGracefully(resp)
 	err = json.Unmarshal(responseBody, &baiduResponse)
 	if err != nil {
 		return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil

+ 2 - 8
relay/channel/cloudflare/relay_cloudflare.go

@@ -94,10 +94,7 @@ func cfHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo)
 	if err != nil {
 		return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
 	}
-	err = resp.Body.Close()
-	if err != nil {
-		return service.OpenAIErrorWrapperLocal(err, "close_response_body_failed", http.StatusInternalServerError), nil
-	}
+	common.CloseResponseBodyGracefully(resp)
 	var response dto.TextResponse
 	err = json.Unmarshal(responseBody, &response)
 	if err != nil {
@@ -127,10 +124,7 @@ func cfSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayIn
 	if err != nil {
 		return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
 	}
-	err = resp.Body.Close()
-	if err != nil {
-		return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
-	}
+	common.CloseResponseBodyGracefully(resp)
 	err = json.Unmarshal(responseBody, &cfResp)
 	if err != nil {
 		return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil

+ 2 - 8
relay/channel/cohere/relay-cohere.go

@@ -173,10 +173,7 @@ func cohereHandler(c *gin.Context, resp *http.Response, modelName string, prompt
 	if err != nil {
 		return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
 	}
-	err = resp.Body.Close()
-	if err != nil {
-		return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
-	}
+	common.CloseResponseBodyGracefully(resp)
 	var cohereResp CohereResponseResult
 	err = json.Unmarshal(responseBody, &cohereResp)
 	if err != nil {
@@ -217,10 +214,7 @@ func cohereRerankHandler(c *gin.Context, resp *http.Response, info *relaycommon.
 	if err != nil {
 		return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
 	}
-	err = resp.Body.Close()
-	if err != nil {
-		return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
-	}
+	common.CloseResponseBodyGracefully(resp)
 	var cohereResp CohereRerankResponseResult
 	err = json.Unmarshal(responseBody, &cohereResp)
 	if err != nil {

+ 1 - 4
relay/channel/coze/relay-coze.go

@@ -48,10 +48,7 @@ func cozeChatHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela
 	if err != nil {
 		return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
 	}
-	err = resp.Body.Close()
-	if err != nil {
-		return service.OpenAIErrorWrapperLocal(err, "close_response_body_failed", http.StatusInternalServerError), nil
-	}
+	common.CloseResponseBodyGracefully(resp)
 	// convert coze response to openai response
 	var response dto.TextResponse
 	var cozeResponse CozeChatDetailResponse

+ 1 - 4
relay/channel/dify/relay-dify.go

@@ -257,10 +257,7 @@ func difyHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInf
 	if err != nil {
 		return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
 	}
-	err = resp.Body.Close()
-	if err != nil {
-		return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
-	}
+	common.CloseResponseBodyGracefully(resp)
 	err = json.Unmarshal(responseBody, &difyResponse)
 	if err != nil {
 		return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil

+ 1 - 4
relay/channel/gemini/relay-gemini-native.go

@@ -20,10 +20,7 @@ func GeminiTextGenerationHandler(c *gin.Context, resp *http.Response, info *rela
 	if err != nil {
 		return nil, service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
 	}
-	err = resp.Body.Close()
-	if err != nil {
-		return nil, service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
-	}
+	common.CloseResponseBodyGracefully(resp)
 
 	if common.DebugEnabled {
 		println(string(responseBody))

+ 1 - 4
relay/channel/gemini/relay-gemini.go

@@ -866,10 +866,7 @@ func GeminiChatHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re
 	if err != nil {
 		return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
 	}
-	err = resp.Body.Close()
-	if err != nil {
-		return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
-	}
+	common.CloseResponseBodyGracefully(resp)
 	if common.DebugEnabled {
 		println(string(responseBody))
 	}

+ 3 - 6
relay/channel/mokaai/relay-mokaai.go

@@ -5,6 +5,7 @@ import (
 	"github.com/gin-gonic/gin"
 	"io"
 	"net/http"
+	"one-api/common"
 	"one-api/dto"
 	"one-api/service"
 )
@@ -26,7 +27,7 @@ func embeddingRequestOpenAI2Moka(request dto.GeneralOpenAIRequest) *dto.Embeddin
 	}
 	return &dto.EmbeddingRequest{
 		Input: input,
-		Model:  request.Model,
+		Model: request.Model,
 	}
 }
 
@@ -53,10 +54,7 @@ func mokaEmbeddingHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIError
 	if err != nil {
 		return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
 	}
-	err = resp.Body.Close()
-	if err != nil {
-		return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
-	}
+	common.CloseResponseBodyGracefully(resp)
 	err = json.Unmarshal(responseBody, &baiduResponse)
 	if err != nil {
 		return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
@@ -80,4 +78,3 @@ func mokaEmbeddingHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIError
 	_, err = c.Writer.Write(jsonResponse)
 	return nil, &fullTextResponse.Usage
 }
-

+ 3 - 8
relay/channel/ollama/relay-ollama.go

@@ -7,6 +7,7 @@ import (
 	"github.com/gin-gonic/gin"
 	"io"
 	"net/http"
+	"one-api/common"
 	"one-api/dto"
 	"one-api/service"
 	"strings"
@@ -88,10 +89,7 @@ func ollamaEmbeddingHandler(c *gin.Context, resp *http.Response, promptTokens in
 	if err != nil {
 		return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
 	}
-	err = resp.Body.Close()
-	if err != nil {
-		return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
-	}
+	common.CloseResponseBodyGracefully(resp)
 	err = json.Unmarshal(responseBody, &ollamaEmbeddingResponse)
 	if err != nil {
 		return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
@@ -141,10 +139,7 @@ func ollamaEmbeddingHandler(c *gin.Context, resp *http.Response, promptTokens in
 	if err != nil {
 		return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
 	}
-	err = resp.Body.Close()
-	if err != nil {
-		return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
-	}
+	common.CloseResponseBodyGracefully(resp)
 	return nil, usage
 }
 

+ 7 - 13
relay/channel/openai/relay-openai.go

@@ -202,10 +202,7 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayI
 	if err != nil {
 		return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
 	}
-	err = resp.Body.Close()
-	if err != nil {
-		return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
-	}
+	common.CloseResponseBodyGracefully(resp)
 	err = common.DecodeJson(responseBody, &simpleResponse)
 	if err != nil {
 		return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
@@ -269,7 +266,7 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayI
 		//return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
 		common.SysError("error copying response body: " + err.Error())
 	}
-	resp.Body.Close()
+	common.CloseResponseBodyGracefully(resp)
 	return nil, &simpleResponse.Usage
 }
 
@@ -280,7 +277,7 @@ func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
 	// if the upstream returns a specific status code, once the upstream has already written the header,
 	// the subsequent failure of the response body should be regarded as a non-recoverable error,
 	// and can be terminated directly.
-	defer resp.Body.Close()
+	defer common.CloseResponseBodyGracefully(resp)
 	usage := &dto.Usage{}
 	usage.PromptTokens = info.PromptTokens
 	usage.TotalTokens = info.PromptTokens
@@ -306,7 +303,7 @@ func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
 	if err != nil {
 		return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
 	}
-	err = resp.Body.Close()
+	common.CloseResponseBodyGracefully(resp)
 	if err != nil {
 		return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
 	}
@@ -324,7 +321,7 @@ func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
 	if err != nil {
 		return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
 	}
-	resp.Body.Close()
+	common.CloseResponseBodyGracefully(resp)
 
 	usage := &dto.Usage{}
 	usage.PromptTokens = audioTokens
@@ -605,10 +602,7 @@ func OpenaiHandlerWithUsage(c *gin.Context, resp *http.Response, info *relaycomm
 	if err != nil {
 		return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
 	}
-	err = resp.Body.Close()
-	if err != nil {
-		return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
-	}
+	common.CloseResponseBodyGracefully(resp)
 	// Reset response body
 	resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
 	// We shouldn't set the header before we parse the response body, because the parse part may fail.
@@ -625,7 +619,7 @@ func OpenaiHandlerWithUsage(c *gin.Context, resp *http.Response, info *relaycomm
 	if err != nil {
 		common.SysError("error copying response body: " + err.Error())
 	}
-	_ = resp.Body.Close()
+	common.CloseResponseBodyGracefully(resp)
 
 	// Once we've written to the client, we should not return errors anymore
 	// because the upstream has already consumed resources and returned content

+ 1 - 4
relay/channel/openai/relay_responses.go

@@ -22,10 +22,7 @@ func OaiResponsesHandler(c *gin.Context, resp *http.Response, info *relaycommon.
 	if err != nil {
 		return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
 	}
-	err = resp.Body.Close()
-	if err != nil {
-		return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
-	}
+	common.CloseResponseBodyGracefully(resp)
 	err = common.DecodeJson(responseBody, &responsesResponse)
 	if err != nil {
 		return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil

+ 2 - 10
relay/channel/palm/relay-palm.go

@@ -83,12 +83,7 @@ func palmStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWit
 			stopChan <- true
 			return
 		}
-		err = resp.Body.Close()
-		if err != nil {
-			common.SysError("error closing stream response: " + err.Error())
-			stopChan <- true
-			return
-		}
+		common.CloseResponseBodyGracefully(resp)
 		var palmResponse PaLMChatResponse
 		err = json.Unmarshal(responseBody, &palmResponse)
 		if err != nil {
@@ -134,10 +129,7 @@ func palmHandler(c *gin.Context, resp *http.Response, promptTokens int, model st
 	if err != nil {
 		return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
 	}
-	err = resp.Body.Close()
-	if err != nil {
-		return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
-	}
+	common.CloseResponseBodyGracefully(resp)
 	var palmResponse PaLMChatResponse
 	err = json.Unmarshal(responseBody, &palmResponse)
 	if err != nil {

+ 2 - 4
relay/channel/siliconflow/relay-siliconflow.go

@@ -5,6 +5,7 @@ import (
 	"github.com/gin-gonic/gin"
 	"io"
 	"net/http"
+	"one-api/common"
 	"one-api/dto"
 	"one-api/service"
 )
@@ -14,10 +15,7 @@ func siliconflowRerankHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIE
 	if err != nil {
 		return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
 	}
-	err = resp.Body.Close()
-	if err != nil {
-		return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
-	}
+	common.CloseResponseBodyGracefully(resp)
 	var siliconflowResp SFRerankResponse
 	err = json.Unmarshal(responseBody, &siliconflowResp)
 	if err != nil {

+ 1 - 4
relay/channel/tencent/relay-tencent.go

@@ -138,10 +138,7 @@ func tencentHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithSt
 	if err != nil {
 		return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
 	}
-	err = resp.Body.Close()
-	if err != nil {
-		return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
-	}
+	common.CloseResponseBodyGracefully(resp)
 	err = json.Unmarshal(responseBody, &tencentSb)
 	if err != nil {
 		return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil

+ 1 - 4
relay/channel/xai/text.go

@@ -110,10 +110,7 @@ func xAIHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo
 	if err != nil {
 		return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
 	}
-	err = resp.Body.Close()
-	if err != nil {
-		return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
-	}
+	common.CloseResponseBodyGracefully(resp)
 
 	return nil, &response.Usage
 }

+ 1 - 4
relay/channel/zhipu/relay-zhipu.go

@@ -223,10 +223,7 @@ func zhipuHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStat
 	if err != nil {
 		return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
 	}
-	err = resp.Body.Close()
-	if err != nil {
-		return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
-	}
+	common.CloseResponseBodyGracefully(resp)
 	err = json.Unmarshal(responseBody, &zhipuResponse)
 	if err != nil {
 		return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil

+ 1 - 4
relay/common_handler/rerank.go

@@ -16,10 +16,7 @@ func RerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo
 	if err != nil {
 		return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
 	}
-	err = resp.Body.Close()
-	if err != nil {
-		return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
-	}
+	common.CloseResponseBodyGracefully(resp)
 	if common.DebugEnabled {
 		println("reranker response body: ", string(responseBody))
 	}

+ 1 - 4
service/error.go

@@ -90,10 +90,7 @@ func RelayErrorHandler(resp *http.Response, showBodyWhenFail bool) (errWithStatu
 	if err != nil {
 		return
 	}
-	err = resp.Body.Close()
-	if err != nil {
-		return
-	}
+	common.CloseResponseBodyGracefully(resp)
 	var errResponse dto.GeneralErrorResponse
 	err = json.Unmarshal(responseBody, &errResponse)
 	if err != nil {

+ 1 - 4
service/midjourney.go

@@ -228,10 +228,7 @@ func DoMidjourneyHttpRequest(c *gin.Context, timeout time.Duration, fullRequestU
 	if err != nil {
 		return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "read_response_body_failed", statusCode), nullBytes, err
 	}
-	err = resp.Body.Close()
-	if err != nil {
-		return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "close_response_body_failed", statusCode), responseBody, err
-	}
+	common.CloseResponseBodyGracefully(resp)
 	respStr := string(responseBody)
 	log.Printf("respStr: %s", respStr)
 	if respStr == "" {