Просмотр исходного кода

refactor: update error handling to support dynamic error types

CaIon 7 месяцев назад
Родитель
Сommit
ce031f7d15

+ 40 - 8
dto/claude.go

@@ -2,6 +2,7 @@ package dto
 
 import (
 	"encoding/json"
+	"fmt"
 	"one-api/common"
 	"one-api/types"
 )
@@ -284,14 +285,9 @@ func (c *ClaudeRequest) ParseSystem() []ClaudeMediaMessage {
 	return mediaContent
 }
 
-type ClaudeError struct {
-	Type    string `json:"type,omitempty"`
-	Message string `json:"message,omitempty"`
-}
-
 type ClaudeErrorWithStatusCode struct {
-	Error      ClaudeError `json:"error"`
-	StatusCode int         `json:"status_code"`
+	Error      types.ClaudeError `json:"error"`
+	StatusCode int               `json:"status_code"`
 	LocalError bool
 }
 
@@ -303,7 +299,7 @@ type ClaudeResponse struct {
 	Completion   string               `json:"completion,omitempty"`
 	StopReason   string               `json:"stop_reason,omitempty"`
 	Model        string               `json:"model,omitempty"`
-	Error        *types.ClaudeError   `json:"error,omitempty"`
+	Error        any                  `json:"error,omitempty"`
 	Usage        *ClaudeUsage         `json:"usage,omitempty"`
 	Index        *int                 `json:"index,omitempty"`
 	ContentBlock *ClaudeMediaMessage  `json:"content_block,omitempty"`
@@ -324,6 +320,42 @@ func (c *ClaudeResponse) GetIndex() int {
 	return *c.Index
 }
 
+// GetClaudeError 从动态错误类型中提取ClaudeError结构
+func (c *ClaudeResponse) GetClaudeError() *types.ClaudeError {
+	if c.Error == nil {
+		return nil
+	}
+
+	switch err := c.Error.(type) {
+	case types.ClaudeError:
+		return &err
+	case *types.ClaudeError:
+		return err
+	case map[string]interface{}:
+		// 处理从JSON解析来的map结构
+		claudeErr := &types.ClaudeError{}
+		if errType, ok := err["type"].(string); ok {
+			claudeErr.Type = errType
+		}
+		if errMsg, ok := err["message"].(string); ok {
+			claudeErr.Message = errMsg
+		}
+		return claudeErr
+	case string:
+		// 处理简单字符串错误
+		return &types.ClaudeError{
+			Type:    "error",
+			Message: err,
+		}
+	default:
+		// 未知类型,尝试转换为字符串
+		return &types.ClaudeError{
+			Type:    "unknown_error",
+			Message: fmt.Sprintf("%v", err),
+		}
+	}
+}
+
 type ClaudeUsage struct {
 	InputTokens              int                  `json:"input_tokens"`
 	CacheCreationInputTokens int                  `json:"cache_creation_input_tokens"`

+ 61 - 3
dto/openai_response.go

@@ -2,12 +2,18 @@ package dto
 
 import (
 	"encoding/json"
+	"fmt"
 	"one-api/types"
 )
 
 type SimpleResponse struct {
 	Usage `json:"usage"`
-	Error *OpenAIError `json:"error"`
+	Error any `json:"error"`
+}
+
+// GetOpenAIError 从动态错误类型中提取OpenAIError结构
+func (s *SimpleResponse) GetOpenAIError() *types.OpenAIError {
+	return GetOpenAIError(s.Error)
 }
 
 type TextResponse struct {
@@ -31,10 +37,15 @@ type OpenAITextResponse struct {
 	Object  string                     `json:"object"`
 	Created any                        `json:"created"`
 	Choices []OpenAITextResponseChoice `json:"choices"`
-	Error   *types.OpenAIError         `json:"error,omitempty"`
+	Error   any                        `json:"error,omitempty"`
 	Usage   `json:"usage"`
 }
 
+// GetOpenAIError 从动态错误类型中提取OpenAIError结构
+func (o *OpenAITextResponse) GetOpenAIError() *types.OpenAIError {
+	return GetOpenAIError(o.Error)
+}
+
 type OpenAIEmbeddingResponseItem struct {
 	Object    string    `json:"object"`
 	Index     int       `json:"index"`
@@ -217,7 +228,7 @@ type OpenAIResponsesResponse struct {
 	Object             string             `json:"object"`
 	CreatedAt          int                `json:"created_at"`
 	Status             string             `json:"status"`
-	Error              *types.OpenAIError `json:"error,omitempty"`
+	Error              any                `json:"error,omitempty"`
 	IncompleteDetails  *IncompleteDetails `json:"incomplete_details,omitempty"`
 	Instructions       string             `json:"instructions"`
 	MaxOutputTokens    int                `json:"max_output_tokens"`
@@ -237,6 +248,11 @@ type OpenAIResponsesResponse struct {
 	Metadata           json.RawMessage    `json:"metadata"`
 }
 
+// GetOpenAIError 从动态错误类型中提取OpenAIError结构
+func (o *OpenAIResponsesResponse) GetOpenAIError() *types.OpenAIError {
+	return GetOpenAIError(o.Error)
+}
+
 type IncompleteDetails struct {
 	Reasoning string `json:"reasoning"`
 }
@@ -276,3 +292,45 @@ type ResponsesStreamResponse struct {
 	Delta    string                   `json:"delta,omitempty"`
 	Item     *ResponsesOutput         `json:"item,omitempty"`
 }
+
+// GetOpenAIError 从动态错误类型中提取OpenAIError结构
+func GetOpenAIError(errorField any) *types.OpenAIError {
+	if errorField == nil {
+		return nil
+	}
+
+	switch err := errorField.(type) {
+	case types.OpenAIError:
+		return &err
+	case *types.OpenAIError:
+		return err
+	case map[string]interface{}:
+		// 处理从JSON解析来的map结构
+		openaiErr := &types.OpenAIError{}
+		if errType, ok := err["type"].(string); ok {
+			openaiErr.Type = errType
+		}
+		if errMsg, ok := err["message"].(string); ok {
+			openaiErr.Message = errMsg
+		}
+		if errParam, ok := err["param"].(string); ok {
+			openaiErr.Param = errParam
+		}
+		if errCode, ok := err["code"]; ok {
+			openaiErr.Code = errCode
+		}
+		return openaiErr
+	case string:
+		// 处理简单字符串错误
+		return &types.OpenAIError{
+			Type:    "error",
+			Message: err,
+		}
+	default:
+		// 未知类型,尝试转换为字符串
+		return &types.OpenAIError{
+			Type:    "unknown_error",
+			Message: fmt.Sprintf("%v", err),
+		}
+	}
+}

+ 4 - 4
relay/channel/claude/relay-claude.go

@@ -612,8 +612,8 @@ func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
 		common.SysError("error unmarshalling stream response: " + err.Error())
 		return types.NewError(err, types.ErrorCodeBadResponseBody)
 	}
-	if claudeResponse.Error != nil && claudeResponse.Error.Type != "" {
-		return types.WithClaudeError(*claudeResponse.Error, http.StatusInternalServerError)
+	if claudeError := claudeResponse.GetClaudeError(); claudeError != nil && claudeError.Type != "" {
+		return types.WithClaudeError(*claudeError, http.StatusInternalServerError)
 	}
 	if info.RelayFormat == relaycommon.RelayFormatClaude {
 		FormatClaudeResponseInfo(requestMode, &claudeResponse, nil, claudeInfo)
@@ -704,8 +704,8 @@ func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
 	if err != nil {
 		return types.NewError(err, types.ErrorCodeBadResponseBody)
 	}
-	if claudeResponse.Error != nil && claudeResponse.Error.Type != "" {
-		return types.WithClaudeError(*claudeResponse.Error, http.StatusInternalServerError)
+	if claudeError := claudeResponse.GetClaudeError(); claudeError != nil && claudeError.Type != "" {
+		return types.WithClaudeError(*claudeError, http.StatusInternalServerError)
 	}
 	if requestMode == RequestModeCompletion {
 		completionTokens := service.CountTextToken(claudeResponse.Completion, info.OriginModelName)

+ 6 - 3
relay/channel/gemini/dto.go

@@ -1,6 +1,9 @@
 package gemini
 
-import "encoding/json"
+import (
+	"encoding/json"
+	"one-api/common"
+)
 
 type GeminiChatRequest struct {
 	Contents           []GeminiChatContent        `json:"contents"`
@@ -32,7 +35,7 @@ func (g *GeminiInlineData) UnmarshalJSON(data []byte) error {
 		MimeTypeSnake string `json:"mime_type"`
 	}
 
-	if err := json.Unmarshal(data, &aux); err != nil {
+	if err := common.Unmarshal(data, &aux); err != nil {
 		return err
 	}
 
@@ -93,7 +96,7 @@ func (p *GeminiPart) UnmarshalJSON(data []byte) error {
 		InlineDataSnake *GeminiInlineData `json:"inline_data,omitempty"` // snake_case variant
 	}
 
-	if err := json.Unmarshal(data, &aux); err != nil {
+	if err := common.Unmarshal(data, &aux); err != nil {
 		return err
 	}
 

+ 2 - 2
relay/channel/openai/relay-openai.go

@@ -184,8 +184,8 @@ func OpenaiHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo
 	if err != nil {
 		return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
 	}
-	if simpleResponse.Error != nil && simpleResponse.Error.Type != "" {
-		return nil, types.WithOpenAIError(*simpleResponse.Error, resp.StatusCode)
+	if oaiError := simpleResponse.GetOpenAIError(); oaiError != nil && oaiError.Type != "" {
+		return nil, types.WithOpenAIError(*oaiError, resp.StatusCode)
 	}
 
 	forceFormat := false

+ 2 - 2
relay/channel/openai/relay_responses.go

@@ -28,8 +28,8 @@ func OaiResponsesHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http
 	if err != nil {
 		return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
 	}
-	if responsesResponse.Error != nil {
-		return nil, types.WithOpenAIError(*responsesResponse.Error, resp.StatusCode)
+	if oaiError := responsesResponse.GetOpenAIError(); oaiError != nil && oaiError.Type != "" {
+		return nil, types.WithOpenAIError(*oaiError, resp.StatusCode)
 	}
 
 	// 写入新的 response body

+ 0 - 22
service/convert.go

@@ -188,28 +188,6 @@ func ClaudeToOpenAIRequest(claudeRequest dto.ClaudeRequest, info *relaycommon.Re
 	return &openAIRequest, nil
 }
 
-func OpenAIErrorToClaudeError(openAIError *dto.OpenAIErrorWithStatusCode) *dto.ClaudeErrorWithStatusCode {
-	claudeError := dto.ClaudeError{
-		Type:    "new_api_error",
-		Message: openAIError.Error.Message,
-	}
-	return &dto.ClaudeErrorWithStatusCode{
-		Error:      claudeError,
-		StatusCode: openAIError.StatusCode,
-	}
-}
-
-func ClaudeErrorToOpenAIError(claudeError *dto.ClaudeErrorWithStatusCode) *dto.OpenAIErrorWithStatusCode {
-	openAIError := dto.OpenAIError{
-		Message: claudeError.Error.Message,
-		Type:    "new_api_error",
-	}
-	return &dto.OpenAIErrorWithStatusCode{
-		Error:      openAIError,
-		StatusCode: claudeError.StatusCode,
-	}
-}
-
 func generateStopBlock(index int) *dto.ClaudeResponse {
 	return &dto.ClaudeResponse{
 		Type:  "content_block_stop",

+ 1 - 1
service/error.go

@@ -62,7 +62,7 @@ func ClaudeErrorWrapper(err error, code string, statusCode int) *dto.ClaudeError
 			text = "请求上游地址失败"
 		}
 	}
-	claudeError := dto.ClaudeError{
+	claudeError := types.ClaudeError{
 		Message: text,
 		Type:    "new_api_error",
 	}

+ 1 - 1
types/error.go

@@ -16,8 +16,8 @@ type OpenAIError struct {
 }
 
 type ClaudeError struct {
-	Message string `json:"message,omitempty"`
 	Type    string `json:"type,omitempty"`
+	Message string `json:"message,omitempty"`
 }
 
 type ErrorType string