Просмотр исходного кода

feat: support baidu's embedding model (close #324)

JustSong 2 лет назад
Родитель
Сommit
130e6bfd83
5 измененных файлов с 129 добавлено и 4 удалено
  1. 1 0
      common/model-ratio.go
  2. 9 0
      controller/model.go
  3. 85 0
      controller/relay-baidu.go
  4. 21 4
      controller/relay-text.go
  5. 13 0
      controller/relay.go

+ 1 - 0
common/model-ratio.go

@@ -42,6 +42,7 @@ var ModelRatio = map[string]float64{
 	"claude-2":                30,
 	"ERNIE-Bot":               0.8572, // ¥0.012 / 1k tokens
 	"ERNIE-Bot-turbo":         0.5715, // ¥0.008 / 1k tokens
+	"Embedding-V1":            0.1429, // ¥0.002 / 1k tokens
 	"PaLM-2":                  1,
 	"chatglm_pro":             0.7143, // ¥0.01 / 1k tokens
 	"chatglm_std":             0.3572, // ¥0.005 / 1k tokens

+ 9 - 0
controller/model.go

@@ -288,6 +288,15 @@ func init() {
 			Root:       "ERNIE-Bot-turbo",
 			Parent:     nil,
 		},
+		{
+			Id:         "Embedding-V1",
+			Object:     "model",
+			Created:    1677649963,
+			OwnedBy:    "baidu",
+			Permission: permission,
+			Root:       "Embedding-V1",
+			Parent:     nil,
+		},
 		{
 			Id:         "PaLM-2",
 			Object:     "model",

+ 85 - 0
controller/relay-baidu.go

@@ -54,6 +54,25 @@ type BaiduChatStreamResponse struct {
 	IsEnd      bool `json:"is_end"`
 }
 
+type BaiduEmbeddingRequest struct {
+	Input []string `json:"input"`
+}
+
+type BaiduEmbeddingData struct {
+	Object    string    `json:"object"`
+	Embedding []float64 `json:"embedding"`
+	Index     int       `json:"index"`
+}
+
+type BaiduEmbeddingResponse struct {
+	Id      string               `json:"id"`
+	Object  string               `json:"object"`
+	Created int64                `json:"created"`
+	Data    []BaiduEmbeddingData `json:"data"`
+	Usage   Usage                `json:"usage"`
+	BaiduError
+}
+
 func requestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduChatRequest {
 	messages := make([]BaiduMessage, 0, len(request.Messages))
 	for _, message := range request.Messages {
@@ -112,6 +131,36 @@ func streamResponseBaidu2OpenAI(baiduResponse *BaiduChatStreamResponse) *ChatCom
 	return &response
 }
 
+func embeddingRequestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduEmbeddingRequest {
+	baiduEmbeddingRequest := BaiduEmbeddingRequest{
+		Input: nil,
+	}
+	switch request.Input.(type) {
+	case string:
+		baiduEmbeddingRequest.Input = []string{request.Input.(string)}
+	case []string:
+		baiduEmbeddingRequest.Input = request.Input.([]string)
+	}
+	return &baiduEmbeddingRequest
+}
+
+func embeddingResponseBaidu2OpenAI(response *BaiduEmbeddingResponse) *OpenAIEmbeddingResponse {
+	openAIEmbeddingResponse := OpenAIEmbeddingResponse{
+		Object: "list",
+		Data:   make([]OpenAIEmbeddingResponseItem, 0, len(response.Data)),
+		Model:  "baidu-embedding",
+		Usage:  response.Usage,
+	}
+	for _, item := range response.Data {
+		openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, OpenAIEmbeddingResponseItem{
+			Object:    item.Object,
+			Index:     item.Index,
+			Embedding: item.Embedding,
+		})
+	}
+	return &openAIEmbeddingResponse
+}
+
 func baiduStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
 	var usage Usage
 	scanner := bufio.NewScanner(resp.Body)
@@ -212,3 +261,39 @@ func baiduHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCo
 	_, err = c.Writer.Write(jsonResponse)
 	return nil, &fullTextResponse.Usage
 }
+
+func baiduEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
+	var baiduResponse BaiduEmbeddingResponse
+	responseBody, err := io.ReadAll(resp.Body)
+	if err != nil {
+		return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
+	}
+	err = resp.Body.Close()
+	if err != nil {
+		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
+	}
+	err = json.Unmarshal(responseBody, &baiduResponse)
+	if err != nil {
+		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
+	}
+	if baiduResponse.ErrorMsg != "" {
+		return &OpenAIErrorWithStatusCode{
+			OpenAIError: OpenAIError{
+				Message: baiduResponse.ErrorMsg,
+				Type:    "baidu_error",
+				Param:   "",
+				Code:    baiduResponse.ErrorCode,
+			},
+			StatusCode: resp.StatusCode,
+		}, nil
+	}
+	fullTextResponse := embeddingResponseBaidu2OpenAI(&baiduResponse)
+	jsonResponse, err := json.Marshal(fullTextResponse)
+	if err != nil {
+		return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
+	}
+	c.Writer.Header().Set("Content-Type", "application/json")
+	c.Writer.WriteHeader(resp.StatusCode)
+	_, err = c.Writer.Write(jsonResponse)
+	return nil, &fullTextResponse.Usage
+}

+ 21 - 4
controller/relay-text.go

@@ -139,6 +139,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 			fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant"
 		case "BLOOMZ-7B":
 			fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/bloomz_7b1"
+		case "Embedding-V1":
+			fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/embedding-v1"
 		}
 		apiKey := c.Request.Header.Get("Authorization")
 		apiKey = strings.TrimPrefix(apiKey, "Bearer ")
@@ -212,12 +214,20 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 		}
 		requestBody = bytes.NewBuffer(jsonStr)
 	case APITypeBaidu:
-		baiduRequest := requestOpenAI2Baidu(textRequest)
-		jsonStr, err := json.Marshal(baiduRequest)
+		var jsonData []byte
+		var err error
+		switch relayMode {
+		case RelayModeEmbeddings:
+			baiduEmbeddingRequest := embeddingRequestOpenAI2Baidu(textRequest)
+			jsonData, err = json.Marshal(baiduEmbeddingRequest)
+		default:
+			baiduRequest := requestOpenAI2Baidu(textRequest)
+			jsonData, err = json.Marshal(baiduRequest)
+		}
 		if err != nil {
 			return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
 		}
-		requestBody = bytes.NewBuffer(jsonStr)
+		requestBody = bytes.NewBuffer(jsonData)
 	case APITypePaLM:
 		palmRequest := requestOpenAI2PaLM(textRequest)
 		jsonStr, err := json.Marshal(palmRequest)
@@ -386,7 +396,14 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 			}
 			return nil
 		} else {
-			err, usage := baiduHandler(c, resp)
+			var err *OpenAIErrorWithStatusCode
+			var usage *Usage
+			switch relayMode {
+			case RelayModeEmbeddings:
+				err, usage = baiduEmbeddingHandler(c, resp)
+			default:
+				err, usage = baiduHandler(c, resp)
+			}
 			if err != nil {
 				return err
 			}

+ 13 - 0
controller/relay.go

@@ -99,6 +99,19 @@ type OpenAITextResponse struct {
 	Usage   `json:"usage"`
 }
 
+type OpenAIEmbeddingResponseItem struct {
+	Object    string    `json:"object"`
+	Index     int       `json:"index"`
+	Embedding []float64 `json:"embedding"`
+}
+
+type OpenAIEmbeddingResponse struct {
+	Object string                        `json:"object"`
+	Data   []OpenAIEmbeddingResponseItem `json:"data"`
+	Model  string                        `json:"model"`
+	Usage  `json:"usage"`
+}
+
 type ImageResponse struct {
 	Created int `json:"created"`
 	Data    []struct {