|
|
@@ -1,6 +1,7 @@
|
|
|
package claude
|
|
|
|
|
|
import (
|
|
|
+ "bytes"
|
|
|
"encoding/json"
|
|
|
"fmt"
|
|
|
"io"
|
|
|
@@ -290,9 +291,8 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
|
|
|
return &claudeRequest, nil
|
|
|
}
|
|
|
|
|
|
-func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) (*dto.ChatCompletionsStreamResponse, *ClaudeUsage) {
|
|
|
+func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.ChatCompletionsStreamResponse {
|
|
|
var response dto.ChatCompletionsStreamResponse
|
|
|
- var claudeUsage *ClaudeUsage
|
|
|
response.Object = "chat.completion.chunk"
|
|
|
response.Model = claudeResponse.Model
|
|
|
response.Choices = make([]dto.ChatCompletionsStreamResponseChoice, 0)
|
|
|
@@ -308,7 +308,7 @@ func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) (*
|
|
|
if claudeResponse.Type == "message_start" {
|
|
|
response.Id = claudeResponse.Message.Id
|
|
|
response.Model = claudeResponse.Message.Model
|
|
|
- claudeUsage = &claudeResponse.Message.Usage
|
|
|
+ //claudeUsage = &claudeResponse.Message.Usage
|
|
|
choice.Delta.SetContentString("")
|
|
|
choice.Delta.Role = "assistant"
|
|
|
} else if claudeResponse.Type == "content_block_start" {
|
|
|
@@ -325,7 +325,7 @@ func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) (*
|
|
|
})
|
|
|
}
|
|
|
} else {
|
|
|
- return nil, nil
|
|
|
+ return nil
|
|
|
}
|
|
|
} else if claudeResponse.Type == "content_block_delta" {
|
|
|
if claudeResponse.Delta != nil {
|
|
|
@@ -352,23 +352,20 @@ func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) (*
|
|
|
if finishReason != "null" {
|
|
|
choice.FinishReason = &finishReason
|
|
|
}
|
|
|
- claudeUsage = &claudeResponse.Usage
|
|
|
+ //claudeUsage = &claudeResponse.Usage
|
|
|
} else if claudeResponse.Type == "message_stop" {
|
|
|
- return nil, nil
|
|
|
+ return nil
|
|
|
} else {
|
|
|
- return nil, nil
|
|
|
+ return nil
|
|
|
}
|
|
|
}
|
|
|
- if claudeUsage == nil {
|
|
|
- claudeUsage = &ClaudeUsage{}
|
|
|
- }
|
|
|
if len(tools) > 0 {
|
|
|
choice.Delta.Content = nil // compatible with other OpenAI derivative applications, like LobeOpenAICompatibleFactory ...
|
|
|
choice.Delta.ToolCalls = tools
|
|
|
}
|
|
|
response.Choices = append(response.Choices, choice)
|
|
|
|
|
|
- return &response, claudeUsage
|
|
|
+ return &response
|
|
|
}
|
|
|
|
|
|
func ResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.OpenAITextResponse {
|
|
|
@@ -437,48 +434,65 @@ func ResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.Ope
|
|
|
return &fullTextResponse
|
|
|
}
|
|
|
|
|
|
+type ClaudeResponseInfo struct {
|
|
|
+ ResponseId string
|
|
|
+ Created int64
|
|
|
+ Model string
|
|
|
+ ResponseText strings.Builder
|
|
|
+ Usage *dto.Usage
|
|
|
+}
|
|
|
+
|
|
|
+func FormatClaudeResponseInfo(requestMode int, claudeResponse *ClaudeResponse, oaiResponse *dto.ChatCompletionsStreamResponse, claudeInfo *ClaudeResponseInfo) bool {
|
|
|
+ if oaiResponse == nil {
|
|
|
+ return false
|
|
|
+ }
|
|
|
+ if requestMode == RequestModeCompletion {
|
|
|
+ claudeInfo.ResponseText.WriteString(claudeResponse.Completion)
|
|
|
+ } else {
|
|
|
+ if claudeResponse.Type == "message_start" {
|
|
|
+ // message_start, 获取usage
|
|
|
+ claudeInfo.ResponseId = claudeResponse.Message.Id
|
|
|
+ claudeInfo.Model = claudeResponse.Message.Model
|
|
|
+ claudeInfo.Usage.PromptTokens = claudeResponse.Message.Usage.InputTokens
|
|
|
+ } else if claudeResponse.Type == "content_block_delta" {
|
|
|
+ claudeInfo.ResponseText.WriteString(claudeResponse.Delta.Text)
|
|
|
+ } else if claudeResponse.Type == "message_delta" {
|
|
|
+ claudeInfo.Usage.CompletionTokens = claudeResponse.Usage.OutputTokens
|
|
|
+ claudeInfo.Usage.TotalTokens = claudeResponse.Usage.InputTokens + claudeResponse.Usage.OutputTokens
|
|
|
+ } else if claudeResponse.Type == "content_block_start" {
|
|
|
+ } else {
|
|
|
+ return false
|
|
|
+ }
|
|
|
+ }
|
|
|
+ oaiResponse.Id = claudeInfo.ResponseId
|
|
|
+ oaiResponse.Created = claudeInfo.Created
|
|
|
+ oaiResponse.Model = claudeInfo.Model
|
|
|
+ return true
|
|
|
+}
|
|
|
+
|
|
|
func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
|
|
responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
|
|
|
- var usage *dto.Usage
|
|
|
- usage = &dto.Usage{}
|
|
|
- responseText := ""
|
|
|
- createdTime := common.GetTimestamp()
|
|
|
+ claudeInfo := &ClaudeResponseInfo{
|
|
|
+ ResponseId: responseId,
|
|
|
+ Created: common.GetTimestamp(),
|
|
|
+ Model: info.UpstreamModelName,
|
|
|
+ ResponseText: strings.Builder{},
|
|
|
+ Usage: &dto.Usage{},
|
|
|
+ }
|
|
|
|
|
|
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
|
|
|
var claudeResponse ClaudeResponse
|
|
|
- err := json.Unmarshal([]byte(data), &claudeResponse)
|
|
|
+ err := json.NewDecoder(bytes.NewReader(common.StringToByteSlice(data))).Decode(&claudeResponse)
|
|
|
if err != nil {
|
|
|
common.SysError("error unmarshalling stream response: " + err.Error())
|
|
|
return true
|
|
|
}
|
|
|
|
|
|
- response, claudeUsage := StreamResponseClaude2OpenAI(requestMode, &claudeResponse)
|
|
|
- if response == nil {
|
|
|
+ response := StreamResponseClaude2OpenAI(requestMode, &claudeResponse)
|
|
|
+
|
|
|
+ if !FormatClaudeResponseInfo(requestMode, &claudeResponse, response, claudeInfo) {
|
|
|
return true
|
|
|
}
|
|
|
- if requestMode == RequestModeCompletion {
|
|
|
- responseText += claudeResponse.Completion
|
|
|
- responseId = response.Id
|
|
|
- } else {
|
|
|
- if claudeResponse.Type == "message_start" {
|
|
|
- // message_start, 获取usage
|
|
|
- responseId = claudeResponse.Message.Id
|
|
|
- info.UpstreamModelName = claudeResponse.Message.Model
|
|
|
- usage.PromptTokens = claudeUsage.InputTokens
|
|
|
- } else if claudeResponse.Type == "content_block_delta" {
|
|
|
- responseText += claudeResponse.Delta.Text
|
|
|
- } else if claudeResponse.Type == "message_delta" {
|
|
|
- usage.CompletionTokens = claudeUsage.OutputTokens
|
|
|
- usage.TotalTokens = claudeUsage.InputTokens + claudeUsage.OutputTokens
|
|
|
- } else if claudeResponse.Type == "content_block_start" {
|
|
|
- } else {
|
|
|
- return true
|
|
|
- }
|
|
|
- }
|
|
|
- //response.Id = responseId
|
|
|
- response.Id = responseId
|
|
|
- response.Created = createdTime
|
|
|
- response.Model = info.UpstreamModelName
|
|
|
|
|
|
err = helper.ObjectData(c, response)
|
|
|
if err != nil {
|
|
|
@@ -488,25 +502,24 @@ func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
|
|
|
})
|
|
|
|
|
|
if requestMode == RequestModeCompletion {
|
|
|
- usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
|
|
+ claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, info.PromptTokens)
|
|
|
} else {
|
|
|
- if usage.PromptTokens == 0 {
|
|
|
- usage.PromptTokens = info.PromptTokens
|
|
|
+ if claudeInfo.Usage.PromptTokens == 0 {
|
|
|
+ //上游出错
|
|
|
}
|
|
|
- if usage.CompletionTokens == 0 {
|
|
|
- usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, usage.PromptTokens)
|
|
|
+ if claudeInfo.Usage.CompletionTokens == 0 {
|
|
|
+ claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens)
|
|
|
}
|
|
|
}
|
|
|
if info.ShouldIncludeUsage {
|
|
|
- response := helper.GenerateFinalUsageResponse(responseId, createdTime, info.UpstreamModelName, *usage)
|
|
|
+ response := helper.GenerateFinalUsageResponse(responseId, claudeInfo.Created, info.UpstreamModelName, *claudeInfo.Usage)
|
|
|
err := helper.ObjectData(c, response)
|
|
|
if err != nil {
|
|
|
common.SysError("send final response failed: " + err.Error())
|
|
|
}
|
|
|
}
|
|
|
helper.Done(c)
|
|
|
- //resp.Body.Close()
|
|
|
- return nil, usage
|
|
|
+ return nil, claudeInfo.Usage
|
|
|
}
|
|
|
|
|
|
func ClaudeHandler(c *gin.Context, resp *http.Response, requestMode int, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|