Bläddra i källkod

feat(channel): enhance Claude response handling with new Done flag and improved usage tracking

CaIon 8 månader sedan
förälder
incheckning
b7c3328d43
1 ändrade filer med 34 tillägg och 35 borttagningar
  1. 34 35
      relay/channel/claude/relay-claude.go

+ 34 - 35
relay/channel/claude/relay-claude.go

@@ -454,6 +454,7 @@ type ClaudeResponseInfo struct {
 	Model        string
 	ResponseText strings.Builder
 	Usage        *dto.Usage
+	Done         bool
 }
 
 func FormatClaudeResponseInfo(requestMode int, claudeResponse *dto.ClaudeResponse, oaiResponse *dto.ChatCompletionsStreamResponse, claudeInfo *ClaudeResponseInfo) bool {
@@ -461,20 +462,32 @@ func FormatClaudeResponseInfo(requestMode int, claudeResponse *dto.ClaudeRespons
 		claudeInfo.ResponseText.WriteString(claudeResponse.Completion)
 	} else {
 		if claudeResponse.Type == "message_start" {
-			// message_start, 获取usage
 			claudeInfo.ResponseId = claudeResponse.Message.Id
 			claudeInfo.Model = claudeResponse.Message.Model
+
+			// message_start, 获取usage
 			claudeInfo.Usage.PromptTokens = claudeResponse.Message.Usage.InputTokens
+			claudeInfo.Usage.PromptTokensDetails.CachedTokens = claudeResponse.Message.Usage.CacheReadInputTokens
+			claudeInfo.Usage.PromptTokensDetails.CachedCreationTokens = claudeResponse.Message.Usage.CacheCreationInputTokens
+			claudeInfo.Usage.CompletionTokens = claudeResponse.Message.Usage.OutputTokens
 		} else if claudeResponse.Type == "content_block_delta" {
 			if claudeResponse.Delta.Text != nil {
 				claudeInfo.ResponseText.WriteString(*claudeResponse.Delta.Text)
 			}
+			if claudeResponse.Delta.Thinking != "" {
+				claudeInfo.ResponseText.WriteString(claudeResponse.Delta.Thinking)
+			}
 		} else if claudeResponse.Type == "message_delta" {
-			claudeInfo.Usage.CompletionTokens = claudeResponse.Usage.OutputTokens
+			// 最终的usage获取
 			if claudeResponse.Usage.InputTokens > 0 {
+				// 不叠加,只取最新的
 				claudeInfo.Usage.PromptTokens = claudeResponse.Usage.InputTokens
 			}
-			claudeInfo.Usage.TotalTokens = claudeInfo.Usage.PromptTokens + claudeResponse.Usage.OutputTokens
+			claudeInfo.Usage.CompletionTokens = claudeResponse.Usage.OutputTokens
+			claudeInfo.Usage.TotalTokens = claudeInfo.Usage.PromptTokens + claudeInfo.Usage.CompletionTokens
+
+			// 判断是否完整
+			claudeInfo.Done = true
 		} else if claudeResponse.Type == "content_block_start" {
 		} else {
 			return false
@@ -506,25 +519,15 @@ func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
 		}
 	}
 	if info.RelayFormat == relaycommon.RelayFormatClaude {
+		FormatClaudeResponseInfo(requestMode, &claudeResponse, nil, claudeInfo)
+
 		if requestMode == RequestModeCompletion {
-			claudeInfo.ResponseText.WriteString(claudeResponse.Completion)
 		} else {
 			if claudeResponse.Type == "message_start" {
 				// message_start, 获取usage
 				info.UpstreamModelName = claudeResponse.Message.Model
-				claudeInfo.Usage.PromptTokens = claudeResponse.Message.Usage.InputTokens
-				claudeInfo.Usage.PromptTokensDetails.CachedTokens = claudeResponse.Message.Usage.CacheReadInputTokens
-				claudeInfo.Usage.PromptTokensDetails.CachedCreationTokens = claudeResponse.Message.Usage.CacheCreationInputTokens
-				claudeInfo.Usage.CompletionTokens = claudeResponse.Message.Usage.OutputTokens
 			} else if claudeResponse.Type == "content_block_delta" {
-				claudeInfo.ResponseText.WriteString(claudeResponse.Delta.GetText())
 			} else if claudeResponse.Type == "message_delta" {
-				if claudeResponse.Usage.InputTokens > 0 {
-					// 不叠加,只取最新的
-					claudeInfo.Usage.PromptTokens = claudeResponse.Usage.InputTokens
-				}
-				claudeInfo.Usage.CompletionTokens = claudeResponse.Usage.OutputTokens
-				claudeInfo.Usage.TotalTokens = claudeInfo.Usage.PromptTokens + claudeInfo.Usage.CompletionTokens
 			}
 		}
 		helper.ClaudeChunkData(c, claudeResponse, data)
@@ -544,29 +547,25 @@ func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
 }
 
 func HandleStreamFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, requestMode int) {
-	if info.RelayFormat == relaycommon.RelayFormatClaude {
-		if requestMode == RequestModeCompletion {
-			claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, info.PromptTokens)
-		} else {
-			// 说明流模式建立失败,可能为官方出错
-			if claudeInfo.Usage.PromptTokens == 0 {
-				//usage.PromptTokens = info.PromptTokens
-			}
-			if claudeInfo.Usage.CompletionTokens == 0 {
-				claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens)
-			}
+
+	if requestMode == RequestModeCompletion {
+		claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, info.PromptTokens)
+	} else {
+		if claudeInfo.Usage.PromptTokens == 0 {
+			//上游出错
 		}
-	} else if info.RelayFormat == relaycommon.RelayFormatOpenAI {
-		if requestMode == RequestModeCompletion {
-			claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, info.PromptTokens)
-		} else {
-			if claudeInfo.Usage.PromptTokens == 0 {
-				//上游出错
-			}
-			if claudeInfo.Usage.CompletionTokens == 0 {
-				claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens)
+		if claudeInfo.Usage.CompletionTokens == 0 || !claudeInfo.Done {
+			if common.DebugEnabled {
+				common.SysError("claude response usage is not complete, maybe upstream error")
 			}
+			claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens)
 		}
+	}
+
+	if info.RelayFormat == relaycommon.RelayFormatClaude {
+		//
+	} else if info.RelayFormat == relaycommon.RelayFormatOpenAI {
+
 		if info.ShouldIncludeUsage {
 			response := helper.GenerateFinalUsageResponse(claudeInfo.ResponseId, claudeInfo.Created, info.UpstreamModelName, *claudeInfo.Usage)
 			err := helper.ObjectData(c, response)