Просмотр исходного кода

refactor: replace DeepCopy with Copy for request handling consistency

CaIon 6 месяцев назад
Родитель
Сommit
872f7a9648

+ 11 - 4
common/copy.go

@@ -6,14 +6,21 @@ import (
 	"github.com/jinzhu/copier"
 )
 
-func DeepCopy[T any](src *T) (*T, error) {
+func Copy[T any](src *T, deepCopy bool) (*T, error) {
 	if src == nil {
 		return nil, fmt.Errorf("copy source cannot be nil")
 	}
 	var dst T
-	err := copier.CopyWithOption(&dst, src, copier.Option{DeepCopy: true, IgnoreEmpty: true})
-	if err != nil {
-		return nil, err
+	if deepCopy {
+		err := copier.CopyWithOption(&dst, src, copier.Option{DeepCopy: true, IgnoreEmpty: true})
+		if err != nil {
+			return nil, err
+		}
+	} else {
+		err := copier.Copy(&dst, src)
+		if err != nil {
+			return nil, err
+		}
 	}
 	return &dst, nil
 }

+ 2 - 1
dto/gemini.go

@@ -2,11 +2,12 @@ package dto
 
 import (
 	"encoding/json"
-	"github.com/gin-gonic/gin"
 	"one-api/common"
 	"one-api/logger"
 	"one-api/types"
 	"strings"
+
+	"github.com/gin-gonic/gin"
 )
 
 type GeminiChatRequest struct {

+ 36 - 36
dto/openai_request.go

@@ -265,7 +265,7 @@ type Message struct {
 	Reasoning        string          `json:"reasoning,omitempty"`
 	ToolCalls        json.RawMessage `json:"tool_calls,omitempty"`
 	ToolCallId       string          `json:"tool_call_id,omitempty"`
-	parsedContent    []MediaContent
+	parsedContent    *[]MediaContent
 	//parsedStringContent *string
 }
 
@@ -441,7 +441,7 @@ func (m *Message) SetStringContent(content string) {
 
 func (m *Message) SetMediaContent(content []MediaContent) {
 	m.Content = content
-	m.parsedContent = content
+	m.parsedContent = &content
 }
 
 func (m *Message) IsStringContent() bool {
@@ -456,8 +456,8 @@ func (m *Message) ParseContent() []MediaContent {
 	if m.Content == nil {
 		return nil
 	}
-	if len(m.parsedContent) > 0 {
-		return m.parsedContent
+	if m.parsedContent != nil && len(*m.parsedContent) > 0 {
+		return *m.parsedContent
 	}
 
 	var contentList []MediaContent
@@ -468,7 +468,7 @@ func (m *Message) ParseContent() []MediaContent {
 			Type: ContentTypeText,
 			Text: content,
 		}}
-		m.parsedContent = contentList
+		m.parsedContent = &contentList
 		return contentList
 	}
 
@@ -580,7 +580,7 @@ func (m *Message) ParseContent() []MediaContent {
 	}
 
 	if len(contentList) > 0 {
-		m.parsedContent = contentList
+		m.parsedContent = &contentList
 	}
 	return contentList
 }
@@ -766,27 +766,27 @@ type WebSearchOptions struct {
 
 // https://platform.openai.com/docs/api-reference/responses/create
 type OpenAIResponsesRequest struct {
-	Model              string          `json:"model"`
-	Input              json.RawMessage `json:"input,omitempty"`
-	Include            json.RawMessage `json:"include,omitempty"`
-	Instructions       json.RawMessage `json:"instructions,omitempty"`
-	MaxOutputTokens    uint            `json:"max_output_tokens,omitempty"`
-	Metadata           json.RawMessage `json:"metadata,omitempty"`
-	ParallelToolCalls  bool            `json:"parallel_tool_calls,omitempty"`
-	PreviousResponseID string          `json:"previous_response_id,omitempty"`
-	Reasoning          *Reasoning      `json:"reasoning,omitempty"`
-	ServiceTier        string          `json:"service_tier,omitempty"`
-	Store              bool            `json:"store,omitempty"`
-	Stream             bool            `json:"stream,omitempty"`
-	Temperature        float64         `json:"temperature,omitempty"`
-	Text               json.RawMessage `json:"text,omitempty"`
-	ToolChoice         json.RawMessage `json:"tool_choice,omitempty"`
-	Tools              json.RawMessage `json:"tools,omitempty"` // 需要处理的参数很少,MCP 参数太多不确定,所以用 map
-	TopP               float64         `json:"top_p,omitempty"`
-	Truncation         string          `json:"truncation,omitempty"`
-	User               string          `json:"user,omitempty"`
-	MaxToolCalls       uint            `json:"max_tool_calls,omitempty"`
-	Prompt             json.RawMessage `json:"prompt,omitempty"`
+	Model              string           `json:"model"`
+	Input              *json.RawMessage `json:"input,omitempty"`
+	Include            json.RawMessage  `json:"include,omitempty"`
+	Instructions       json.RawMessage  `json:"instructions,omitempty"`
+	MaxOutputTokens    uint             `json:"max_output_tokens,omitempty"`
+	Metadata           json.RawMessage  `json:"metadata,omitempty"`
+	ParallelToolCalls  bool             `json:"parallel_tool_calls,omitempty"`
+	PreviousResponseID string           `json:"previous_response_id,omitempty"`
+	Reasoning          *Reasoning       `json:"reasoning,omitempty"`
+	ServiceTier        string           `json:"service_tier,omitempty"`
+	Store              bool             `json:"store,omitempty"`
+	Stream             bool             `json:"stream,omitempty"`
+	Temperature        float64          `json:"temperature,omitempty"`
+	Text               json.RawMessage  `json:"text,omitempty"`
+	ToolChoice         json.RawMessage  `json:"tool_choice,omitempty"`
+	Tools              *json.RawMessage `json:"tools,omitempty"` // 需要处理的参数很少,MCP 参数太多不确定,所以用 map
+	TopP               float64          `json:"top_p,omitempty"`
+	Truncation         string           `json:"truncation,omitempty"`
+	User               string           `json:"user,omitempty"`
+	MaxToolCalls       uint             `json:"max_tool_calls,omitempty"`
+	Prompt             json.RawMessage  `json:"prompt,omitempty"`
 }
 
 func (r *OpenAIResponsesRequest) GetTokenCountMeta() *types.TokenCountMeta {
@@ -837,8 +837,8 @@ func (r *OpenAIResponsesRequest) GetTokenCountMeta() *types.TokenCountMeta {
 		texts = append(texts, string(r.Prompt))
 	}
 
-	if len(r.Tools) > 0 {
-		texts = append(texts, string(r.Tools))
+	if r.Tools != nil && len(*r.Tools) > 0 {
+		texts = append(texts, string(*r.Tools))
 	}
 
 	return &types.TokenCountMeta{
@@ -859,9 +859,9 @@ func (r *OpenAIResponsesRequest) SetModelName(modelName string) {
 }
 
 func (r *OpenAIResponsesRequest) GetToolsMap() []map[string]any {
-	var toolsMap []map[string]any
-	if len(r.Tools) > 0 {
-		_ = common.Unmarshal(r.Tools, &toolsMap)
+	var toolsMap = make([]map[string]any, 0)
+	if r.Tools != nil && len(*r.Tools) > 0 {
+		_ = common.Unmarshal(*r.Tools, &toolsMap)
 	}
 	return toolsMap
 }
@@ -896,17 +896,17 @@ func (r *OpenAIResponsesRequest) ParseInput() []MediaInput {
 	// 	inputs = append(inputs, MediaInput{Type: "input_text", Text: str})
 	// 	return inputs
 	// }
-	if common.GetJsonType(r.Input) == "string" {
+	if common.GetJsonType(*r.Input) == "string" {
 		var str string
-		_ = common.Unmarshal(r.Input, &str)
+		_ = common.Unmarshal(*r.Input, &str)
 		inputs = append(inputs, MediaInput{Type: "input_text", Text: str})
 		return inputs
 	}
 
 	// Try array of parts
-	if common.GetJsonType(r.Input) == "array" {
+	if common.GetJsonType(*r.Input) == "array" {
 		var array []any
-		_ = common.Unmarshal(r.Input, &array)
+		_ = common.Unmarshal(*r.Input, &array)
 		for _, itemAny := range array {
 			// Already parsed MediaInput
 			if media, ok := itemAny.(MediaInput); ok {

+ 1 - 1
relay/audio_handler.go

@@ -22,7 +22,7 @@ func AudioHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *type
 		return types.NewError(errors.New("invalid request type"), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
 	}
 
-	request, err := common.DeepCopy(audioReq)
+	request, err := common.Copy(audioReq, false)
 	if err != nil {
 		return types.NewError(fmt.Errorf("failed to copy request to AudioRequest: %w", err), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
 	}

+ 1 - 1
relay/claude_handler.go

@@ -27,7 +27,7 @@ func ClaudeHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
 		return types.NewErrorWithStatusCode(fmt.Errorf("invalid request type, expected *dto.ClaudeRequest, got %T", info.Request), types.ErrorCodeInvalidRequest, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
 	}
 
-	request, err := common.DeepCopy(claudeReq)
+	request, err := common.Copy(claudeReq, false)
 	if err != nil {
 		return types.NewError(fmt.Errorf("failed to copy request to ClaudeRequest: %w", err), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
 	}

+ 1 - 1
relay/common/relay_info.go

@@ -313,7 +313,7 @@ func GenRelayInfoResponses(c *gin.Context, request *dto.OpenAIResponsesRequest)
 	info.ResponsesUsageInfo = &ResponsesUsageInfo{
 		BuiltInTools: make(map[string]*BuildInToolInfo),
 	}
-	if len(request.Tools) > 0 {
+	if request.Tools != nil && len(*request.Tools) > 0 {
 		for _, tool := range request.GetToolsMap() {
 			toolType := common.Interface2String(tool["type"])
 			info.ResponsesUsageInfo.BuiltInTools[toolType] = &BuildInToolInfo{

+ 1 - 1
relay/compatible_handler.go

@@ -32,7 +32,7 @@ func TextHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types
 		return types.NewErrorWithStatusCode(fmt.Errorf("invalid request type, expected dto.GeneralOpenAIRequest, got %T", info.Request), types.ErrorCodeInvalidRequest, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
 	}
 
-	request, err := common.DeepCopy(textReq)
+	request, err := common.Copy(textReq, false)
 	if err != nil {
 		return types.NewError(fmt.Errorf("failed to copy request to GeneralOpenAIRequest: %w", err), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
 	}

+ 1 - 1
relay/embedding_handler.go

@@ -23,7 +23,7 @@ func EmbeddingHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *
 		return types.NewErrorWithStatusCode(fmt.Errorf("invalid request type, expected *dto.EmbeddingRequest, got %T", info.Request), types.ErrorCodeInvalidRequest, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
 	}
 
-	request, err := common.DeepCopy(embeddingReq)
+	request, err := common.Copy(embeddingReq, false)
 	if err != nil {
 		return types.NewError(fmt.Errorf("failed to copy request to EmbeddingRequest: %w", err), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
 	}

+ 1 - 1
relay/gemini_handler.go

@@ -58,7 +58,7 @@ func GeminiHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
 		return types.NewErrorWithStatusCode(fmt.Errorf("invalid request type, expected *dto.GeminiChatRequest, got %T", info.Request), types.ErrorCodeInvalidRequest, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
 	}
 
-	request, err := common.DeepCopy(geminiReq)
+	request, err := common.Copy(geminiReq, false)
 	if err != nil {
 		return types.NewError(fmt.Errorf("failed to copy request to GeminiChatRequest: %w", err), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
 	}

+ 1 - 1
relay/image_handler.go

@@ -26,7 +26,7 @@ func ImageHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *type
 		return types.NewErrorWithStatusCode(fmt.Errorf("invalid request type, expected dto.ImageRequest, got %T", info.Request), types.ErrorCodeInvalidRequest, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
 	}
 
-	request, err := common.DeepCopy(imageReq)
+	request, err := common.Copy(imageReq, false)
 	if err != nil {
 		return types.NewError(fmt.Errorf("failed to copy request to ImageRequest: %w", err), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
 	}

+ 1 - 1
relay/rerank_handler.go

@@ -24,7 +24,7 @@ func RerankHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
 		return types.NewErrorWithStatusCode(fmt.Errorf("invalid request type, expected dto.RerankRequest, got %T", info.Request), types.ErrorCodeInvalidRequest, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
 	}
 
-	request, err := common.DeepCopy(rerankReq)
+	request, err := common.Copy(rerankReq, false)
 	if err != nil {
 		return types.NewError(fmt.Errorf("failed to copy request to ImageRequest: %w", err), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
 	}

+ 1 - 1
relay/responses_handler.go

@@ -25,7 +25,7 @@ func ResponsesHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *
 		return types.NewErrorWithStatusCode(fmt.Errorf("invalid request type, expected dto.OpenAIResponsesRequest, got %T", info.Request), types.ErrorCodeInvalidRequest, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
 	}
 
-	request, err := common.DeepCopy(responsesReq)
+	request, err := common.Copy(responsesReq, false)
 	if err != nil {
 		return types.NewError(fmt.Errorf("failed to copy request to GeneralOpenAIRequest: %w", err), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
 	}