Просмотр исходного кода

feat: support xinference rerank to jina format

1808837298@qq.com 11 месяцев назад
Родитель
Сommit
d1c62a583d

+ 7 - 3
dto/rerank.go

@@ -10,13 +10,17 @@ type RerankRequest struct {
 	OverLapTokens   int    `json:"overlap_tokens,omitempty"`
 }
 
-type RerankResponseDocument struct {
+type RerankResponseResult struct {
 	Document       any     `json:"document,omitempty"`
 	Index          int     `json:"index"`
 	RelevanceScore float64 `json:"relevance_score"`
 }
 
+type RerankDocument struct {
+	Text any `json:"text"`
+}
+
 type RerankResponse struct {
-	Results []RerankResponseDocument `json:"results"`
-	Usage   Usage                    `json:"usage"`
+	Results []RerankResponseResult `json:"results"`
+	Usage   Usage                  `json:"usage"`
 }

+ 2 - 2
relay/channel/cohere/dto.go

@@ -40,8 +40,8 @@ type CohereRerankRequest struct {
 }
 
 type CohereRerankResponseResult struct {
-	Results []dto.RerankResponseDocument `json:"results"`
-	Meta    CohereMeta                   `json:"meta"`
+	Results []dto.RerankResponseResult `json:"results"`
+	Meta    CohereMeta                 `json:"meta"`
 }
 
 type CohereMeta struct {

+ 1 - 1
relay/channel/jina/adaptor.go

@@ -69,7 +69,7 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
 
 func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
 	if info.RelayMode == constant.RelayModeRerank {
-		err, usage = common_handler.RerankHandler(c, resp)
+		err, usage = common_handler.RerankHandler(c, info, resp)
 	} else if info.RelayMode == constant.RelayModeEmbeddings {
 		err, usage = openai.OpenaiHandler(c, resp, info)
 	}

+ 1 - 1
relay/channel/openai/adaptor.go

@@ -262,7 +262,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
 	case constant.RelayModeImagesGenerations:
 		err, usage = OpenaiTTSHandler(c, resp, info)
 	case constant.RelayModeRerank:
-		err, usage = common_handler.RerankHandler(c, resp)
+		err, usage = common_handler.RerankHandler(c, info, resp)
 	default:
 		if info.IsStream {
 			err, usage = OaiStreamHandler(c, resp, info)

+ 2 - 2
relay/channel/siliconflow/dto.go

@@ -12,6 +12,6 @@ type SFMeta struct {
 }
 
 type SFRerankResponse struct {
-	Results []dto.RerankResponseDocument `json:"results"`
-	Meta    SFMeta                       `json:"meta"`
+	Results []dto.RerankResponseResult `json:"results"`
+	Meta    SFMeta                     `json:"meta"`
 }

+ 11 - 0
relay/channel/xinference/dto.go

@@ -0,0 +1,11 @@
+package xinference
+
+type XinRerankResponseDocument struct {
+	Document       string  `json:"document,omitempty"`
+	Index          int     `json:"index"`
+	RelevanceScore float64 `json:"relevance_score"`
+}
+
+type XinRerankResponse struct {
+	Results []XinRerankResponseDocument `json:"results"`
+}

+ 14 - 0
relay/common/relay_info.go

@@ -33,6 +33,10 @@ const (
 	RelayFormatClaude = "claude"
 )
 
+type RerankerInfo struct {
+	Documents []any
+}
+
 type RelayInfo struct {
 	ChannelType       int
 	ChannelId         int
@@ -78,6 +82,7 @@ type RelayInfo struct {
 	SendResponseCount    int
 	ThinkingContentInfo
 	ClaudeConvertInfo
+	*RerankerInfo
 }
 
 // 定义支持流式选项的通道类型
@@ -111,6 +116,15 @@ func GenRelayInfoClaude(c *gin.Context) *RelayInfo {
 	return info
 }
 
+func GenRelayInfoRerank(c *gin.Context, documents []any) *RelayInfo {
+	info := GenRelayInfo(c)
+	info.RelayMode = relayconstant.RelayModeRerank
+	info.RerankerInfo = &RerankerInfo{
+		Documents: documents,
+	}
+	return info
+}
+
 func GenRelayInfo(c *gin.Context) *RelayInfo {
 	channelType := c.GetInt("channel_type")
 	channelId := c.GetInt("channel_id")

+ 43 - 11
relay/common_handler/rerank.go

@@ -1,15 +1,17 @@
 package common_handler
 
 import (
-	"encoding/json"
 	"github.com/gin-gonic/gin"
 	"io"
 	"net/http"
+	"one-api/common"
 	"one-api/dto"
+	"one-api/relay/channel/xinference"
+	relaycommon "one-api/relay/common"
 	"one-api/service"
 )
 
-func RerankHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
+func RerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
 	responseBody, err := io.ReadAll(resp.Body)
 	if err != nil {
 		return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
@@ -18,18 +20,48 @@ func RerankHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithSta
 	if err != nil {
 		return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
 	}
+	if common.DebugEnabled {
+		println("reranker response body: ", string(responseBody))
+	}
 	var jinaResp dto.RerankResponse
-	err = json.Unmarshal(responseBody, &jinaResp)
-	if err != nil {
-		return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
+	if info.ChannelType == common.ChannelTypeXinference {
+		var xinRerankResponse xinference.XinRerankResponse
+		err = common.DecodeJson(responseBody, &xinRerankResponse)
+		if err != nil {
+			return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
+		}
+		jinaRespResults := make([]dto.RerankResponseResult, len(xinRerankResponse.Results))
+		for i, result := range xinRerankResponse.Results {
+			var document any
+			if result.Document == "" {
+				document = info.Documents[result.Index]
+			} else {
+				document = result.Document
+			}
+			jinaRespResults[i] = dto.RerankResponseResult{
+				Index:          result.Index,
+				RelevanceScore: result.RelevanceScore,
+				Document: dto.RerankDocument{
+					Text: document,
+				},
+			}
+		}
+		jinaResp = dto.RerankResponse{
+			Results: jinaRespResults,
+			Usage: dto.Usage{
+				PromptTokens: info.PromptTokens,
+				TotalTokens:  info.PromptTokens,
+			},
+		}
+	} else {
+		err = common.DecodeJson(responseBody, &jinaResp)
+		if err != nil {
+			return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
+		}
+		jinaResp.Usage.PromptTokens = jinaResp.Usage.TotalTokens
 	}
 
-	jsonResponse, err := json.Marshal(jinaResp)
-	if err != nil {
-		return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
-	}
 	c.Writer.Header().Set("Content-Type", "application/json")
-	c.Writer.WriteHeader(resp.StatusCode)
-	_, err = c.Writer.Write(jsonResponse)
+	c.JSON(http.StatusOK, jinaResp)
 	return nil, &jinaResp.Usage
 }

+ 3 - 1
relay/relay_rerank.go

@@ -25,7 +25,6 @@ func getRerankPromptToken(rerankRequest dto.RerankRequest) int {
 }
 
 func RerankHelper(c *gin.Context, relayMode int) (openaiErr *dto.OpenAIErrorWithStatusCode) {
-	relayInfo := relaycommon.GenRelayInfo(c)
 
 	var rerankRequest *dto.RerankRequest
 	err := common.UnmarshalBodyReusable(c, &rerankRequest)
@@ -33,6 +32,9 @@ func RerankHelper(c *gin.Context, relayMode int) (openaiErr *dto.OpenAIErrorWith
 		common.LogError(c, fmt.Sprintf("getAndValidateTextRequest failed: %s", err.Error()))
 		return service.OpenAIErrorWrapperLocal(err, "invalid_text_request", http.StatusBadRequest)
 	}
+
+	relayInfo := relaycommon.GenRelayInfoRerank(c, rerankRequest.Documents)
+
 	if rerankRequest.Query == "" {
 		return service.OpenAIErrorWrapperLocal(fmt.Errorf("query is empty"), "invalid_query", http.StatusBadRequest)
 	}