Ver código fonte

feat: refactor request body handling to use BodyStorage for improved efficiency

CaIon 3 semanas atrás
pai
commit
197b89ea58

+ 6 - 0
common/body_storage.go

@@ -302,6 +302,12 @@ func CreateBodyStorageFromReader(reader io.Reader, contentLength int64, maxBytes
 	return storage, nil
 }
 
+// ReaderOnly wraps an io.Reader to hide io.Closer, preventing http.NewRequest
+// from type-asserting io.ReadCloser and closing the underlying BodyStorage.
+func ReaderOnly(r io.Reader) io.Reader {
+	return struct{ io.Reader }{r}
+}
+
 // CleanupOldCacheFiles 清理旧的缓存文件(用于启动时清理残留)
 func CleanupOldCacheFiles() {
 	// 使用统一的缓存管理

+ 32 - 43
common/gin.go

@@ -33,14 +33,14 @@ func IsRequestBodyTooLargeError(err error) bool {
 	return errors.As(err, &mbe)
 }
 
-func GetRequestBody(c *gin.Context) ([]byte, error) {
+func GetRequestBody(c *gin.Context) (io.Seeker, error) {
 	// 首先检查是否有 BodyStorage 缓存
 	if storage, exists := c.Get(KeyBodyStorage); exists && storage != nil {
 		if bs, ok := storage.(BodyStorage); ok {
 			if _, err := bs.Seek(0, io.SeekStart); err != nil {
 				return nil, fmt.Errorf("failed to seek body storage: %w", err)
 			}
-			return bs.Bytes()
+			return bs, nil
 		}
 	}
 
@@ -48,7 +48,12 @@ func GetRequestBody(c *gin.Context) ([]byte, error) {
 	cached, exists := c.Get(KeyRequestBody)
 	if exists && cached != nil {
 		if b, ok := cached.([]byte); ok {
-			return b, nil
+			bs, err := CreateBodyStorage(b)
+			if err != nil {
+				return nil, err
+			}
+			c.Set(KeyBodyStorage, bs)
+			return bs, nil
 		}
 	}
 
@@ -74,47 +79,20 @@ func GetRequestBody(c *gin.Context) ([]byte, error) {
 	// 缓存存储对象
 	c.Set(KeyBodyStorage, storage)
 
-	// 获取字节数据
-	body, err := storage.Bytes()
-	if err != nil {
-		return nil, err
-	}
-
-	// 同时设置旧的缓存键以保持兼容性
-	c.Set(KeyRequestBody, body)
-
-	return body, nil
+	return storage, nil
 }
 
 // GetBodyStorage 获取请求体存储对象(用于需要多次读取的场景)
 func GetBodyStorage(c *gin.Context) (BodyStorage, error) {
-	// 检查是否已有存储
-	if storage, exists := c.Get(KeyBodyStorage); exists && storage != nil {
-		if bs, ok := storage.(BodyStorage); ok {
-			if _, err := bs.Seek(0, io.SeekStart); err != nil {
-				return nil, fmt.Errorf("failed to seek body storage: %w", err)
-			}
-			return bs, nil
-		}
-	}
-
-	// 如果没有,调用 GetRequestBody 创建存储
-	_, err := GetRequestBody(c)
+	seeker, err := GetRequestBody(c)
 	if err != nil {
 		return nil, err
 	}
-
-	// 再次获取存储
-	if storage, exists := c.Get(KeyBodyStorage); exists && storage != nil {
-		if bs, ok := storage.(BodyStorage); ok {
-			if _, err := bs.Seek(0, io.SeekStart); err != nil {
-				return nil, fmt.Errorf("failed to seek body storage: %w", err)
-			}
-			return bs, nil
-		}
+	bs, ok := seeker.(BodyStorage)
+	if !ok {
+		return nil, errors.New("unexpected body storage type")
 	}
-
-	return nil, errors.New("failed to get body storage")
+	return bs, nil
 }
 
 // CleanupBodyStorage 清理请求体存储(应在请求结束时调用)
@@ -128,13 +106,14 @@ func CleanupBodyStorage(c *gin.Context) {
 }
 
 func UnmarshalBodyReusable(c *gin.Context, v any) error {
-	requestBody, err := GetRequestBody(c)
+	storage, err := GetBodyStorage(c)
+	if err != nil {
+		return err
+	}
+	requestBody, err := storage.Bytes()
 	if err != nil {
 		return err
 	}
-	//if DebugEnabled {
-	//	println("UnmarshalBodyReusable request body:", string(requestBody))
-	//}
 	contentType := c.Request.Header.Get("Content-Type")
 	if strings.HasPrefix(contentType, "application/json") {
 		err = Unmarshal(requestBody, v)
@@ -150,7 +129,10 @@ func UnmarshalBodyReusable(c *gin.Context, v any) error {
 		return err
 	}
 	// Reset request body
-	c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
+	if _, seekErr := storage.Seek(0, io.SeekStart); seekErr != nil {
+		return seekErr
+	}
+	c.Request.Body = io.NopCloser(storage)
 	return nil
 }
 
@@ -252,7 +234,11 @@ func init() {
 }
 
 func ParseMultipartFormReusable(c *gin.Context) (*multipart.Form, error) {
-	requestBody, err := GetRequestBody(c)
+	storage, err := GetBodyStorage(c)
+	if err != nil {
+		return nil, err
+	}
+	requestBody, err := storage.Bytes()
 	if err != nil {
 		return nil, err
 	}
@@ -270,7 +256,10 @@ func ParseMultipartFormReusable(c *gin.Context) (*multipart.Form, error) {
 	}
 
 	// Reset request body
-	c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
+	if _, seekErr := storage.Seek(0, io.SeekStart); seekErr != nil {
+		return nil, seekErr
+	}
+	c.Request.Body = io.NopCloser(storage)
 	return form, nil
 }
 

+ 4 - 5
controller/relay.go

@@ -1,7 +1,6 @@
 package controller
 
 import (
-	"bytes"
 	"errors"
 	"fmt"
 	"io"
@@ -193,7 +192,7 @@ func Relay(c *gin.Context, relayFormat types.RelayFormat) {
 		}
 
 		addUsedChannel(c, channel.Id)
-		requestBody, bodyErr := common.GetRequestBody(c)
+		bodyStorage, bodyErr := common.GetBodyStorage(c)
 		if bodyErr != nil {
 			// Ensure consistent 413 for oversized bodies even when error occurs later (e.g., retry path)
 			if common.IsRequestBodyTooLargeError(bodyErr) || errors.Is(bodyErr, common.ErrRequestBodyTooLarge) {
@@ -203,7 +202,7 @@ func Relay(c *gin.Context, relayFormat types.RelayFormat) {
 			}
 			break
 		}
-		c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
+		c.Request.Body = io.NopCloser(bodyStorage)
 
 		switch relayFormat {
 		case types.RelayFormatOpenAIRealtime:
@@ -483,7 +482,7 @@ func RelayTask(c *gin.Context) {
 		logger.LogInfo(c, fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, retryParam.GetRetry()))
 		//middleware.SetupContextForSelectedChannel(c, channel, originalModel)
 
-		requestBody, err := common.GetRequestBody(c)
+		bodyStorage, err := common.GetBodyStorage(c)
 		if err != nil {
 			if common.IsRequestBodyTooLargeError(err) || errors.Is(err, common.ErrRequestBodyTooLarge) {
 				taskErr = service.TaskErrorWrapperLocal(err, "read_request_body_failed", http.StatusRequestEntityTooLarge)
@@ -492,7 +491,7 @@ func RelayTask(c *gin.Context) {
 			}
 			break
 		}
-		c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
+		c.Request.Body = io.NopCloser(bodyStorage)
 		taskErr = taskRelayHandler(c, relayInfo)
 	}
 	useChannel := c.GetStringSlice("use_channel")

+ 5 - 1
relay/channel/aws/relay-aws.go

@@ -165,10 +165,14 @@ func doAwsClientRequest(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor,
 // buildAwsRequestBody prepares the payload for AWS requests, applying passthrough rules when enabled.
 func buildAwsRequestBody(c *gin.Context, info *relaycommon.RelayInfo, awsClaudeReq any) ([]byte, error) {
 	if model_setting.GetGlobalSettings().PassThroughRequestEnabled || info.ChannelSetting.PassThroughBodyEnabled {
-		body, err := common.GetRequestBody(c)
+		storage, err := common.GetBodyStorage(c)
 		if err != nil {
 			return nil, errors.Wrap(err, "get request body for pass-through fail")
 		}
+		body, err := storage.Bytes()
+		if err != nil {
+			return nil, errors.Wrap(err, "get request body bytes fail")
+		}
 		var data map[string]interface{}
 		if err := common.Unmarshal(body, &data); err != nil {
 			return nil, errors.Wrap(err, "pass-through unmarshal request body fail")

+ 2 - 3
relay/channel/task/sora/adaptor.go

@@ -1,7 +1,6 @@
 package sora
 
 import (
-	"bytes"
 	"fmt"
 	"io"
 	"net/http"
@@ -104,11 +103,11 @@ func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info
 }
 
 func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error) {
-	cachedBody, err := common.GetRequestBody(c)
+	storage, err := common.GetBodyStorage(c)
 	if err != nil {
 		return nil, errors.Wrap(err, "get_request_body_failed")
 	}
-	return bytes.NewReader(cachedBody), nil
+	return common.ReaderOnly(storage), nil
 }
 
 // DoRequest delegates to common helper.

+ 2 - 2
relay/claude_handler.go

@@ -129,11 +129,11 @@ func ClaudeHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
 
 	var requestBody io.Reader
 	if model_setting.GetGlobalSettings().PassThroughRequestEnabled || info.ChannelSetting.PassThroughBodyEnabled {
-		body, err := common.GetRequestBody(c)
+		storage, err := common.GetBodyStorage(c)
 		if err != nil {
 			return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
 		}
-		requestBody = bytes.NewBuffer(body)
+		requestBody = common.ReaderOnly(storage)
 	} else {
 		convertedRequest, err := adaptor.ConvertClaudeRequest(c, info, request)
 		if err != nil {

+ 5 - 3
relay/compatible_handler.go

@@ -100,14 +100,16 @@ func TextHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types
 	var requestBody io.Reader
 
 	if passThroughGlobal || info.ChannelSetting.PassThroughBodyEnabled {
-		body, err := common.GetRequestBody(c)
+		storage, err := common.GetBodyStorage(c)
 		if err != nil {
 			return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
 		}
 		if common.DebugEnabled {
-			println("requestBody: ", string(body))
+			if debugBytes, bErr := storage.Bytes(); bErr == nil {
+				println("requestBody: ", string(debugBytes))
+			}
 		}
-		requestBody = bytes.NewBuffer(body)
+		requestBody = common.ReaderOnly(storage)
 	} else {
 		convertedRequest, err := adaptor.ConvertOpenAIRequest(c, info, request)
 		if err != nil {

+ 2 - 2
relay/gemini_handler.go

@@ -138,11 +138,11 @@ func GeminiHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
 
 	var requestBody io.Reader
 	if model_setting.GetGlobalSettings().PassThroughRequestEnabled || info.ChannelSetting.PassThroughBodyEnabled {
-		body, err := common.GetRequestBody(c)
+		storage, err := common.GetBodyStorage(c)
 		if err != nil {
 			return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
 		}
-		requestBody = bytes.NewReader(body)
+		requestBody = common.ReaderOnly(storage)
 	} else {
 		// 使用 ConvertGeminiRequest 转换请求格式
 		convertedRequest, err := adaptor.ConvertGeminiRequest(c, info, request)

+ 2 - 2
relay/image_handler.go

@@ -47,11 +47,11 @@ func ImageHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *type
 	var requestBody io.Reader
 
 	if model_setting.GetGlobalSettings().PassThroughRequestEnabled || info.ChannelSetting.PassThroughBodyEnabled {
-		body, err := common.GetRequestBody(c)
+		storage, err := common.GetBodyStorage(c)
 		if err != nil {
 			return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
 		}
-		requestBody = bytes.NewBuffer(body)
+		requestBody = common.ReaderOnly(storage)
 	} else {
 		convertedRequest, err := adaptor.ConvertImageRequest(c, info, *request)
 		if err != nil {

+ 2 - 2
relay/rerank_handler.go

@@ -43,11 +43,11 @@ func RerankHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
 
 	var requestBody io.Reader
 	if model_setting.GetGlobalSettings().PassThroughRequestEnabled || info.ChannelSetting.PassThroughBodyEnabled {
-		body, err := common.GetRequestBody(c)
+		storage, err := common.GetBodyStorage(c)
 		if err != nil {
 			return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
 		}
-		requestBody = bytes.NewBuffer(body)
+		requestBody = common.ReaderOnly(storage)
 	} else {
 		convertedRequest, err := adaptor.ConvertRerankRequest(c, info.RelayMode, *request)
 		if err != nil {

+ 2 - 2
relay/responses_handler.go

@@ -72,11 +72,11 @@ func ResponsesHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *
 	adaptor.Init(info)
 	var requestBody io.Reader
 	if model_setting.GetGlobalSettings().PassThroughRequestEnabled || info.ChannelSetting.PassThroughBodyEnabled {
-		body, err := common.GetRequestBody(c)
+		storage, err := common.GetBodyStorage(c)
 		if err != nil {
 			return types.NewError(err, types.ErrorCodeReadRequestBodyFailed, types.ErrOptionWithSkipRetry())
 		}
-		requestBody = bytes.NewBuffer(body)
+		requestBody = common.ReaderOnly(storage)
 	} else {
 		convertedRequest, err := adaptor.ConvertOpenAIResponsesRequest(c, info, *request)
 		if err != nil {

+ 5 - 1
service/channel_affinity.go

@@ -288,7 +288,11 @@ func extractChannelAffinityValue(c *gin.Context, src operation_setting.ChannelAf
 		if src.Path == "" {
 			return ""
 		}
-		body, err := common.GetRequestBody(c)
+		storage, err := common.GetBodyStorage(c)
+		if err != nil {
+			return ""
+		}
+		body, err := storage.Bytes()
 		if err != nil || len(body) == 0 {
 			return ""
 		}