Просмотр исходного кода

refactor: Enhance Claude response handling

1808837298@qq.com 11 месяцев назад
Родитель
Сommit
53b3599827

+ 1 - 1
dto/openai_response.go

@@ -26,7 +26,7 @@ type OpenAITextResponse struct {
 	Object  string                     `json:"object"`
 	Created int64                      `json:"created"`
 	Choices []OpenAITextResponseChoice `json:"choices"`
-	Error   *OpenAIError               `json:"error"`
+	Error   *OpenAIError               `json:"error,omitempty"`
 	Usage   `json:"usage"`
 }
 

+ 12 - 22
relay/channel/aws/relay-aws.go

@@ -84,22 +84,16 @@ func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (*
 		return wrapErr(errors.Wrap(err, "InvokeModel")), nil
 	}
 
-	claudeResponse := new(dto.ClaudeResponse)
-	err = json.Unmarshal(awsResp.Body, claudeResponse)
-	if err != nil {
-		return wrapErr(errors.Wrap(err, "unmarshal response")), nil
-	}
-
-	openaiResp := claude.ResponseClaude2OpenAI(requestMode, claudeResponse)
-	usage := dto.Usage{
-		PromptTokens:     claudeResponse.Usage.InputTokens,
-		CompletionTokens: claudeResponse.Usage.OutputTokens,
-		TotalTokens:      claudeResponse.Usage.InputTokens + claudeResponse.Usage.OutputTokens,
+	claudeInfo := &claude.ClaudeResponseInfo{
+		ResponseId:   fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
+		Created:      common.GetTimestamp(),
+		Model:        info.UpstreamModelName,
+		ResponseText: strings.Builder{},
+		Usage:        &dto.Usage{},
 	}
-	openaiResp.Usage = usage
 
-	c.JSON(http.StatusOK, openaiResp)
-	return nil, &usage
+	claude.HandleClaudeResponseData(c, info, claudeInfo, awsResp.Body, RequestModeMessage)
+	return nil, claudeInfo.Usage
 }
 
 func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
@@ -150,9 +144,9 @@ func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
 		switch v := event.(type) {
 		case *types.ResponseStreamMemberChunk:
 			info.SetFirstResponseTime()
-			err = claude.HandleResponseData(c, info, claudeInfo, string(v.Value.Bytes), RequestModeMessage)
-			if err != nil {
-				return wrapErr(err), nil
+			respErr := claude.HandleStreamResponseData(c, info, claudeInfo, string(v.Value.Bytes), RequestModeMessage)
+			if respErr != nil {
+				return respErr, nil
 			}
 		case *types.UnknownUnionMember:
 			fmt.Println("unknown tag:", v.Tag)
@@ -163,10 +157,6 @@ func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
 		}
 	}
 
-	claude.HandleFinalResponse(c, info, claudeInfo, RequestModeMessage)
-
-	if resp != nil {
-		resp.Body.Close()
-	}
+	claude.HandleStreamFinalResponse(c, info, claudeInfo, RequestModeMessage)
 	return nil, claudeInfo.Usage
 }

+ 59 - 40
relay/channel/claude/relay-claude.go

@@ -478,12 +478,22 @@ func FormatClaudeResponseInfo(requestMode int, claudeResponse *dto.ClaudeRespons
 	return true
 }
 
-func HandleResponseData(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, data string, requestMode int) error {
+func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, data string, requestMode int) *dto.OpenAIErrorWithStatusCode {
 	var claudeResponse dto.ClaudeResponse
 	err := common.DecodeJsonStr(data, &claudeResponse)
 	if err != nil {
 		common.SysError("error unmarshalling stream response: " + err.Error())
-		return fmt.Errorf("error unmarshalling stream aws response: %w", err)
+		return service.OpenAIErrorWrapper(err, "stream_response_error", http.StatusInternalServerError)
+	}
+	if claudeResponse.Error.Type != "" {
+		return &dto.OpenAIErrorWithStatusCode{
+			Error: dto.OpenAIError{
+				Code:    "stream_response_error",
+				Type:    claudeResponse.Error.Type,
+				Message: claudeResponse.Error.Message,
+			},
+			StatusCode: http.StatusInternalServerError,
+		}
 	}
 	if info.RelayFormat == relaycommon.RelayFormatClaude {
 		if requestMode == RequestModeCompletion {
@@ -523,7 +533,7 @@ func HandleResponseData(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo
 	return nil
 }
 
-func HandleFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, requestMode int) {
+func HandleStreamFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, requestMode int) {
 	if info.RelayFormat == relaycommon.RelayFormatClaude {
 		if requestMode == RequestModeCompletion {
 			claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, info.PromptTokens)
@@ -566,81 +576,90 @@ func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
 		ResponseText: strings.Builder{},
 		Usage:        &dto.Usage{},
 	}
-	var err error
+	var err *dto.OpenAIErrorWithStatusCode
 	helper.StreamScannerHandler(c, resp, info, func(data string) bool {
-		err = HandleResponseData(c, info, claudeInfo, data, requestMode)
+		err = HandleStreamResponseData(c, info, claudeInfo, data, requestMode)
 		if err != nil {
 			return false
 		}
 		return true
 	})
 	if err != nil {
-		return service.OpenAIErrorWrapper(err, "stream_response_error", http.StatusInternalServerError), nil
+		return err, nil
 	}
 
-	HandleFinalResponse(c, info, claudeInfo, requestMode)
-
+	HandleStreamFinalResponse(c, info, claudeInfo, requestMode)
 	return nil, claudeInfo.Usage
 }
 
-func ClaudeHandler(c *gin.Context, resp *http.Response, requestMode int, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
-	responseBody, err := io.ReadAll(resp.Body)
-	if err != nil {
-		return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
-	}
-	err = resp.Body.Close()
-	if err != nil {
-		return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
-	}
-	if common.DebugEnabled {
-		println("responseBody: ", string(responseBody))
-	}
+func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, data []byte, requestMode int) *dto.OpenAIErrorWithStatusCode {
 	var claudeResponse dto.ClaudeResponse
-	err = json.Unmarshal(responseBody, &claudeResponse)
+	err := common.DecodeJson(data, &claudeResponse)
 	if err != nil {
-		return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
+		return service.OpenAIErrorWrapper(err, "unmarshal_claude_response_failed", http.StatusInternalServerError)
 	}
 	if claudeResponse.Error.Type != "" {
 		return &dto.OpenAIErrorWithStatusCode{
 			Error: dto.OpenAIError{
 				Message: claudeResponse.Error.Message,
 				Type:    claudeResponse.Error.Type,
-				Param:   "",
 				Code:    claudeResponse.Error.Type,
 			},
-			StatusCode: resp.StatusCode,
-		}, nil
+			StatusCode: http.StatusInternalServerError,
+		}
 	}
-	usage := dto.Usage{}
 	if requestMode == RequestModeCompletion {
 		completionTokens, err := service.CountTextToken(claudeResponse.Completion, info.OriginModelName)
 		if err != nil {
-			return service.OpenAIErrorWrapper(err, "count_token_text_failed", http.StatusInternalServerError), nil
+			return service.OpenAIErrorWrapper(err, "count_token_text_failed", http.StatusInternalServerError)
 		}
-		usage.PromptTokens = info.PromptTokens
-		usage.CompletionTokens = completionTokens
-		usage.TotalTokens = info.PromptTokens + completionTokens
+		claudeInfo.Usage.PromptTokens = info.PromptTokens
+		claudeInfo.Usage.CompletionTokens = completionTokens
+		claudeInfo.Usage.TotalTokens = info.PromptTokens + completionTokens
 	} else {
-		usage.PromptTokens = claudeResponse.Usage.InputTokens
-		usage.CompletionTokens = claudeResponse.Usage.OutputTokens
-		usage.TotalTokens = claudeResponse.Usage.InputTokens + claudeResponse.Usage.OutputTokens
-		usage.PromptTokensDetails.CachedTokens = claudeResponse.Usage.CacheReadInputTokens
-		usage.PromptTokensDetails.CachedCreationTokens = claudeResponse.Usage.CacheCreationInputTokens
+		claudeInfo.Usage.PromptTokens = claudeResponse.Usage.InputTokens
+		claudeInfo.Usage.CompletionTokens = claudeResponse.Usage.OutputTokens
+		claudeInfo.Usage.TotalTokens = claudeResponse.Usage.InputTokens + claudeResponse.Usage.OutputTokens
+		claudeInfo.Usage.PromptTokensDetails.CachedTokens = claudeResponse.Usage.CacheReadInputTokens
+		claudeInfo.Usage.PromptTokensDetails.CachedCreationTokens = claudeResponse.Usage.CacheCreationInputTokens
 	}
 	var responseData []byte
 	switch info.RelayFormat {
 	case relaycommon.RelayFormatOpenAI:
 		openaiResponse := ResponseClaude2OpenAI(requestMode, &claudeResponse)
-		openaiResponse.Usage = usage
+		openaiResponse.Usage = *claudeInfo.Usage
 		responseData, err = json.Marshal(openaiResponse)
 		if err != nil {
-			return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
+			return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError)
 		}
 	case relaycommon.RelayFormatClaude:
-		responseData = responseBody
+		responseData = data
 	}
 	c.Writer.Header().Set("Content-Type", "application/json")
-	c.Writer.WriteHeader(resp.StatusCode)
+	c.Writer.WriteHeader(http.StatusOK)
 	_, err = c.Writer.Write(responseData)
-	return nil, &usage
+	return nil
+}
+
+func ClaudeHandler(c *gin.Context, resp *http.Response, requestMode int, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
+	claudeInfo := &ClaudeResponseInfo{
+		ResponseId:   fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
+		Created:      common.GetTimestamp(),
+		Model:        info.UpstreamModelName,
+		ResponseText: strings.Builder{},
+		Usage:        &dto.Usage{},
+	}
+	responseBody, err := io.ReadAll(resp.Body)
+	if err != nil {
+		return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
+	}
+	resp.Body.Close()
+	if common.DebugEnabled {
+		println("responseBody: ", string(responseBody))
+	}
+	handleErr := HandleClaudeResponseData(c, info, claudeInfo, responseBody, requestMode)
+	if handleErr != nil {
+		return handleErr, nil
+	}
+	return nil, claudeInfo.Usage
 }

+ 2 - 1
relay/channel/openai/relay-openai.go

@@ -240,7 +240,8 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayI
 	c.Writer.WriteHeader(resp.StatusCode)
 	_, err = io.Copy(c.Writer, resp.Body)
 	if err != nil {
-		return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
+		//return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
+		common.SysError("error copying response body: " + err.Error())
 	}
 	resp.Body.Close()
 	if simpleResponse.Usage.TotalTokens == 0 || (simpleResponse.Usage.PromptTokens == 0 && simpleResponse.Usage.CompletionTokens == 0) {