Просмотр исходного кода

Merge pull request #1767 from QuantumNous/copy-claude-header-from-upstream

fix: claude header was not set correctly
Calcium-Ion 5 месяцев назад
Родитель
Сommit
18077b6e87
2 измененных файлов с 12 добавлено и 7 удалено
  1. 6 1
      relay/channel/aws/relay-aws.go
  2. 6 6
      relay/channel/claude/relay-claude.go

+ 6 - 1
relay/channel/aws/relay-aws.go

@@ -130,7 +130,12 @@ func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (*
 		Usage:        &dto.Usage{},
 		Usage:        &dto.Usage{},
 	}
 	}
 
 
-	handlerErr := claude.HandleClaudeResponseData(c, info, claudeInfo, awsResp.Body, RequestModeMessage)
+	// 复制上游 Content-Type 到客户端响应头
+	if awsResp.ContentType != nil && *awsResp.ContentType != "" {
+		c.Writer.Header().Set("Content-Type", *awsResp.ContentType)
+	}
+
+	handlerErr := claude.HandleClaudeResponseData(c, info, claudeInfo, nil, awsResp.Body, RequestModeMessage)
 	if handlerErr != nil {
 	if handlerErr != nil {
 		return handlerErr, nil
 		return handlerErr, nil
 	}
 	}

+ 6 - 6
relay/channel/claude/relay-claude.go

@@ -276,7 +276,7 @@ func RequestOpenAI2ClaudeMessage(c *gin.Context, textRequest dto.GeneralOpenAIRe
 	isFirstMessage := true
 	isFirstMessage := true
 	// 初始化system消息数组,用于累积多个system消息
 	// 初始化system消息数组,用于累积多个system消息
 	var systemMessages []dto.ClaudeMediaMessage
 	var systemMessages []dto.ClaudeMediaMessage
-	
+
 	for _, message := range formatMessages {
 	for _, message := range formatMessages {
 		if message.Role == "system" {
 		if message.Role == "system" {
 			// 根据Claude API规范,system字段使用数组格式更有通用性
 			// 根据Claude API规范,system字段使用数组格式更有通用性
@@ -401,12 +401,12 @@ func RequestOpenAI2ClaudeMessage(c *gin.Context, textRequest dto.GeneralOpenAIRe
 			claudeMessages = append(claudeMessages, claudeMessage)
 			claudeMessages = append(claudeMessages, claudeMessage)
 		}
 		}
 	}
 	}
-	
+
 	// 设置累积的system消息
 	// 设置累积的system消息
 	if len(systemMessages) > 0 {
 	if len(systemMessages) > 0 {
 		claudeRequest.System = systemMessages
 		claudeRequest.System = systemMessages
 	}
 	}
-	
+
 	claudeRequest.Prompt = ""
 	claudeRequest.Prompt = ""
 	claudeRequest.Messages = claudeMessages
 	claudeRequest.Messages = claudeMessages
 	return &claudeRequest, nil
 	return &claudeRequest, nil
@@ -716,7 +716,7 @@ func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
 	return claudeInfo.Usage, nil
 	return claudeInfo.Usage, nil
 }
 }
 
 
-func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, data []byte, requestMode int) *types.NewAPIError {
+func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, httpResp *http.Response, data []byte, requestMode int) *types.NewAPIError {
 	var claudeResponse dto.ClaudeResponse
 	var claudeResponse dto.ClaudeResponse
 	err := common.Unmarshal(data, &claudeResponse)
 	err := common.Unmarshal(data, &claudeResponse)
 	if err != nil {
 	if err != nil {
@@ -754,7 +754,7 @@ func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
 		c.Set("claude_web_search_requests", claudeResponse.Usage.ServerToolUse.WebSearchRequests)
 		c.Set("claude_web_search_requests", claudeResponse.Usage.ServerToolUse.WebSearchRequests)
 	}
 	}
 
 
-	service.IOCopyBytesGracefully(c, nil, responseData)
+	service.IOCopyBytesGracefully(c, httpResp, responseData)
 	return nil
 	return nil
 }
 }
 
 
@@ -775,7 +775,7 @@ func ClaudeHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayI
 	if common.DebugEnabled {
 	if common.DebugEnabled {
 		println("responseBody: ", string(responseBody))
 		println("responseBody: ", string(responseBody))
 	}
 	}
-	handleErr := HandleClaudeResponseData(c, info, claudeInfo, responseBody, requestMode)
+	handleErr := HandleClaudeResponseData(c, info, claudeInfo, resp, responseBody, requestMode)
 	if handleErr != nil {
 	if handleErr != nil {
 		return nil, handleErr
 		return nil, handleErr
 	}
 	}