Просмотр исходного кода

fix: claude to openai tools use

1808837298@qq.com 1 год назад
Родитель
Сommit
229738cda9
2 измененных файлов с 83 добавлено и 73 удалено
  1. 21 24
      relay/channel/aws/relay-aws.go
  2. 62 49
      relay/channel/claude/relay-claude.go

+ 21 - 24
relay/channel/aws/relay-aws.go

@@ -144,11 +144,14 @@ func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
 	defer stream.Close()
 
 	c.Writer.Header().Set("Content-Type", "text/event-stream")
-	var usage relaymodel.Usage
-	var id string
-	var model string
+	claudeInfo := &claude.ClaudeResponseInfo{
+		ResponseId:   fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
+		Created:      common.GetTimestamp(),
+		Model:        info.UpstreamModelName,
+		ResponseText: strings.Builder{},
+		Usage:        &relaymodel.Usage{},
+	}
 	isFirst := true
-	createdTime := common.GetTimestamp()
 	c.Stream(func(w io.Writer) bool {
 		event, ok := <-stream.Events()
 		if !ok {
@@ -161,33 +164,19 @@ func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
 				isFirst = false
 				info.FirstResponseTime = time.Now()
 			}
-			claudeResp := new(claude.ClaudeResponse)
-			err := json.NewDecoder(bytes.NewReader(v.Value.Bytes)).Decode(claudeResp)
+			claudeResponse := new(claude.ClaudeResponse)
+			err := json.NewDecoder(bytes.NewReader(v.Value.Bytes)).Decode(claudeResponse)
 			if err != nil {
 				common.SysError("error unmarshalling stream response: " + err.Error())
 				return false
 			}
 
-			response, claudeUsage := claude.StreamResponseClaude2OpenAI(requestMode, claudeResp)
-			if claudeUsage != nil {
-				usage.PromptTokens += claudeUsage.InputTokens
-				usage.CompletionTokens += claudeUsage.OutputTokens
-			}
+			response := claude.StreamResponseClaude2OpenAI(requestMode, claudeResponse)
 
-			if response == nil {
+			if !claude.FormatClaudeResponseInfo(RequestModeMessage, claudeResponse, response, claudeInfo) {
 				return true
 			}
 
-			if response.Id != "" {
-				id = response.Id
-			}
-			if response.Model != "" {
-				model = response.Model
-			}
-			response.Created = createdTime
-			response.Id = id
-			response.Model = model
-
 			jsonStr, err := json.Marshal(response)
 			if err != nil {
 				common.SysError("error marshalling stream response: " + err.Error())
@@ -203,8 +192,16 @@ func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
 			return false
 		}
 	})
+
+	if claudeInfo.Usage.PromptTokens == 0 {
+		//上游出错
+	}
+	if claudeInfo.Usage.CompletionTokens == 0 {
+		claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens)
+	}
+
 	if info.ShouldIncludeUsage {
-		response := helper.GenerateFinalUsageResponse(id, createdTime, info.UpstreamModelName, usage)
+		response := helper.GenerateFinalUsageResponse(claudeInfo.ResponseId, claudeInfo.Created, info.UpstreamModelName, *claudeInfo.Usage)
 		err := helper.ObjectData(c, response)
 		if err != nil {
 			common.SysError("send final response failed: " + err.Error())
@@ -217,5 +214,5 @@ func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
 			return service.OpenAIErrorWrapperLocal(err, "close_response_body_failed", http.StatusInternalServerError), nil
 		}
 	}
-	return nil, &usage
+	return nil, claudeInfo.Usage
 }

+ 62 - 49
relay/channel/claude/relay-claude.go

@@ -1,6 +1,7 @@
 package claude
 
 import (
+	"bytes"
 	"encoding/json"
 	"fmt"
 	"io"
@@ -290,9 +291,8 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
 	return &claudeRequest, nil
 }
 
-func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) (*dto.ChatCompletionsStreamResponse, *ClaudeUsage) {
+func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.ChatCompletionsStreamResponse {
 	var response dto.ChatCompletionsStreamResponse
-	var claudeUsage *ClaudeUsage
 	response.Object = "chat.completion.chunk"
 	response.Model = claudeResponse.Model
 	response.Choices = make([]dto.ChatCompletionsStreamResponseChoice, 0)
@@ -308,7 +308,7 @@ func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) (*
 		if claudeResponse.Type == "message_start" {
 			response.Id = claudeResponse.Message.Id
 			response.Model = claudeResponse.Message.Model
-			claudeUsage = &claudeResponse.Message.Usage
+			//claudeUsage = &claudeResponse.Message.Usage
 			choice.Delta.SetContentString("")
 			choice.Delta.Role = "assistant"
 		} else if claudeResponse.Type == "content_block_start" {
@@ -325,7 +325,7 @@ func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) (*
 					})
 				}
 			} else {
-				return nil, nil
+				return nil
 			}
 		} else if claudeResponse.Type == "content_block_delta" {
 			if claudeResponse.Delta != nil {
@@ -352,23 +352,20 @@ func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) (*
 			if finishReason != "null" {
 				choice.FinishReason = &finishReason
 			}
-			claudeUsage = &claudeResponse.Usage
+			//claudeUsage = &claudeResponse.Usage
 		} else if claudeResponse.Type == "message_stop" {
-			return nil, nil
+			return nil
 		} else {
-			return nil, nil
+			return nil
 		}
 	}
-	if claudeUsage == nil {
-		claudeUsage = &ClaudeUsage{}
-	}
 	if len(tools) > 0 {
 		choice.Delta.Content = nil // compatible with other OpenAI derivative applications, like LobeOpenAICompatibleFactory ...
 		choice.Delta.ToolCalls = tools
 	}
 	response.Choices = append(response.Choices, choice)
 
-	return &response, claudeUsage
+	return &response
 }
 
 func ResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.OpenAITextResponse {
@@ -437,48 +434,65 @@ func ResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.Ope
 	return &fullTextResponse
 }
 
+type ClaudeResponseInfo struct {
+	ResponseId   string
+	Created      int64
+	Model        string
+	ResponseText strings.Builder
+	Usage        *dto.Usage
+}
+
+func FormatClaudeResponseInfo(requestMode int, claudeResponse *ClaudeResponse, oaiResponse *dto.ChatCompletionsStreamResponse, claudeInfo *ClaudeResponseInfo) bool {
+	if oaiResponse == nil {
+		return false
+	}
+	if requestMode == RequestModeCompletion {
+		claudeInfo.ResponseText.WriteString(claudeResponse.Completion)
+	} else {
+		if claudeResponse.Type == "message_start" {
+			// message_start, 获取usage
+			claudeInfo.ResponseId = claudeResponse.Message.Id
+			claudeInfo.Model = claudeResponse.Message.Model
+			claudeInfo.Usage.PromptTokens = claudeResponse.Message.Usage.InputTokens
+		} else if claudeResponse.Type == "content_block_delta" {
+			claudeInfo.ResponseText.WriteString(claudeResponse.Delta.Text)
+		} else if claudeResponse.Type == "message_delta" {
+			claudeInfo.Usage.CompletionTokens = claudeResponse.Usage.OutputTokens
+			claudeInfo.Usage.TotalTokens = claudeResponse.Usage.InputTokens + claudeResponse.Usage.OutputTokens
+		} else if claudeResponse.Type == "content_block_start" {
+		} else {
+			return false
+		}
+	}
+	oaiResponse.Id = claudeInfo.ResponseId
+	oaiResponse.Created = claudeInfo.Created
+	oaiResponse.Model = claudeInfo.Model
+	return true
+}
+
 func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
 	responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
-	var usage *dto.Usage
-	usage = &dto.Usage{}
-	responseText := ""
-	createdTime := common.GetTimestamp()
+	claudeInfo := &ClaudeResponseInfo{
+		ResponseId:   responseId,
+		Created:      common.GetTimestamp(),
+		Model:        info.UpstreamModelName,
+		ResponseText: strings.Builder{},
+		Usage:        &dto.Usage{},
+	}
 
 	helper.StreamScannerHandler(c, resp, info, func(data string) bool {
 		var claudeResponse ClaudeResponse
-		err := json.Unmarshal([]byte(data), &claudeResponse)
+		err := json.NewDecoder(bytes.NewReader(common.StringToByteSlice(data))).Decode(&claudeResponse)
 		if err != nil {
 			common.SysError("error unmarshalling stream response: " + err.Error())
 			return true
 		}
 
-		response, claudeUsage := StreamResponseClaude2OpenAI(requestMode, &claudeResponse)
-		if response == nil {
+		response := StreamResponseClaude2OpenAI(requestMode, &claudeResponse)
+
+		if !FormatClaudeResponseInfo(requestMode, &claudeResponse, response, claudeInfo) {
 			return true
 		}
-		if requestMode == RequestModeCompletion {
-			responseText += claudeResponse.Completion
-			responseId = response.Id
-		} else {
-			if claudeResponse.Type == "message_start" {
-				// message_start, 获取usage
-				responseId = claudeResponse.Message.Id
-				info.UpstreamModelName = claudeResponse.Message.Model
-				usage.PromptTokens = claudeUsage.InputTokens
-			} else if claudeResponse.Type == "content_block_delta" {
-				responseText += claudeResponse.Delta.Text
-			} else if claudeResponse.Type == "message_delta" {
-				usage.CompletionTokens = claudeUsage.OutputTokens
-				usage.TotalTokens = claudeUsage.InputTokens + claudeUsage.OutputTokens
-			} else if claudeResponse.Type == "content_block_start" {
-			} else {
-				return true
-			}
-		}
-		//response.Id = responseId
-		response.Id = responseId
-		response.Created = createdTime
-		response.Model = info.UpstreamModelName
 
 		err = helper.ObjectData(c, response)
 		if err != nil {
@@ -488,25 +502,24 @@ func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
 	})
 
 	if requestMode == RequestModeCompletion {
-		usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
+		claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, info.PromptTokens)
 	} else {
-		if usage.PromptTokens == 0 {
-			usage.PromptTokens = info.PromptTokens
+		if claudeInfo.Usage.PromptTokens == 0 {
+			//上游出错
 		}
-		if usage.CompletionTokens == 0 {
-			usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, usage.PromptTokens)
+		if claudeInfo.Usage.CompletionTokens == 0 {
+			claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens)
 		}
 	}
 	if info.ShouldIncludeUsage {
-		response := helper.GenerateFinalUsageResponse(responseId, createdTime, info.UpstreamModelName, *usage)
+		response := helper.GenerateFinalUsageResponse(responseId, claudeInfo.Created, info.UpstreamModelName, *claudeInfo.Usage)
 		err := helper.ObjectData(c, response)
 		if err != nil {
 			common.SysError("send final response failed: " + err.Error())
 		}
 	}
 	helper.Done(c)
-	//resp.Body.Close()
-	return nil, usage
+	return nil, claudeInfo.Usage
 }
 
 func ClaudeHandler(c *gin.Context, resp *http.Response, requestMode int, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {