ソースを参照

refactor: refactor claude related code

JustSong 2 年 前
コミット
675847bf98
2 ファイル変更71 行追加50 行削除
  1. 64 0
      controller/relay-claude.go
  2. 7 50
      controller/relay-text.go

+ 64 - 0
controller/relay-claude.go

@@ -1,5 +1,11 @@
 package controller
 package controller
 
 
+import (
+	"fmt"
+	"one-api/common"
+	"strings"
+)
+
 type ClaudeMetadata struct {
 type ClaudeMetadata struct {
 	UserId string `json:"user_id"`
 	UserId string `json:"user_id"`
 }
 }
@@ -38,3 +44,61 @@ func stopReasonClaude2OpenAI(reason string) string {
 		return reason
 		return reason
 	}
 	}
 }
 }
+
+func requestOpenAI2Claude(textRequest GeneralOpenAIRequest) *ClaudeRequest {
+	claudeRequest := ClaudeRequest{
+		Model:             textRequest.Model,
+		Prompt:            "",
+		MaxTokensToSample: textRequest.MaxTokens,
+		StopSequences:     nil,
+		Temperature:       textRequest.Temperature,
+		TopP:              textRequest.TopP,
+		Stream:            textRequest.Stream,
+	}
+	if claudeRequest.MaxTokensToSample == 0 {
+		claudeRequest.MaxTokensToSample = 1000000
+	}
+	prompt := ""
+	for _, message := range textRequest.Messages {
+		if message.Role == "user" {
+			prompt += fmt.Sprintf("\n\nHuman: %s", message.Content)
+		} else if message.Role == "assistant" {
+			prompt += fmt.Sprintf("\n\nAssistant: %s", message.Content)
+		} else {
+			// ignore other roles
+		}
+		prompt += "\n\nAssistant:"
+	}
+	claudeRequest.Prompt = prompt
+	return &claudeRequest
+}
+
+func streamResponseClaude2OpenAI(claudeResponse *ClaudeResponse) *ChatCompletionsStreamResponse {
+	var choice ChatCompletionsStreamResponseChoice
+	choice.Delta.Content = claudeResponse.Completion
+	choice.FinishReason = stopReasonClaude2OpenAI(claudeResponse.StopReason)
+	var response ChatCompletionsStreamResponse
+	response.Object = "chat.completion.chunk"
+	response.Model = claudeResponse.Model
+	response.Choices = []ChatCompletionsStreamResponseChoice{choice}
+	return &response
+}
+
+func responseClaude2OpenAI(claudeResponse *ClaudeResponse) *OpenAITextResponse {
+	choice := OpenAITextResponseChoice{
+		Index: 0,
+		Message: Message{
+			Role:    "assistant",
+			Content: strings.TrimPrefix(claudeResponse.Completion, " "),
+			Name:    nil,
+		},
+		FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason),
+	}
+	fullTextResponse := OpenAITextResponse{
+		Id:      fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
+		Object:  "chat.completion",
+		Created: common.GetTimestamp(),
+		Choices: []OpenAITextResponseChoice{choice},
+	}
+	return &fullTextResponse
+}

+ 7 - 50
controller/relay-text.go

@@ -159,30 +159,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 	}
 	}
 	switch apiType {
 	switch apiType {
 	case APITypeClaude:
 	case APITypeClaude:
-		claudeRequest := ClaudeRequest{
-			Model:             textRequest.Model,
-			Prompt:            "",
-			MaxTokensToSample: textRequest.MaxTokens,
-			StopSequences:     nil,
-			Temperature:       textRequest.Temperature,
-			TopP:              textRequest.TopP,
-			Stream:            textRequest.Stream,
-		}
-		if claudeRequest.MaxTokensToSample == 0 {
-			claudeRequest.MaxTokensToSample = 1000000
-		}
-		prompt := ""
-		for _, message := range textRequest.Messages {
-			if message.Role == "user" {
-				prompt += fmt.Sprintf("\n\nHuman: %s", message.Content)
-			} else if message.Role == "assistant" {
-				prompt += fmt.Sprintf("\n\nAssistant: %s", message.Content)
-			} else {
-				// ignore other roles
-			}
-			prompt += "\n\nAssistant:"
-		}
-		claudeRequest.Prompt = prompt
+		claudeRequest := requestOpenAI2Claude(textRequest)
 		jsonStr, err := json.Marshal(claudeRequest)
 		jsonStr, err := json.Marshal(claudeRequest)
 		if err != nil {
 		if err != nil {
 			return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
 			return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
@@ -441,15 +418,9 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 						return true
 						return true
 					}
 					}
 					streamResponseText += claudeResponse.Completion
 					streamResponseText += claudeResponse.Completion
-					var choice ChatCompletionsStreamResponseChoice
-					choice.Delta.Content = claudeResponse.Completion
-					choice.FinishReason = stopReasonClaude2OpenAI(claudeResponse.StopReason)
-					var response ChatCompletionsStreamResponse
+					response := streamResponseClaude2OpenAI(&claudeResponse)
 					response.Id = responseId
 					response.Id = responseId
 					response.Created = createdTime
 					response.Created = createdTime
-					response.Object = "chat.completion.chunk"
-					response.Model = textRequest.Model
-					response.Choices = []ChatCompletionsStreamResponseChoice{choice}
 					jsonStr, err := json.Marshal(response)
 					jsonStr, err := json.Marshal(response)
 					if err != nil {
 					if err != nil {
 						common.SysError("error marshalling stream response: " + err.Error())
 						common.SysError("error marshalling stream response: " + err.Error())
@@ -492,26 +463,12 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 					StatusCode: resp.StatusCode,
 					StatusCode: resp.StatusCode,
 				}
 				}
 			}
 			}
-			choice := OpenAITextResponseChoice{
-				Index: 0,
-				Message: Message{
-					Role:    "assistant",
-					Content: strings.TrimPrefix(claudeResponse.Completion, " "),
-					Name:    nil,
-				},
-				FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason),
-			}
+			fullTextResponse := responseClaude2OpenAI(&claudeResponse)
 			completionTokens := countTokenText(claudeResponse.Completion, textRequest.Model)
 			completionTokens := countTokenText(claudeResponse.Completion, textRequest.Model)
-			fullTextResponse := OpenAITextResponse{
-				Id:      fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
-				Object:  "chat.completion",
-				Created: common.GetTimestamp(),
-				Choices: []OpenAITextResponseChoice{choice},
-				Usage: Usage{
-					PromptTokens:     promptTokens,
-					CompletionTokens: completionTokens,
-					TotalTokens:      promptTokens + promptTokens,
-				},
+			fullTextResponse.Usage = Usage{
+				PromptTokens:     promptTokens,
+				CompletionTokens: completionTokens,
+				TotalTokens:      promptTokens + completionTokens,
 			}
 			}
 			textResponse.Usage = fullTextResponse.Usage
 			textResponse.Usage = fullTextResponse.Usage
 			jsonResponse, err := json.Marshal(fullTextResponse)
 			jsonResponse, err := json.Marshal(fullTextResponse)