Преглед изворни кода

Merge remote-tracking branch 'origin/alpha' into refactor/model-pricing

t0ng7u пре 7 месеци
родитељ
комит
26f44b8d4b

+ 2 - 2
controller/channel-test.go

@@ -161,7 +161,7 @@ func testChannel(channel *model.Channel, testModel string) testResult {
 	logInfo.ApiKey = ""
 	common.SysLog(fmt.Sprintf("testing channel %d with model %s , info %+v ", channel.Id, testModel, logInfo))
 
-	priceData, err := helper.ModelPriceHelper(c, info, 0, int(request.MaxTokens))
+	priceData, err := helper.ModelPriceHelper(c, info, 0, int(request.GetMaxTokens()))
 	if err != nil {
 		return testResult{
 			context:     c,
@@ -275,7 +275,7 @@ func testChannel(channel *model.Channel, testModel string) testResult {
 		Quota:            quota,
 		Content:          "模型测试",
 		UseTimeSeconds:   int(consumedTime),
-		IsStream:         false,
+		IsStream:         info.IsStream,
 		Group:            info.UsingGroup,
 		Other:            other,
 	})

+ 10 - 50
controller/channel.go

@@ -36,30 +36,11 @@ type OpenAIModel struct {
 	Parent string `json:"parent"`
 }
 
-type GoogleOpenAICompatibleModels []struct {
-	Name                       string   `json:"name"`
-	Version                    string   `json:"version"`
-	DisplayName                string   `json:"displayName"`
-	Description                string   `json:"description,omitempty"`
-	InputTokenLimit            int      `json:"inputTokenLimit"`
-	OutputTokenLimit           int      `json:"outputTokenLimit"`
-	SupportedGenerationMethods []string `json:"supportedGenerationMethods"`
-	Temperature                float64  `json:"temperature,omitempty"`
-	TopP                       float64  `json:"topP,omitempty"`
-	TopK                       int      `json:"topK,omitempty"`
-	MaxTemperature             int      `json:"maxTemperature,omitempty"`
-}
-
 type OpenAIModelsResponse struct {
 	Data    []OpenAIModel `json:"data"`
 	Success bool          `json:"success"`
 }
 
-type GoogleOpenAICompatibleResponse struct {
-	Models        []GoogleOpenAICompatibleModels `json:"models"`
-	NextPageToken string                         `json:"nextPageToken"`
-}
-
 func parseStatusFilter(statusParam string) int {
 	switch strings.ToLower(statusParam) {
 	case "enabled", "1":
@@ -203,7 +184,7 @@ func FetchUpstreamModels(c *gin.Context) {
 	switch channel.Type {
 	case constant.ChannelTypeGemini:
 		// curl https://example.com/v1beta/models?key=$GEMINI_API_KEY
-		url = fmt.Sprintf("%s/v1beta/openai/models?key=%s", baseURL, channel.Key)
+		url = fmt.Sprintf("%s/v1beta/openai/models", baseURL) // Remove key in url since we need to use AuthHeader
 	case constant.ChannelTypeAli:
 		url = fmt.Sprintf("%s/compatible-mode/v1/models", baseURL)
 	default:
@@ -212,10 +193,11 @@ func FetchUpstreamModels(c *gin.Context) {
 
 	// 获取响应体 - 根据渠道类型决定是否添加 AuthHeader
 	var body []byte
+	key := strings.Split(channel.Key, "\n")[0]
 	if channel.Type == constant.ChannelTypeGemini {
-		body, err = GetResponseBody("GET", url, channel, nil) // I don't know why, but Gemini requires no AuthHeader
+		body, err = GetResponseBody("GET", url, channel, GetAuthHeader(key)) // Use AuthHeader since Gemini now forces it
 	} else {
-		body, err = GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
+		body, err = GetResponseBody("GET", url, channel, GetAuthHeader(key))
 	}
 	if err != nil {
 		common.ApiError(c, err)
@@ -223,34 +205,12 @@ func FetchUpstreamModels(c *gin.Context) {
 	}
 
 	var result OpenAIModelsResponse
-	var parseSuccess bool
-
-	// 适配特殊格式
-	switch channel.Type {
-	case constant.ChannelTypeGemini:
-		var googleResult GoogleOpenAICompatibleResponse
-		if err = json.Unmarshal(body, &googleResult); err == nil {
-			// 转换Google格式到OpenAI格式
-			for _, model := range googleResult.Models {
-				for _, gModel := range model {
-					result.Data = append(result.Data, OpenAIModel{
-						ID: gModel.Name,
-					})
-				}
-			}
-			parseSuccess = true
-		}
-	}
-
-	// 如果解析失败,尝试OpenAI格式
-	if !parseSuccess {
-		if err = json.Unmarshal(body, &result); err != nil {
-			c.JSON(http.StatusOK, gin.H{
-				"success": false,
-				"message": fmt.Sprintf("解析响应失败: %s", err.Error()),
-			})
-			return
-		}
+	if err = json.Unmarshal(body, &result); err != nil {
+		c.JSON(http.StatusOK, gin.H{
+			"success": false,
+			"message": fmt.Sprintf("解析响应失败: %s", err.Error()),
+		})
+		return
 	}
 
 	var ids []string

+ 1 - 1
dto/claude.go

@@ -361,7 +361,7 @@ type ClaudeUsage struct {
 	CacheCreationInputTokens int                  `json:"cache_creation_input_tokens"`
 	CacheReadInputTokens     int                  `json:"cache_read_input_tokens"`
 	OutputTokens             int                  `json:"output_tokens"`
-	ServerToolUse            *ClaudeServerToolUse `json:"server_tool_use"`
+	ServerToolUse            *ClaudeServerToolUse `json:"server_tool_use,omitempty"`
 }
 
 type ClaudeServerToolUse struct {

+ 5 - 2
dto/openai_request.go

@@ -99,8 +99,11 @@ type StreamOptions struct {
 	IncludeUsage bool `json:"include_usage,omitempty"`
 }
 
-func (r *GeneralOpenAIRequest) GetMaxTokens() int {
-	return int(r.MaxTokens)
+func (r *GeneralOpenAIRequest) GetMaxTokens() uint {
+	if r.MaxCompletionTokens != 0 {
+		return r.MaxCompletionTokens
+	}
+	return r.MaxTokens
 }
 
 func (r *GeneralOpenAIRequest) ParseInput() []string {

+ 47 - 27
relay/channel/ali/adaptor.go

@@ -3,16 +3,17 @@ package ali
 import (
 	"errors"
 	"fmt"
+	"github.com/gin-gonic/gin"
 	"io"
 	"net/http"
 	"one-api/dto"
 	"one-api/relay/channel"
+	"one-api/relay/channel/claude"
 	"one-api/relay/channel/openai"
 	relaycommon "one-api/relay/common"
 	"one-api/relay/constant"
 	"one-api/types"
-
-	"github.com/gin-gonic/gin"
+	"strings"
 )
 
 type Adaptor struct {
@@ -23,10 +24,8 @@ func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dt
 	return nil, errors.New("not implemented")
 }
 
-func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
-	//TODO implement me
-	panic("implement me")
-	return nil, nil
+func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, req *dto.ClaudeRequest) (any, error) {
+	return req, nil
 }
 
 func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
@@ -34,18 +33,24 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
 
 func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
 	var fullRequestURL string
-	switch info.RelayMode {
-	case constant.RelayModeEmbeddings:
-		fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/embeddings", info.BaseUrl)
-	case constant.RelayModeRerank:
-		fullRequestURL = fmt.Sprintf("%s/api/v1/services/rerank/text-rerank/text-rerank", info.BaseUrl)
-	case constant.RelayModeImagesGenerations:
-		fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/text2image/image-synthesis", info.BaseUrl)
-	case constant.RelayModeCompletions:
-		fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/completions", info.BaseUrl)
+	switch info.RelayFormat {
+	case relaycommon.RelayFormatClaude:
+		fullRequestURL = fmt.Sprintf("%s/api/v2/apps/claude-code-proxy/v1/messages", info.BaseUrl)
 	default:
-		fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/chat/completions", info.BaseUrl)
+		switch info.RelayMode {
+		case constant.RelayModeEmbeddings:
+			fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/embeddings", info.BaseUrl)
+		case constant.RelayModeRerank:
+			fullRequestURL = fmt.Sprintf("%s/api/v1/services/rerank/text-rerank/text-rerank", info.BaseUrl)
+		case constant.RelayModeImagesGenerations:
+			fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/text2image/image-synthesis", info.BaseUrl)
+		case constant.RelayModeCompletions:
+			fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/completions", info.BaseUrl)
+		default:
+			fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/chat/completions", info.BaseUrl)
+		}
 	}
+
 	return fullRequestURL, nil
 }
 
@@ -65,7 +70,13 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
 	if request == nil {
 		return nil, errors.New("request is nil")
 	}
-
+	// docs: https://bailian.console.aliyun.com/?tab=api#/api/?type=model&url=2712216
+	// fix: InternalError.Algo.InvalidParameter: The value of the enable_thinking parameter is restricted to True.
+	if strings.Contains(request.Model, "thinking") {
+		request.EnableThinking = true
+		request.Stream = true
+		info.IsStream = true
+	}
 	// fix: ali parameter.enable_thinking must be set to false for non-streaming calls
 	if !info.IsStream {
 		request.EnableThinking = false
@@ -106,18 +117,27 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
 }
 
 func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
-	switch info.RelayMode {
-	case constant.RelayModeImagesGenerations:
-		err, usage = aliImageHandler(c, resp, info)
-	case constant.RelayModeEmbeddings:
-		err, usage = aliEmbeddingHandler(c, resp)
-	case constant.RelayModeRerank:
-		err, usage = RerankHandler(c, resp, info)
-	default:
+	switch info.RelayFormat {
+	case relaycommon.RelayFormatClaude:
 		if info.IsStream {
-			usage, err = openai.OaiStreamHandler(c, info, resp)
+			err, usage = claude.ClaudeStreamHandler(c, resp, info, claude.RequestModeMessage)
 		} else {
-			usage, err = openai.OpenaiHandler(c, info, resp)
+			err, usage = claude.ClaudeHandler(c, resp, info, claude.RequestModeMessage)
+		}
+	default:
+		switch info.RelayMode {
+		case constant.RelayModeImagesGenerations:
+			err, usage = aliImageHandler(c, resp, info)
+		case constant.RelayModeEmbeddings:
+			err, usage = aliEmbeddingHandler(c, resp)
+		case constant.RelayModeRerank:
+			err, usage = RerankHandler(c, resp, info)
+		default:
+			if info.IsStream {
+				usage, err = openai.OaiStreamHandler(c, info, resp)
+			} else {
+				usage, err = openai.OpenaiHandler(c, info, resp)
+			}
 		}
 	}
 	return

+ 3 - 3
relay/channel/baidu/relay-baidu.go

@@ -34,9 +34,9 @@ func requestOpenAI2Baidu(request dto.GeneralOpenAIRequest) *BaiduChatRequest {
 		EnableCitation: false,
 		UserId:         request.User,
 	}
-	if request.MaxTokens != 0 {
-		maxTokens := int(request.MaxTokens)
-		if request.MaxTokens == 1 {
+	if request.GetMaxTokens() != 0 {
+		maxTokens := int(request.GetMaxTokens())
+		if request.GetMaxTokens() == 1 {
 			maxTokens = 2
 		}
 		baiduRequest.MaxOutputTokens = &maxTokens

+ 20 - 10
relay/channel/baidu_v2/adaptor.go

@@ -9,6 +9,7 @@ import (
 	"one-api/relay/channel"
 	"one-api/relay/channel/openai"
 	relaycommon "one-api/relay/common"
+	"one-api/relay/constant"
 	"one-api/types"
 	"strings"
 
@@ -23,10 +24,9 @@ func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dt
 	return nil, errors.New("not implemented")
 }
 
-func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
-	//TODO implement me
-	panic("implement me")
-	return nil, nil
+func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, req *dto.ClaudeRequest) (any, error) {
+	adaptor := openai.Adaptor{}
+	return adaptor.ConvertClaudeRequest(c, info, req)
 }
 
 func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
@@ -43,7 +43,20 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
 }
 
 func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
-	return fmt.Sprintf("%s/v2/chat/completions", info.BaseUrl), nil
+	switch info.RelayMode {
+	case constant.RelayModeChatCompletions:
+		return fmt.Sprintf("%s/v2/chat/completions", info.BaseUrl), nil
+	case constant.RelayModeEmbeddings:
+		return fmt.Sprintf("%s/v2/embeddings", info.BaseUrl), nil
+	case constant.RelayModeImagesGenerations:
+		return fmt.Sprintf("%s/v2/images/generations", info.BaseUrl), nil
+	case constant.RelayModeImagesEdits:
+		return fmt.Sprintf("%s/v2/images/edits", info.BaseUrl), nil
+	case constant.RelayModeRerank:
+		return fmt.Sprintf("%s/v2/rerank", info.BaseUrl), nil
+	default:
+	}
+	return "", fmt.Errorf("unsupported relay mode: %d", info.RelayMode)
 }
 
 func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
@@ -99,11 +112,8 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
 }
 
 func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
-	if info.IsStream {
-		usage, err = openai.OaiStreamHandler(c, info, resp)
-	} else {
-		usage, err = openai.OpenaiHandler(c, info, resp)
-	}
+	adaptor := openai.Adaptor{}
+	usage, err = adaptor.DoResponse(c, resp, info)
 	return
 }
 

+ 1 - 1
relay/channel/claude/adaptor.go

@@ -104,7 +104,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
 	if info.IsStream {
 		err, usage = ClaudeStreamHandler(c, resp, info, a.RequestMode)
 	} else {
-		err, usage = ClaudeHandler(c, resp, a.RequestMode, info)
+		err, usage = ClaudeHandler(c, resp, info, a.RequestMode)
 	}
 	return
 }

+ 2 - 2
relay/channel/claude/relay-claude.go

@@ -149,7 +149,7 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*dto.Cla
 
 	claudeRequest := dto.ClaudeRequest{
 		Model:         textRequest.Model,
-		MaxTokens:     textRequest.MaxTokens,
+		MaxTokens:     textRequest.GetMaxTokens(),
 		StopSequences: nil,
 		Temperature:   textRequest.Temperature,
 		TopP:          textRequest.TopP,
@@ -740,7 +740,7 @@ func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
 	return nil
 }
 
-func ClaudeHandler(c *gin.Context, resp *http.Response, requestMode int, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.Usage) {
+func ClaudeHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*types.NewAPIError, *dto.Usage) {
 	defer common.CloseResponseBodyGracefully(resp)
 
 	claudeInfo := &ClaudeResponseInfo{

+ 1 - 1
relay/channel/cloudflare/dto.go

@@ -5,7 +5,7 @@ import "one-api/dto"
 type CfRequest struct {
 	Messages    []dto.Message `json:"messages,omitempty"`
 	Lora        string        `json:"lora,omitempty"`
-	MaxTokens   int           `json:"max_tokens,omitempty"`
+	MaxTokens   uint          `json:"max_tokens,omitempty"`
 	Prompt      string        `json:"prompt,omitempty"`
 	Raw         bool          `json:"raw,omitempty"`
 	Stream      bool          `json:"stream,omitempty"`

+ 1 - 1
relay/channel/cohere/dto.go

@@ -7,7 +7,7 @@ type CohereRequest struct {
 	ChatHistory []ChatHistory `json:"chat_history"`
 	Message     string        `json:"message"`
 	Stream      bool          `json:"stream"`
-	MaxTokens   int           `json:"max_tokens"`
+	MaxTokens   uint          `json:"max_tokens"`
 	SafetyMode  string        `json:"safety_mode,omitempty"`
 }
 

+ 3 - 4
relay/channel/deepseek/adaptor.go

@@ -24,10 +24,9 @@ func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dt
 	return nil, errors.New("not implemented")
 }
 
-func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
-	//TODO implement me
-	panic("implement me")
-	return nil, nil
+func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, req *dto.ClaudeRequest) (any, error) {
+	adaptor := openai.Adaptor{}
+	return adaptor.ConvertClaudeRequest(c, info, req)
 }
 
 func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {

+ 3 - 1
relay/channel/gemini/adaptor.go

@@ -120,6 +120,9 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
 	action := "generateContent"
 	if info.IsStream {
 		action = "streamGenerateContent?alt=sse"
+		if info.RelayMode == constant.RelayModeGemini {
+			info.DisablePing = true
+		}
 	}
 	return fmt.Sprintf("%s/%s/models/%s:%s", info.BaseUrl, version, info.UpstreamModelName, action), nil
 }
@@ -193,7 +196,6 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
 func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
 	if info.RelayMode == constant.RelayModeGemini {
 		if info.IsStream {
-			info.DisablePing = true
 			return GeminiTextGenerationStreamHandler(c, info, resp)
 		} else {
 			return GeminiTextGenerationHandler(c, info, resp)

+ 1 - 1
relay/channel/gemini/relay-gemini.go

@@ -184,7 +184,7 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
 		GenerationConfig: dto.GeminiChatGenerationConfig{
 			Temperature:     textRequest.Temperature,
 			TopP:            textRequest.TopP,
-			MaxOutputTokens: textRequest.MaxTokens,
+			MaxOutputTokens: textRequest.GetMaxTokens(),
 			Seed:            int64(textRequest.Seed),
 		},
 	}

+ 1 - 1
relay/channel/mistral/text.go

@@ -71,7 +71,7 @@ func requestOpenAI2Mistral(request *dto.GeneralOpenAIRequest) *dto.GeneralOpenAI
 		Messages:    messages,
 		Temperature: request.Temperature,
 		TopP:        request.TopP,
-		MaxTokens:   request.MaxTokens,
+		MaxTokens:   request.GetMaxTokens(),
 		Tools:       request.Tools,
 		ToolChoice:  request.ToolChoice,
 	}

+ 1 - 1
relay/channel/ollama/relay-ollama.go

@@ -60,7 +60,7 @@ func requestOpenAI2Ollama(request *dto.GeneralOpenAIRequest) (*OllamaRequest, er
 		TopK:             request.TopK,
 		Stop:             Stop,
 		Tools:            request.Tools,
-		MaxTokens:        request.MaxTokens,
+		MaxTokens:        request.GetMaxTokens(),
 		ResponseFormat:   request.ResponseFormat,
 		FrequencyPenalty: request.FrequencyPenalty,
 		PresencePenalty:  request.PresencePenalty,

+ 0 - 24
relay/channel/palm/relay-palm.go

@@ -18,30 +18,6 @@ import (
 // https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#request-body
 // https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#response-body
 
-func requestOpenAI2PaLM(textRequest dto.GeneralOpenAIRequest) *PaLMChatRequest {
-	palmRequest := PaLMChatRequest{
-		Prompt: PaLMPrompt{
-			Messages: make([]PaLMChatMessage, 0, len(textRequest.Messages)),
-		},
-		Temperature:    textRequest.Temperature,
-		CandidateCount: textRequest.N,
-		TopP:           textRequest.TopP,
-		TopK:           textRequest.MaxTokens,
-	}
-	for _, message := range textRequest.Messages {
-		palmMessage := PaLMChatMessage{
-			Content: message.StringContent(),
-		}
-		if message.Role == "user" {
-			palmMessage.Author = "0"
-		} else {
-			palmMessage.Author = "1"
-		}
-		palmRequest.Prompt.Messages = append(palmRequest.Prompt.Messages, palmMessage)
-	}
-	return &palmRequest
-}
-
 func responsePaLM2OpenAI(response *PaLMChatResponse) *dto.OpenAITextResponse {
 	fullTextResponse := dto.OpenAITextResponse{
 		Choices: make([]dto.OpenAITextResponseChoice, 0, len(response.Candidates)),

+ 1 - 1
relay/channel/perplexity/relay-perplexity.go

@@ -16,6 +16,6 @@ func requestOpenAI2Perplexity(request dto.GeneralOpenAIRequest) *dto.GeneralOpen
 		Messages:    messages,
 		Temperature: request.Temperature,
 		TopP:        request.TopP,
-		MaxTokens:   request.MaxTokens,
+		MaxTokens:   request.GetMaxTokens(),
 	}
 }

+ 1 - 1
relay/channel/vertex/adaptor.go

@@ -238,7 +238,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
 	} else {
 		switch a.RequestMode {
 		case RequestModeClaude:
-			err, usage = claude.ClaudeHandler(c, resp, claude.RequestModeMessage, info)
+			err, usage = claude.ClaudeHandler(c, resp, info, claude.RequestModeMessage)
 		case RequestModeGemini:
 			if info.RelayMode == constant.RelayModeGemini {
 				usage, err = gemini.GeminiTextGenerationHandler(c, info, resp)

+ 9 - 16
relay/channel/volcengine/adaptor.go

@@ -28,10 +28,9 @@ func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dt
 	return nil, errors.New("not implemented")
 }
 
-func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
-	//TODO implement me
-	panic("implement me")
-	return nil, nil
+func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, req *dto.ClaudeRequest) (any, error) {
+	adaptor := openai.Adaptor{}
+	return adaptor.ConvertClaudeRequest(c, info, req)
 }
 
 func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
@@ -196,6 +195,10 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
 		return fmt.Sprintf("%s/api/v3/embeddings", info.BaseUrl), nil
 	case constant.RelayModeImagesGenerations:
 		return fmt.Sprintf("%s/api/v3/images/generations", info.BaseUrl), nil
+	case constant.RelayModeImagesEdits:
+		return fmt.Sprintf("%s/api/v3/images/edits", info.BaseUrl), nil
+	case constant.RelayModeRerank:
+		return fmt.Sprintf("%s/api/v3/rerank", info.BaseUrl), nil
 	default:
 	}
 	return "", fmt.Errorf("unsupported relay mode: %d", info.RelayMode)
@@ -232,18 +235,8 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
 }
 
 func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
-	switch info.RelayMode {
-	case constant.RelayModeChatCompletions:
-		if info.IsStream {
-			usage, err = openai.OaiStreamHandler(c, info, resp)
-		} else {
-			usage, err = openai.OpenaiHandler(c, info, resp)
-		}
-	case constant.RelayModeEmbeddings:
-		usage, err = openai.OpenaiHandler(c, info, resp)
-	case constant.RelayModeImagesGenerations, constant.RelayModeImagesEdits:
-		usage, err = openai.OpenaiHandlerWithUsage(c, info, resp)
-	}
+	adaptor := openai.Adaptor{}
+	usage, err = adaptor.DoResponse(c, resp, info)
 	return
 }
 

+ 1 - 1
relay/channel/xunfei/relay-xunfei.go

@@ -48,7 +48,7 @@ func requestOpenAI2Xunfei(request dto.GeneralOpenAIRequest, xunfeiAppId string,
 	xunfeiRequest.Parameter.Chat.Domain = domain
 	xunfeiRequest.Parameter.Chat.Temperature = request.Temperature
 	xunfeiRequest.Parameter.Chat.TopK = request.N
-	xunfeiRequest.Parameter.Chat.MaxTokens = request.MaxTokens
+	xunfeiRequest.Parameter.Chat.MaxTokens = request.GetMaxTokens()
 	xunfeiRequest.Payload.Message.Text = messages
 	return &xunfeiRequest
 }

+ 1 - 1
relay/channel/zhipu_4v/relay-zhipu_v4.go

@@ -105,7 +105,7 @@ func requestOpenAI2Zhipu(request dto.GeneralOpenAIRequest) *dto.GeneralOpenAIReq
 		Messages:    messages,
 		Temperature: request.Temperature,
 		TopP:        request.TopP,
-		MaxTokens:   request.MaxTokens,
+		MaxTokens:   request.GetMaxTokens(),
 		Stop:        Stop,
 		Tools:       request.Tools,
 		ToolChoice:  request.ToolChoice,

+ 3 - 0
relay/common/relay_info.go

@@ -225,6 +225,9 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
 	userId := common.GetContextKeyInt(c, constant.ContextKeyUserId)
 	tokenUnlimited := common.GetContextKeyBool(c, constant.ContextKeyTokenUnlimited)
 	startTime := common.GetContextKeyTime(c, constant.ContextKeyRequestStartTime)
+	if startTime.IsZero() {
+		startTime = time.Now()
+	}
 	// firstResponseTime = time.Now() - 1 second
 
 	apiType, _ := common.ChannelType2APIType(channelType)

+ 5 - 1
service/convert.go

@@ -283,7 +283,9 @@ func StreamResponseOpenAI2Claude(openAIResponse *dto.ChatCompletionsStreamRespon
 		if chosenChoice.FinishReason != nil && *chosenChoice.FinishReason != "" {
 			// should be done
 			info.FinishReason = *chosenChoice.FinishReason
-			return claudeResponses
+			if !info.Done {
+				return claudeResponses
+			}
 		}
 		if info.Done {
 			claudeResponses = append(claudeResponses, generateStopBlock(info.ClaudeConvertInfo.Index))
@@ -432,6 +434,8 @@ func stopReasonOpenAI2Claude(reason string) string {
 		return "end_turn"
 	case "stop_sequence":
 		return "stop_sequence"
+	case "length":
+		fallthrough
 	case "max_tokens":
 		return "max_tokens"
 	case "tool_calls":

+ 3 - 0
service/error.go

@@ -93,6 +93,9 @@ func RelayErrorHandler(resp *http.Response, showBodyWhenFail bool) (newApiErr *t
 		if showBodyWhenFail {
 			newApiErr.Err = fmt.Errorf("bad response status code %d, body: %s", resp.StatusCode, string(responseBody))
 		} else {
+			if common.DebugEnabled {
+				println(fmt.Sprintf("bad response status code %d, body: %s", resp.StatusCode, string(responseBody)))
+			}
 			newApiErr.Err = fmt.Errorf("bad response status code %d", resp.StatusCode)
 		}
 		return