| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535 |
- package claude
- import (
- "bufio"
- "encoding/json"
- "fmt"
- "io"
- "net/http"
- "one-api/common"
- "one-api/dto"
- relaycommon "one-api/relay/common"
- "one-api/service"
- "strings"
- "github.com/gin-gonic/gin"
- )
- func stopReasonClaude2OpenAI(reason string) string {
- switch reason {
- case "stop_sequence":
- return "stop"
- case "end_turn":
- return "stop"
- case "max_tokens":
- return "max_tokens"
- default:
- return reason
- }
- }
- func RequestOpenAI2ClaudeComplete(textRequest dto.GeneralOpenAIRequest) *ClaudeRequest {
- claudeRequest := ClaudeRequest{
- Model: textRequest.Model,
- Prompt: "",
- StopSequences: nil,
- Temperature: textRequest.Temperature,
- TopP: textRequest.TopP,
- TopK: textRequest.TopK,
- Stream: textRequest.Stream,
- }
- if claudeRequest.MaxTokensToSample == 0 {
- claudeRequest.MaxTokensToSample = 4096
- }
- prompt := ""
- for _, message := range textRequest.Messages {
- if message.Role == "user" {
- prompt += fmt.Sprintf("\n\nHuman: %s", message.Content)
- } else if message.Role == "assistant" {
- prompt += fmt.Sprintf("\n\nAssistant: %s", message.Content)
- } else if message.Role == "system" {
- if prompt == "" {
- prompt = message.StringContent()
- }
- }
- }
- prompt += "\n\nAssistant:"
- claudeRequest.Prompt = prompt
- return &claudeRequest
- }
- func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeRequest, error) {
- claudeTools := make([]Tool, 0, len(textRequest.Tools))
- for _, tool := range textRequest.Tools {
- if params, ok := tool.Function.Parameters.(map[string]any); ok {
- claudeTool := Tool{
- Name: tool.Function.Name,
- Description: tool.Function.Description,
- }
- claudeTool.InputSchema = make(map[string]interface{})
- claudeTool.InputSchema["type"] = params["type"].(string)
- claudeTool.InputSchema["properties"] = params["properties"]
- claudeTool.InputSchema["required"] = params["required"]
- for s, a := range params {
- if s == "type" || s == "properties" || s == "required" {
- continue
- }
- claudeTool.InputSchema[s] = a
- }
- claudeTools = append(claudeTools, claudeTool)
- }
- }
- claudeRequest := ClaudeRequest{
- Model: textRequest.Model,
- MaxTokens: textRequest.MaxTokens,
- StopSequences: nil,
- Temperature: textRequest.Temperature,
- TopP: textRequest.TopP,
- TopK: textRequest.TopK,
- Stream: textRequest.Stream,
- Tools: claudeTools,
- }
- if claudeRequest.MaxTokens == 0 {
- claudeRequest.MaxTokens = 4096
- }
- if textRequest.Stop != nil {
- // stop maybe string/array string, convert to array string
- switch textRequest.Stop.(type) {
- case string:
- claudeRequest.StopSequences = []string{textRequest.Stop.(string)}
- case []interface{}:
- stopSequences := make([]string, 0)
- for _, stop := range textRequest.Stop.([]interface{}) {
- stopSequences = append(stopSequences, stop.(string))
- }
- claudeRequest.StopSequences = stopSequences
- }
- }
- formatMessages := make([]dto.Message, 0)
- lastMessage := dto.Message{
- Role: "tool",
- }
- for i, message := range textRequest.Messages {
- if message.Role == "" {
- textRequest.Messages[i].Role = "user"
- }
- fmtMessage := dto.Message{
- Role: message.Role,
- Content: message.Content,
- }
- if message.Role == "tool" {
- fmtMessage.ToolCallId = message.ToolCallId
- }
- if message.Role == "assistant" && message.ToolCalls != nil {
- fmtMessage.ToolCalls = message.ToolCalls
- }
- if lastMessage.Role == message.Role && lastMessage.Role != "tool" {
- if lastMessage.IsStringContent() && message.IsStringContent() {
- content, _ := json.Marshal(strings.Trim(fmt.Sprintf("%s %s", lastMessage.StringContent(), message.StringContent()), "\""))
- fmtMessage.Content = content
- // delete last message
- formatMessages = formatMessages[:len(formatMessages)-1]
- }
- }
- if fmtMessage.Content == nil {
- content, _ := json.Marshal("...")
- fmtMessage.Content = content
- }
- formatMessages = append(formatMessages, fmtMessage)
- lastMessage = fmtMessage
- }
- claudeMessages := make([]ClaudeMessage, 0)
- isFirstMessage := true
- for _, message := range formatMessages {
- if message.Role == "system" {
- if message.IsStringContent() {
- claudeRequest.System = message.StringContent()
- } else {
- contents := message.ParseContent()
- content := ""
- for _, ctx := range contents {
- if ctx.Type == "text" {
- content += ctx.Text
- }
- }
- claudeRequest.System = content
- }
- } else {
- if isFirstMessage {
- isFirstMessage = false
- if message.Role != "user" {
- // fix: first message is assistant, add user message
- claudeMessage := ClaudeMessage{
- Role: "user",
- Content: []ClaudeMediaMessage{
- {
- Type: "text",
- Text: "...",
- },
- },
- }
- claudeMessages = append(claudeMessages, claudeMessage)
- }
- }
- claudeMessage := ClaudeMessage{
- Role: message.Role,
- }
- if message.Role == "tool" {
- if len(claudeMessages) > 0 && claudeMessages[len(claudeMessages)-1].Role == "user" {
- lastMessage := claudeMessages[len(claudeMessages)-1]
- if content, ok := lastMessage.Content.(string); ok {
- lastMessage.Content = []ClaudeMediaMessage{
- {
- Type: "text",
- Text: content,
- },
- }
- }
- lastMessage.Content = append(lastMessage.Content.([]ClaudeMediaMessage), ClaudeMediaMessage{
- Type: "tool_result",
- ToolUseId: message.ToolCallId,
- Content: message.StringContent(),
- })
- claudeMessages[len(claudeMessages)-1] = lastMessage
- continue
- } else {
- claudeMessage.Role = "user"
- claudeMessage.Content = []ClaudeMediaMessage{
- {
- Type: "tool_result",
- ToolUseId: message.ToolCallId,
- Content: message.StringContent(),
- },
- }
- }
- } else if message.IsStringContent() && message.ToolCalls == nil {
- claudeMessage.Content = message.StringContent()
- } else {
- claudeMediaMessages := make([]ClaudeMediaMessage, 0)
- for _, mediaMessage := range message.ParseContent() {
- claudeMediaMessage := ClaudeMediaMessage{
- Type: mediaMessage.Type,
- }
- if mediaMessage.Type == "text" {
- claudeMediaMessage.Text = mediaMessage.Text
- } else {
- imageUrl := mediaMessage.ImageUrl.(dto.MessageImageUrl)
- claudeMediaMessage.Type = "image"
- claudeMediaMessage.Source = &ClaudeMessageSource{
- Type: "base64",
- }
- // 判断是否是url
- if strings.HasPrefix(imageUrl.Url, "http") {
- // 是url,获取图片的类型和base64编码的数据
- mimeType, data, _ := service.GetImageFromUrl(imageUrl.Url)
- claudeMediaMessage.Source.MediaType = mimeType
- claudeMediaMessage.Source.Data = data
- } else {
- _, format, base64String, err := service.DecodeBase64ImageData(imageUrl.Url)
- if err != nil {
- return nil, err
- }
- claudeMediaMessage.Source.MediaType = "image/" + format
- claudeMediaMessage.Source.Data = base64String
- }
- }
- claudeMediaMessages = append(claudeMediaMessages, claudeMediaMessage)
- }
- if message.ToolCalls != nil {
- for _, tc := range message.ToolCalls.([]interface{}) {
- toolCallJSON, _ := json.Marshal(tc)
- var toolCall dto.ToolCall
- err := json.Unmarshal(toolCallJSON, &toolCall)
- if err != nil {
- common.SysError("tool call is not a dto.ToolCall: " + fmt.Sprintf("%v", tc))
- continue
- }
- inputObj := make(map[string]any)
- if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &inputObj); err != nil {
- common.SysError("tool call function arguments is not a map[string]any: " + fmt.Sprintf("%v", toolCall.Function.Arguments))
- continue
- }
- claudeMediaMessages = append(claudeMediaMessages, ClaudeMediaMessage{
- Type: "tool_use",
- Id: toolCall.ID,
- Name: toolCall.Function.Name,
- Input: inputObj,
- })
- }
- }
- claudeMessage.Content = claudeMediaMessages
- }
- claudeMessages = append(claudeMessages, claudeMessage)
- }
- }
- claudeRequest.Prompt = ""
- claudeRequest.Messages = claudeMessages
- return &claudeRequest, nil
- }
- func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) (*dto.ChatCompletionsStreamResponse, *ClaudeUsage) {
- var response dto.ChatCompletionsStreamResponse
- var claudeUsage *ClaudeUsage
- response.Object = "chat.completion.chunk"
- response.Model = claudeResponse.Model
- response.Choices = make([]dto.ChatCompletionsStreamResponseChoice, 0)
- tools := make([]dto.ToolCall, 0)
- var choice dto.ChatCompletionsStreamResponseChoice
- if reqMode == RequestModeCompletion {
- choice.Delta.SetContentString(claudeResponse.Completion)
- finishReason := stopReasonClaude2OpenAI(claudeResponse.StopReason)
- if finishReason != "null" {
- choice.FinishReason = &finishReason
- }
- } else {
- if claudeResponse.Type == "message_start" {
- response.Id = claudeResponse.Message.Id
- response.Model = claudeResponse.Message.Model
- claudeUsage = &claudeResponse.Message.Usage
- choice.Delta.SetContentString("")
- choice.Delta.Role = "assistant"
- } else if claudeResponse.Type == "content_block_start" {
- if claudeResponse.ContentBlock != nil {
- //choice.Delta.SetContentString(claudeResponse.ContentBlock.Text)
- if claudeResponse.ContentBlock.Type == "tool_use" {
- tools = append(tools, dto.ToolCall{
- ID: claudeResponse.ContentBlock.Id,
- Type: "function",
- Function: dto.FunctionCall{
- Name: claudeResponse.ContentBlock.Name,
- Arguments: "",
- },
- })
- }
- } else {
- return nil, nil
- }
- } else if claudeResponse.Type == "content_block_delta" {
- if claudeResponse.Delta != nil {
- choice.Index = claudeResponse.Index
- choice.Delta.SetContentString(claudeResponse.Delta.Text)
- if claudeResponse.Delta.Type == "input_json_delta" {
- tools = append(tools, dto.ToolCall{
- Function: dto.FunctionCall{
- Arguments: claudeResponse.Delta.PartialJson,
- },
- })
- }
- }
- } else if claudeResponse.Type == "message_delta" {
- finishReason := stopReasonClaude2OpenAI(*claudeResponse.Delta.StopReason)
- if finishReason != "null" {
- choice.FinishReason = &finishReason
- }
- claudeUsage = &claudeResponse.Usage
- } else if claudeResponse.Type == "message_stop" {
- return nil, nil
- } else {
- return nil, nil
- }
- }
- if claudeUsage == nil {
- claudeUsage = &ClaudeUsage{}
- }
- if len(tools) > 0 {
- choice.Delta.Content = nil // compatible with other OpenAI derivative applications, like LobeOpenAICompatibleFactory ...
- choice.Delta.ToolCalls = tools
- }
- response.Choices = append(response.Choices, choice)
- return &response, claudeUsage
- }
- func ResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.OpenAITextResponse {
- choices := make([]dto.OpenAITextResponseChoice, 0)
- fullTextResponse := dto.OpenAITextResponse{
- Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
- Object: "chat.completion",
- Created: common.GetTimestamp(),
- }
- var responseText string
- if len(claudeResponse.Content) > 0 {
- responseText = claudeResponse.Content[0].Text
- }
- tools := make([]dto.ToolCall, 0)
- if reqMode == RequestModeCompletion {
- content, _ := json.Marshal(strings.TrimPrefix(claudeResponse.Completion, " "))
- choice := dto.OpenAITextResponseChoice{
- Index: 0,
- Message: dto.Message{
- Role: "assistant",
- Content: content,
- Name: nil,
- },
- FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason),
- }
- choices = append(choices, choice)
- } else {
- fullTextResponse.Id = claudeResponse.Id
- for _, message := range claudeResponse.Content {
- if message.Type == "tool_use" {
- args, _ := json.Marshal(message.Input)
- tools = append(tools, dto.ToolCall{
- ID: message.Id,
- Type: "function", // compatible with other OpenAI derivative applications
- Function: dto.FunctionCall{
- Name: message.Name,
- Arguments: string(args),
- },
- })
- }
- }
- }
- choice := dto.OpenAITextResponseChoice{
- Index: 0,
- Message: dto.Message{
- Role: "assistant",
- },
- FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason),
- }
- choice.SetStringContent(responseText)
- if len(tools) > 0 {
- choice.Message.ToolCalls = tools
- }
- fullTextResponse.Model = claudeResponse.Model
- choices = append(choices, choice)
- fullTextResponse.Choices = choices
- return &fullTextResponse
- }
- func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
- responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
- var usage *dto.Usage
- usage = &dto.Usage{}
- responseText := ""
- createdTime := common.GetTimestamp()
- scanner := bufio.NewScanner(resp.Body)
- scanner.Split(bufio.ScanLines)
- service.SetEventStreamHeaders(c)
- for scanner.Scan() {
- data := scanner.Text()
- info.SetFirstResponseTime()
- if len(data) < 6 || !strings.HasPrefix(data, "data:") {
- continue
- }
- data = strings.TrimPrefix(data, "data:")
- data = strings.TrimSpace(data)
- var claudeResponse ClaudeResponse
- err := json.Unmarshal([]byte(data), &claudeResponse)
- if err != nil {
- common.SysError("error unmarshalling stream response: " + err.Error())
- continue
- }
- response, claudeUsage := StreamResponseClaude2OpenAI(requestMode, &claudeResponse)
- if response == nil {
- continue
- }
- if requestMode == RequestModeCompletion {
- responseText += claudeResponse.Completion
- responseId = response.Id
- } else {
- if claudeResponse.Type == "message_start" {
- // message_start, 获取usage
- responseId = claudeResponse.Message.Id
- info.UpstreamModelName = claudeResponse.Message.Model
- usage.PromptTokens = claudeUsage.InputTokens
- } else if claudeResponse.Type == "content_block_delta" {
- responseText += claudeResponse.Delta.Text
- } else if claudeResponse.Type == "message_delta" {
- usage.CompletionTokens = claudeUsage.OutputTokens
- usage.TotalTokens = claudeUsage.InputTokens + claudeUsage.OutputTokens
- } else if claudeResponse.Type == "content_block_start" {
- } else {
- continue
- }
- }
- //response.Id = responseId
- response.Id = responseId
- response.Created = createdTime
- response.Model = info.UpstreamModelName
- err = service.ObjectData(c, response)
- if err != nil {
- common.LogError(c, "send_stream_response_failed: "+err.Error())
- }
- }
- if requestMode == RequestModeCompletion {
- usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
- } else {
- if usage.PromptTokens == 0 {
- usage.PromptTokens = info.PromptTokens
- }
- if usage.CompletionTokens == 0 {
- usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, usage.PromptTokens)
- }
- }
- if info.ShouldIncludeUsage {
- response := service.GenerateFinalUsageResponse(responseId, createdTime, info.UpstreamModelName, *usage)
- err := service.ObjectData(c, response)
- if err != nil {
- common.SysError("send final response failed: " + err.Error())
- }
- }
- service.Done(c)
- resp.Body.Close()
- return nil, usage
- }
- func ClaudeHandler(c *gin.Context, resp *http.Response, requestMode int, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
- responseBody, err := io.ReadAll(resp.Body)
- if err != nil {
- return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
- }
- err = resp.Body.Close()
- if err != nil {
- return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
- }
- var claudeResponse ClaudeResponse
- err = json.Unmarshal(responseBody, &claudeResponse)
- if err != nil {
- return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
- }
- if claudeResponse.Error.Type != "" {
- return &dto.OpenAIErrorWithStatusCode{
- Error: dto.OpenAIError{
- Message: claudeResponse.Error.Message,
- Type: claudeResponse.Error.Type,
- Param: "",
- Code: claudeResponse.Error.Type,
- },
- StatusCode: resp.StatusCode,
- }, nil
- }
- fullTextResponse := ResponseClaude2OpenAI(requestMode, &claudeResponse)
- completionTokens, err := service.CountTextToken(claudeResponse.Completion, info.OriginModelName)
- if err != nil {
- return service.OpenAIErrorWrapper(err, "count_token_text_failed", http.StatusInternalServerError), nil
- }
- usage := dto.Usage{}
- if requestMode == RequestModeCompletion {
- usage.PromptTokens = info.PromptTokens
- usage.CompletionTokens = completionTokens
- usage.TotalTokens = info.PromptTokens + completionTokens
- } else {
- usage.PromptTokens = claudeResponse.Usage.InputTokens
- usage.CompletionTokens = claudeResponse.Usage.OutputTokens
- usage.TotalTokens = claudeResponse.Usage.InputTokens + claudeResponse.Usage.OutputTokens
- }
- fullTextResponse.Usage = usage
- jsonResponse, err := json.Marshal(fullTextResponse)
- if err != nil {
- return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
- }
- c.Writer.Header().Set("Content-Type", "application/json")
- c.Writer.WriteHeader(resp.StatusCode)
- _, err = c.Writer.Write(jsonResponse)
- return nil, &usage
- }
|