Просмотр исходного кода

fix: support ali's embedding model (#481, close #469)

* feat:支持阿里的 embedding 模型

* fix: add to model list

---------

Co-authored-by: JustSong <songquanpeng@foxmail.com>
Co-authored-by: JustSong <39998050+songquanpeng@users.noreply.github.com>
igophper 2 лет назад
Родитель
Сommit
d0a0e871e1

+ 4 - 3
common/model-ratio.go

@@ -50,9 +50,10 @@ var ModelRatio = map[string]float64{
 	"chatglm_pro":               0.7143, // ¥0.01 / 1k tokens
 	"chatglm_pro":               0.7143, // ¥0.01 / 1k tokens
 	"chatglm_std":               0.3572, // ¥0.005 / 1k tokens
 	"chatglm_std":               0.3572, // ¥0.005 / 1k tokens
 	"chatglm_lite":              0.1429, // ¥0.002 / 1k tokens
 	"chatglm_lite":              0.1429, // ¥0.002 / 1k tokens
-	"qwen-v1":                   0.8572, // TBD: https://help.aliyun.com/document_detail/2399482.html?spm=a2c4g.2399482.0.0.1ad347feilAgag
-	"qwen-plus-v1":              0.5715, // Same as above
-	"SparkDesk":                 0.8572, // TBD
+	"qwen-v1":                   0.8572, // ¥0.012 / 1k tokens
+	"qwen-plus-v1":              1,      // ¥0.014 / 1k tokens
+	"text-embedding-v1":         0.05,   // ¥0.0007 / 1k tokens
+	"SparkDesk":                 1.2858, // ¥0.018 / 1k tokens
 	"360GPT_S2_V9":              0.8572, // ¥0.012 / 1k tokens
 	"360GPT_S2_V9":              0.8572, // ¥0.012 / 1k tokens
 	"embedding-bert-512-v1":     0.0715, // ¥0.001 / 1k tokens
 	"embedding-bert-512-v1":     0.0715, // ¥0.001 / 1k tokens
 	"embedding_s1_v1":           0.0715, // ¥0.001 / 1k tokens
 	"embedding_s1_v1":           0.0715, // ¥0.001 / 1k tokens

+ 9 - 0
controller/model.go

@@ -360,6 +360,15 @@ func init() {
 			Root:       "qwen-plus-v1",
 			Root:       "qwen-plus-v1",
 			Parent:     nil,
 			Parent:     nil,
 		},
 		},
+		{
+			Id:         "text-embedding-v1",
+			Object:     "model",
+			Created:    1677649963,
+			OwnedBy:    "ali",
+			Permission: permission,
+			Root:       "text-embedding-v1",
+			Parent:     nil,
+		},
 		{
 		{
 			Id:         "SparkDesk",
 			Id:         "SparkDesk",
 			Object:     "model",
 			Object:     "model",

+ 88 - 0
controller/relay-ali.go

@@ -35,6 +35,29 @@ type AliChatRequest struct {
 	Parameters AliParameters `json:"parameters,omitempty"`
 	Parameters AliParameters `json:"parameters,omitempty"`
 }
 }
 
 
+type AliEmbeddingRequest struct {
+	Model string `json:"model"`
+	Input struct {
+		Texts []string `json:"texts"`
+	} `json:"input"`
+	Parameters *struct {
+		TextType string `json:"text_type,omitempty"`
+	} `json:"parameters,omitempty"`
+}
+
+type AliEmbedding struct {
+	Embedding []float64 `json:"embedding"`
+	TextIndex int       `json:"text_index"`
+}
+
+type AliEmbeddingResponse struct {
+	Output struct {
+		Embeddings []AliEmbedding `json:"embeddings"`
+	} `json:"output"`
+	Usage AliUsage `json:"usage"`
+	AliError
+}
+
 type AliError struct {
 type AliError struct {
 	Code      string `json:"code"`
 	Code      string `json:"code"`
 	Message   string `json:"message"`
 	Message   string `json:"message"`
@@ -44,6 +67,7 @@ type AliError struct {
 type AliUsage struct {
 type AliUsage struct {
 	InputTokens  int `json:"input_tokens"`
 	InputTokens  int `json:"input_tokens"`
 	OutputTokens int `json:"output_tokens"`
 	OutputTokens int `json:"output_tokens"`
+	TotalTokens  int `json:"total_tokens"`
 }
 }
 
 
 type AliOutput struct {
 type AliOutput struct {
@@ -95,6 +119,70 @@ func requestOpenAI2Ali(request GeneralOpenAIRequest) *AliChatRequest {
 	}
 	}
 }
 }
 
 
+func embeddingRequestOpenAI2Ali(request GeneralOpenAIRequest) *AliEmbeddingRequest {
+	return &AliEmbeddingRequest{
+		Model: "text-embedding-v1",
+		Input: struct {
+			Texts []string `json:"texts"`
+		}{
+			Texts: request.ParseInput(),
+		},
+	}
+}
+
+func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
+	var aliResponse AliEmbeddingResponse
+	err := json.NewDecoder(resp.Body).Decode(&aliResponse)
+	if err != nil {
+		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
+	}
+
+	err = resp.Body.Close()
+	if err != nil {
+		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
+	}
+
+	if aliResponse.Code != "" {
+		return &OpenAIErrorWithStatusCode{
+			OpenAIError: OpenAIError{
+				Message: aliResponse.Message,
+				Type:    aliResponse.Code,
+				Param:   aliResponse.RequestId,
+				Code:    aliResponse.Code,
+			},
+			StatusCode: resp.StatusCode,
+		}, nil
+	}
+
+	fullTextResponse := embeddingResponseAli2OpenAI(&aliResponse)
+	jsonResponse, err := json.Marshal(fullTextResponse)
+	if err != nil {
+		return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
+	}
+	c.Writer.Header().Set("Content-Type", "application/json")
+	c.Writer.WriteHeader(resp.StatusCode)
+	_, err = c.Writer.Write(jsonResponse)
+	return nil, &fullTextResponse.Usage
+}
+
+func embeddingResponseAli2OpenAI(response *AliEmbeddingResponse) *OpenAIEmbeddingResponse {
+	openAIEmbeddingResponse := OpenAIEmbeddingResponse{
+		Object: "list",
+		Data:   make([]OpenAIEmbeddingResponseItem, 0, len(response.Output.Embeddings)),
+		Model:  "text-embedding-v1",
+		Usage:  Usage{TotalTokens: response.Usage.TotalTokens},
+	}
+
+	for _, item := range response.Output.Embeddings {
+		openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, OpenAIEmbeddingResponseItem{
+			Object:    `embedding`,
+			Index:     item.TextIndex,
+			Embedding: item.Embedding,
+		})
+	}
+	return &openAIEmbeddingResponse
+}
+
 func responseAli2OpenAI(response *AliChatResponse) *OpenAITextResponse {
 func responseAli2OpenAI(response *AliChatResponse) *OpenAITextResponse {
 	choice := OpenAITextResponseChoice{
 	choice := OpenAITextResponseChoice{
 		Index: 0,
 		Index: 0,

+ 2 - 13
controller/relay-baidu.go

@@ -144,20 +144,9 @@ func streamResponseBaidu2OpenAI(baiduResponse *BaiduChatStreamResponse) *ChatCom
 }
 }
 
 
 func embeddingRequestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduEmbeddingRequest {
 func embeddingRequestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduEmbeddingRequest {
-	baiduEmbeddingRequest := BaiduEmbeddingRequest{
-		Input: nil,
+	return &BaiduEmbeddingRequest{
+		Input: request.ParseInput(),
 	}
 	}
-	switch request.Input.(type) {
-	case string:
-		baiduEmbeddingRequest.Input = []string{request.Input.(string)}
-	case []any:
-		for _, item := range request.Input.([]any) {
-			if str, ok := item.(string); ok {
-				baiduEmbeddingRequest.Input = append(baiduEmbeddingRequest.Input, str)
-			}
-		}
-	}
-	return &baiduEmbeddingRequest
 }
 }
 
 
 func embeddingResponseBaidu2OpenAI(response *BaiduEmbeddingResponse) *OpenAIEmbeddingResponse {
 func embeddingResponseBaidu2OpenAI(response *BaiduEmbeddingResponse) *OpenAIEmbeddingResponse {

+ 21 - 3
controller/relay-text.go

@@ -174,6 +174,9 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 		fullRequestURL = fmt.Sprintf("https://open.bigmodel.cn/api/paas/v3/model-api/%s/%s", textRequest.Model, method)
 		fullRequestURL = fmt.Sprintf("https://open.bigmodel.cn/api/paas/v3/model-api/%s/%s", textRequest.Model, method)
 	case APITypeAli:
 	case APITypeAli:
 		fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation"
 		fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation"
+		if relayMode == RelayModeEmbeddings {
+			fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/embeddings/text-embedding/text-embedding"
+		}
 	case APITypeAIProxyLibrary:
 	case APITypeAIProxyLibrary:
 		fullRequestURL = fmt.Sprintf("%s/api/library/ask", baseURL)
 		fullRequestURL = fmt.Sprintf("%s/api/library/ask", baseURL)
 	}
 	}
@@ -262,8 +265,16 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 		}
 		}
 		requestBody = bytes.NewBuffer(jsonStr)
 		requestBody = bytes.NewBuffer(jsonStr)
 	case APITypeAli:
 	case APITypeAli:
-		aliRequest := requestOpenAI2Ali(textRequest)
-		jsonStr, err := json.Marshal(aliRequest)
+		var jsonStr []byte
+		var err error
+		switch relayMode {
+		case RelayModeEmbeddings:
+			aliEmbeddingRequest := embeddingRequestOpenAI2Ali(textRequest)
+			jsonStr, err = json.Marshal(aliEmbeddingRequest)
+		default:
+			aliRequest := requestOpenAI2Ali(textRequest)
+			jsonStr, err = json.Marshal(aliRequest)
+		}
 		if err != nil {
 		if err != nil {
 			return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
 			return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
 		}
 		}
@@ -502,7 +513,14 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 			}
 			}
 			return nil
 			return nil
 		} else {
 		} else {
-			err, usage := aliHandler(c, resp)
+			var err *OpenAIErrorWithStatusCode
+			var usage *Usage
+			switch relayMode {
+			case RelayModeEmbeddings:
+				err, usage = aliEmbeddingHandler(c, resp)
+			default:
+				err, usage = aliHandler(c, resp)
+			}
 			if err != nil {
 			if err != nil {
 				return err
 				return err
 			}
 			}

+ 19 - 0
controller/relay.go

@@ -44,6 +44,25 @@ type GeneralOpenAIRequest struct {
 	Functions   any       `json:"functions,omitempty"`
 	Functions   any       `json:"functions,omitempty"`
 }
 }
 
 
+func (r GeneralOpenAIRequest) ParseInput() []string {
+	if r.Input == nil {
+		return nil
+	}
+	var input []string
+	switch r.Input.(type) {
+	case string:
+		input = []string{r.Input.(string)}
+	case []any:
+		input = make([]string, 0, len(r.Input.([]any)))
+		for _, item := range r.Input.([]any) {
+			if str, ok := item.(string); ok {
+				input = append(input, str)
+			}
+		}
+	}
+	return input
+}
+
 type ChatRequest struct {
 type ChatRequest struct {
 	Model     string    `json:"model"`
 	Model     string    `json:"model"`
 	Messages  []Message `json:"messages"`
 	Messages  []Message `json:"messages"`

+ 1 - 1
web/src/pages/Channel/EditChannel.js

@@ -67,7 +67,7 @@ const EditChannel = () => {
           localModels = ['ERNIE-Bot', 'ERNIE-Bot-turbo', 'Embedding-V1'];
           localModels = ['ERNIE-Bot', 'ERNIE-Bot-turbo', 'Embedding-V1'];
           break;
           break;
         case 17:
         case 17:
-          localModels = ['qwen-v1', 'qwen-plus-v1'];
+          localModels = ['qwen-v1', 'qwen-plus-v1', 'text-embedding-v1'];
           break;
           break;
         case 16:
         case 16:
           localModels = ['chatglm_pro', 'chatglm_std', 'chatglm_lite'];
           localModels = ['chatglm_pro', 'chatglm_std', 'chatglm_lite'];