Quellcode durchsuchen

feat: support cohere rerank

CalciumIon vor 1 Jahr
Ursprung
Commit
8af4e28f75

+ 2 - 0
controller/relay.go

@@ -29,6 +29,8 @@ func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode
 		fallthrough
 	case relayconstant.RelayModeAudioTranscription:
 		err = relay.AudioHelper(c, relayMode)
+	case relayconstant.RelayModeRerank:
+		err = relay.RerankHelper(c, relayMode)
 	default:
 		err = relay.TextHelper(c)
 	}

+ 19 - 0
dto/rerank.go

@@ -0,0 +1,19 @@
+package dto
+
+type RerankRequest struct {
+	Documents []any  `json:"documents"`
+	Query     string `json:"query"`
+	Model     string `json:"model"`
+	TopN      int    `json:"top_n"`
+}
+
+type RerankResponseDocument struct {
+	Document       any     `json:"document"`
+	Index          int     `json:"index"`
+	RelevanceScore float64 `json:"relevance_score"`
+}
+
+type RerankResponse struct {
+	Results []RerankResponseDocument `json:"results"`
+	Usage   Usage                    `json:"usage"`
+}

+ 2 - 0
relay/channel/adapter.go

@@ -11,9 +11,11 @@ import (
 type Adaptor interface {
 	// Init IsStream bool
 	Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest)
+	InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest)
 	GetRequestURL(info *relaycommon.RelayInfo) (string, error)
 	SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error
 	ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error)
+	ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error)
 	DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error)
 	DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode)
 	GetModelList() []string

+ 7 - 0
relay/channel/ali/adaptor.go

@@ -15,6 +15,9 @@ import (
 type Adaptor struct {
 }
 
+func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
+}
+
 func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
 
 }
@@ -53,6 +56,10 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
 	}
 }
 
+func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
+	return nil, nil
+}
+
 func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
 	return channel.DoApiRequest(a, c, info, requestBody)
 }

+ 9 - 0
relay/channel/aws/adaptor.go

@@ -20,6 +20,11 @@ type Adaptor struct {
 	RequestMode int
 }
 
+func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
+	//TODO implement me
+
+}
+
 func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
 	if strings.HasPrefix(info.UpstreamModelName, "claude-3") {
 		a.RequestMode = RequestModeMessage
@@ -53,6 +58,10 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
 	return claudeReq, err
 }
 
+func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
+	return nil, nil
+}
+
 func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
 	return nil, nil
 }

+ 9 - 0
relay/channel/baidu/adaptor.go

@@ -16,6 +16,11 @@ import (
 type Adaptor struct {
 }
 
+func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
+	//TODO implement me
+
+}
+
 func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
 
 }
@@ -108,6 +113,10 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
 	}
 }
 
+func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
+	return nil, nil
+}
+
 func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
 	return channel.DoApiRequest(a, c, info, requestBody)
 }

+ 9 - 0
relay/channel/claude/adaptor.go

@@ -21,6 +21,11 @@ type Adaptor struct {
 	RequestMode int
 }
 
+func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
+	//TODO implement me
+
+}
+
 func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
 	if strings.HasPrefix(info.UpstreamModelName, "claude-3") {
 		a.RequestMode = RequestModeMessage
@@ -59,6 +64,10 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
 	}
 }
 
+func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
+	return nil, nil
+}
+
 func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
 	return channel.DoApiRequest(a, c, info, requestBody)
 }

+ 20 - 4
relay/channel/cohere/adaptor.go

@@ -8,16 +8,24 @@ import (
 	"one-api/dto"
 	"one-api/relay/channel"
 	relaycommon "one-api/relay/common"
+	"one-api/relay/constant"
 )
 
 type Adaptor struct {
 }
 
+func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
+}
+
 func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
 }
 
 func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
-	return fmt.Sprintf("%s/v1/chat", info.BaseUrl), nil
+	if info.RelayMode == constant.RelayModeRerank {
+		return fmt.Sprintf("%s/v1/rerank", info.BaseUrl), nil
+	} else {
+		return fmt.Sprintf("%s/v1/chat", info.BaseUrl), nil
+	}
 }
 
 func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
@@ -34,11 +42,19 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
 	return channel.DoApiRequest(a, c, info, requestBody)
 }
 
+func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
+	return requestConvertRerank2Cohere(request), nil
+}
+
 func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
-	if info.IsStream {
-		err, usage = cohereStreamHandler(c, resp, info)
+	if info.RelayMode == constant.RelayModeRerank {
+		err, usage = cohereRerankHandler(c, resp, info)
 	} else {
-		err, usage = cohereHandler(c, resp, info.UpstreamModelName, info.PromptTokens)
+		if info.IsStream {
+			err, usage = cohereStreamHandler(c, resp, info)
+		} else {
+			err, usage = cohereHandler(c, resp, info.UpstreamModelName, info.PromptTokens)
+		}
 	}
 	return
 }

+ 1 - 0
relay/channel/cohere/constant.go

@@ -2,6 +2,7 @@ package cohere
 
 var ModelList = []string{
 	"command-r", "command-r-plus", "command-light", "command-light-nightly", "command", "command-nightly",
+	"rerank-english-v3.0", "rerank-multilingual-v3.0", "rerank-english-v2.0", "rerank-multilingual-v2.0",
 }
 
 var ChannelName = "cohere"

+ 15 - 0
relay/channel/cohere/dto.go

@@ -1,5 +1,7 @@
 package cohere
 
+import "one-api/dto"
+
 type CohereRequest struct {
 	Model       string        `json:"model"`
 	ChatHistory []ChatHistory `json:"chat_history"`
@@ -28,6 +30,19 @@ type CohereResponseResult struct {
 	Meta         CohereMeta `json:"meta"`
 }
 
+type CohereRerankRequest struct {
+	Documents       []any  `json:"documents"`
+	Query           string `json:"query"`
+	Model           string `json:"model"`
+	TopN            int    `json:"top_n"`
+	ReturnDocuments bool   `json:"return_documents"`
+}
+
+type CohereRerankResponseResult struct {
+	Results []dto.RerankResponseDocument `json:"results"`
+	Meta    CohereMeta                   `json:"meta"`
+}
+
 type CohereMeta struct {
 	//Tokens CohereTokens `json:"tokens"`
 	BilledUnits CohereBilledUnits `json:"billed_units"`

+ 53 - 0
relay/channel/cohere/relay-cohere.go

@@ -47,6 +47,20 @@ func requestOpenAI2Cohere(textRequest dto.GeneralOpenAIRequest) *CohereRequest {
 	return &cohereReq
 }
 
+func requestConvertRerank2Cohere(rerankRequest dto.RerankRequest) *CohereRerankRequest {
+	cohereReq := CohereRerankRequest{
+		Query:           rerankRequest.Query,
+		Documents:       rerankRequest.Documents,
+		Model:           rerankRequest.Model,
+		TopN:            rerankRequest.TopN,
+		ReturnDocuments: true,
+	}
+	for _, doc := range rerankRequest.Documents {
+		cohereReq.Documents = append(cohereReq.Documents, doc)
+	}
+	return &cohereReq
+}
+
 func stopReasonCohere2OpenAI(reason string) string {
 	switch reason {
 	case "COMPLETE":
@@ -194,3 +208,42 @@ func cohereHandler(c *gin.Context, resp *http.Response, modelName string, prompt
 	_, err = c.Writer.Write(jsonResponse)
 	return nil, &usage
 }
+
+func cohereRerankHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
+	responseBody, err := io.ReadAll(resp.Body)
+	if err != nil {
+		return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
+	}
+	err = resp.Body.Close()
+	if err != nil {
+		return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
+	}
+	var cohereResp CohereRerankResponseResult
+	err = json.Unmarshal(responseBody, &cohereResp)
+	if err != nil {
+		return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
+	}
+	usage := dto.Usage{}
+	if cohereResp.Meta.BilledUnits.InputTokens == 0 {
+		usage.PromptTokens = info.PromptTokens
+		usage.CompletionTokens = 0
+		usage.TotalTokens = info.PromptTokens
+	} else {
+		usage.PromptTokens = cohereResp.Meta.BilledUnits.InputTokens
+		usage.CompletionTokens = cohereResp.Meta.BilledUnits.OutputTokens
+		usage.TotalTokens = cohereResp.Meta.BilledUnits.InputTokens + cohereResp.Meta.BilledUnits.OutputTokens
+	}
+
+	var rerankResp dto.RerankResponse
+	rerankResp.Results = cohereResp.Results
+	rerankResp.Usage = usage
+
+	jsonResponse, err := json.Marshal(rerankResp)
+	if err != nil {
+		return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
+	}
+	c.Writer.Header().Set("Content-Type", "application/json")
+	c.Writer.WriteHeader(resp.StatusCode)
+	_, err = c.Writer.Write(jsonResponse)
+	return nil, &usage
+}

+ 9 - 0
relay/channel/dify/adaptor.go

@@ -14,6 +14,11 @@ import (
 type Adaptor struct {
 }
 
+func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
+	//TODO implement me
+
+}
+
 func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
 }
 
@@ -34,6 +39,10 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
 	return requestOpenAI2Dify(*request), nil
 }
 
+func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
+	return nil, nil
+}
+
 func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
 	return channel.DoApiRequest(a, c, info, requestBody)
 }

+ 7 - 0
relay/channel/gemini/adaptor.go

@@ -15,6 +15,9 @@ import (
 type Adaptor struct {
 }
 
+func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
+}
+
 func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
 }
 
@@ -56,6 +59,10 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
 	return CovertGemini2OpenAI(*request), nil
 }
 
+func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
+	return nil, nil
+}
+
 func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
 	return channel.DoApiRequest(a, c, info, requestBody)
 }

+ 7 - 0
relay/channel/ollama/adaptor.go

@@ -16,6 +16,9 @@ import (
 type Adaptor struct {
 }
 
+func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
+}
+
 func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
 }
 
@@ -45,6 +48,10 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
 	}
 }
 
+func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
+	return nil, nil
+}
+
 func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
 	return channel.DoApiRequest(a, c, info, requestBody)
 }

+ 7 - 0
relay/channel/openai/adaptor.go

@@ -22,6 +22,13 @@ type Adaptor struct {
 	ChannelType int
 }
 
+func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
+	return nil, nil
+}
+
+func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
+}
+
 func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
 	a.ChannelType = info.ChannelType
 }

+ 9 - 0
relay/channel/palm/adaptor.go

@@ -15,6 +15,11 @@ import (
 type Adaptor struct {
 }
 
+func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
+	//TODO implement me
+
+}
+
 func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
 }
 
@@ -35,6 +40,10 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
 	return request, nil
 }
 
+func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
+	return nil, nil
+}
+
 func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
 	return channel.DoApiRequest(a, c, info, requestBody)
 }

+ 9 - 0
relay/channel/perplexity/adaptor.go

@@ -16,6 +16,11 @@ import (
 type Adaptor struct {
 }
 
+func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
+	//TODO implement me
+
+}
+
 func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
 }
 
@@ -39,6 +44,10 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
 	return requestOpenAI2Perplexity(*request), nil
 }
 
+func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
+	return nil, nil
+}
+
 func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
 	return channel.DoApiRequest(a, c, info, requestBody)
 }

+ 9 - 0
relay/channel/tencent/adaptor.go

@@ -22,6 +22,11 @@ type Adaptor struct {
 	Timestamp int64
 }
 
+func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
+	//TODO implement me
+
+}
+
 func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
 	a.Action = "ChatCompletions"
 	a.Version = "2023-09-01"
@@ -57,6 +62,10 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
 	return tencentRequest, nil
 }
 
+func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
+	return nil, nil
+}
+
 func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
 	return channel.DoApiRequest(a, c, info, requestBody)
 }

+ 9 - 0
relay/channel/xunfei/adaptor.go

@@ -16,6 +16,11 @@ type Adaptor struct {
 	request *dto.GeneralOpenAIRequest
 }
 
+func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
+	//TODO implement me
+
+}
+
 func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
 }
 
@@ -36,6 +41,10 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
 	return request, nil
 }
 
+func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
+	return nil, nil
+}
+
 func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
 	// xunfei's request is not http request, so we don't need to do anything here
 	dummyResp := &http.Response{}

+ 9 - 0
relay/channel/zhipu/adaptor.go

@@ -14,6 +14,11 @@ import (
 type Adaptor struct {
 }
 
+func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
+	//TODO implement me
+
+}
+
 func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
 }
 
@@ -42,6 +47,10 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
 	return requestOpenAI2Zhipu(*request), nil
 }
 
+func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
+	return nil, nil
+}
+
 func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
 	return channel.DoApiRequest(a, c, info, requestBody)
 }

+ 9 - 0
relay/channel/zhipu_4v/adaptor.go

@@ -16,6 +16,11 @@ import (
 type Adaptor struct {
 }
 
+func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
+	//TODO implement me
+
+}
+
 func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
 }
 
@@ -40,6 +45,10 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
 	return requestOpenAI2Zhipu(*request), nil
 }
 
+func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
+	return nil, nil
+}
+
 func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
 	return channel.DoApiRequest(a, c, info, requestBody)
 }

+ 3 - 0
relay/constant/relay_mode.go

@@ -32,6 +32,7 @@ const (
 	RelayModeSunoFetch
 	RelayModeSunoFetchByID
 	RelayModeSunoSubmit
+	RelayModeRerank
 )
 
 func Path2RelayMode(path string) int {
@@ -56,6 +57,8 @@ func Path2RelayMode(path string) int {
 		relayMode = RelayModeAudioTranscription
 	} else if strings.HasPrefix(path, "/v1/audio/translations") {
 		relayMode = RelayModeAudioTranslation
+	} else if strings.HasPrefix(path, "/v1/rerank") {
+		relayMode = RelayModeRerank
 	}
 	return relayMode
 }

+ 9 - 7
relay/relay-text.go

@@ -182,7 +182,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
 		service.ResetStatusCode(openaiErr, statusCodeMappingStr)
 		return openaiErr
 	}
-	postConsumeQuota(c, relayInfo, *textRequest, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, success)
+	postConsumeQuota(c, relayInfo, textRequest.Model, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, success)
 	return nil
 }
 
@@ -272,7 +272,7 @@ func returnPreConsumedQuota(c *gin.Context, tokenId int, userQuota int, preConsu
 	}
 }
 
-func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, textRequest dto.GeneralOpenAIRequest,
+func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelName string,
 	usage *dto.Usage, ratio float64, preConsumedQuota int, userQuota int, modelRatio float64, groupRatio float64,
 	modelPrice float64, usePrice bool) {
 
@@ -281,7 +281,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, textRe
 	completionTokens := usage.CompletionTokens
 
 	tokenName := ctx.GetString("token_name")
-	completionRatio := common.GetCompletionRatio(textRequest.Model)
+	completionRatio := common.GetCompletionRatio(modelName)
 
 	quota := 0
 	if !usePrice {
@@ -307,7 +307,8 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, textRe
 		// we cannot just return, because we may have to return the pre-consumed quota
 		quota = 0
 		logContent += fmt.Sprintf("(可能是上游超时)")
-		common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, textRequest.Model, preConsumedQuota))
+		common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+
+			"tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, preConsumedQuota))
 	} else {
 		//if sensitiveResp != nil {
 		//	logContent += fmt.Sprintf(",敏感词:%s", strings.Join(sensitiveResp.SensitiveWords, ", "))
@@ -327,13 +328,14 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, textRe
 		model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
 	}
 
-	logModel := textRequest.Model
+	logModel := modelName
 	if strings.HasPrefix(logModel, "gpt-4-gizmo") {
 		logModel = "gpt-4-gizmo-*"
-		logContent += fmt.Sprintf(",模型 %s", textRequest.Model)
+		logContent += fmt.Sprintf(",模型 %s", modelName)
 	}
 	other := service.GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, modelPrice)
-	model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, promptTokens, completionTokens, logModel, tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, other)
+	model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, promptTokens, completionTokens, logModel,
+		tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, other)
 
 	//if quota != 0 {
 	//

+ 104 - 0
relay/relay_rerank.go

@@ -0,0 +1,104 @@
+package relay
+
+import (
+	"bytes"
+	"encoding/json"
+	"fmt"
+	"github.com/gin-gonic/gin"
+	"net/http"
+	"one-api/common"
+	"one-api/dto"
+	relaycommon "one-api/relay/common"
+	"one-api/service"
+)
+
+func getRerankPromptToken(rerankRequest dto.RerankRequest) int {
+	token, _ := service.CountTokenInput(rerankRequest.Query, rerankRequest.Model)
+	for _, document := range rerankRequest.Documents {
+		tkm, err := service.CountTokenInput(document, rerankRequest.Model)
+		if err == nil {
+			token += tkm
+		}
+	}
+	return token
+}
+
+func RerankHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
+	relayInfo := relaycommon.GenRelayInfo(c)
+
+	var rerankRequest *dto.RerankRequest
+	err := common.UnmarshalBodyReusable(c, &rerankRequest)
+	if err != nil {
+		common.LogError(c, fmt.Sprintf("getAndValidateTextRequest failed: %s", err.Error()))
+		return service.OpenAIErrorWrapperLocal(err, "invalid_text_request", http.StatusBadRequest)
+	}
+	if rerankRequest.Query == "" {
+		return service.OpenAIErrorWrapperLocal(fmt.Errorf("query is empty"), "invalid_query", http.StatusBadRequest)
+	}
+	if len(rerankRequest.Documents) == 0 {
+		return service.OpenAIErrorWrapperLocal(fmt.Errorf("documents is empty"), "invalid_documents", http.StatusBadRequest)
+	}
+	relayInfo.UpstreamModelName = rerankRequest.Model
+	modelPrice, success := common.GetModelPrice(rerankRequest.Model, false)
+	groupRatio := common.GetGroupRatio(relayInfo.Group)
+
+	var preConsumedQuota int
+	var ratio float64
+	var modelRatio float64
+
+	promptToken := getRerankPromptToken(*rerankRequest)
+	if !success {
+		preConsumedTokens := promptToken
+		modelRatio = common.GetModelRatio(rerankRequest.Model)
+		ratio = modelRatio * groupRatio
+		preConsumedQuota = int(float64(preConsumedTokens) * ratio)
+	} else {
+		preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio)
+	}
+	relayInfo.PromptTokens = promptToken
+
+	// pre-consume quota 预消耗配额
+	preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, preConsumedQuota, relayInfo)
+	if openaiErr != nil {
+		return openaiErr
+	}
+	adaptor := GetAdaptor(relayInfo.ApiType)
+	if adaptor == nil {
+		return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)
+	}
+	adaptor.InitRerank(relayInfo, *rerankRequest)
+
+	convertedRequest, err := adaptor.ConvertRerankRequest(c, relayInfo.RelayMode, *rerankRequest)
+	if err != nil {
+		return service.OpenAIErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError)
+	}
+	jsonData, err := json.Marshal(convertedRequest)
+	if err != nil {
+		return service.OpenAIErrorWrapperLocal(err, "json_marshal_failed", http.StatusInternalServerError)
+	}
+	requestBody := bytes.NewBuffer(jsonData)
+	statusCodeMappingStr := c.GetString("status_code_mapping")
+	resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
+	if err != nil {
+		return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
+	}
+	if resp != nil {
+		if resp.StatusCode != http.StatusOK {
+			returnPreConsumedQuota(c, relayInfo.TokenId, userQuota, preConsumedQuota)
+			openaiErr := service.RelayErrorHandler(resp)
+			// reset status code 重置状态码
+			service.ResetStatusCode(openaiErr, statusCodeMappingStr)
+			return openaiErr
+		}
+	}
+
+	usage, openaiErr := adaptor.DoResponse(c, resp, relayInfo)
+	if openaiErr != nil {
+		returnPreConsumedQuota(c, relayInfo.TokenId, userQuota, preConsumedQuota)
+		// reset status code 重置状态码
+		service.ResetStatusCode(openaiErr, statusCodeMappingStr)
+		return openaiErr
+	}
+	postConsumeQuota(c, relayInfo, rerankRequest.Model, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, success)
+	return nil
+}

+ 1 - 0
router/relay-router.go

@@ -42,6 +42,7 @@ func SetRelayRouter(router *gin.Engine) {
 		relayV1Router.GET("/fine-tunes/:id/events", controller.RelayNotImplemented)
 		relayV1Router.DELETE("/models/:model", controller.RelayNotImplemented)
 		relayV1Router.POST("/moderations", controller.Relay)
+		relayV1Router.POST("/rerank", controller.Relay)
 	}
 
 	relayMjRouter := router.Group("/mj")