فهرست منبع

Fix M3E not working

Jerry 1 سال پیش
والد
کامیت
7588c42b42

+ 8 - 0
controller/relay.go

@@ -33,6 +33,8 @@ func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode
 		err = relay.AudioHelper(c)
 	case relayconstant.RelayModeRerank:
 		err = relay.RerankHelper(c, relayMode)
+	case relayconstant.RelayModeEmbeddings:
+		err = relay.EmbeddingHelper(c,relayMode)
 	default:
 		err = relay.TextHelper(c)
 	}
@@ -55,6 +57,11 @@ func Relay(c *gin.Context) {
 	originalModel := c.GetString("original_model")
 	var openaiErr *dto.OpenAIErrorWithStatusCode
 
+	//获取request body 并输出到日志
+	requestBody, _ := common.GetRequestBody(c)
+	common.LogInfo(c, fmt.Sprintf("relayMode: %d ,request body: %s",relayMode, string(requestBody)))
+	
+
 	for i := 0; i <= common.RetryTimes; i++ {
 		channel, err := getChannel(c, group, originalModel, i)
 		if err != nil {
@@ -154,6 +161,7 @@ func WssRelay(c *gin.Context) {
 }
 
 func relayRequest(c *gin.Context, relayMode int, channel *model.Channel) *dto.OpenAIErrorWithStatusCode {
+	common.LogInfo(c, fmt.Sprintf("relayMode: %d ,channel Id : %s",relayMode, string(channel.Id)))
 	addUsedChannel(c, channel.Id)
 	requestBody, _ := common.GetRequestBody(c)
 	c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))

+ 15 - 15
relay/channel/mokaai/dto.go → dto/embedding.go

@@ -1,19 +1,6 @@
-package mokaai
+package dto
 
-import "one-api/dto"
-
-
-type Request struct {
-	Messages    []dto.Message `json:"messages,omitempty"`
-	Lora        string        `json:"lora,omitempty"`
-	MaxTokens   int           `json:"max_tokens,omitempty"`
-	Prompt      string        `json:"prompt,omitempty"`
-	Raw         bool          `json:"raw,omitempty"`
-	Stream      bool          `json:"stream,omitempty"`
-	Temperature float64       `json:"temperature,omitempty"`
-}
-
-type Options struct {
+type EmbeddingOptions struct {
 	Seed             int      `json:"seed,omitempty"`
 	Temperature      *float64 `json:"temperature,omitempty"`
 	TopK             int      `json:"top_k,omitempty"`
@@ -27,4 +14,17 @@ type Options struct {
 type EmbeddingRequest struct {
 	Model string   `json:"model"`
 	Input []string `json:"input"`
+}
+
+type EmbeddingResponseItem struct {
+	Object    string    `json:"object"`
+	Index     int       `json:"index"`
+	Embedding []float64 `json:"embedding"`
+}
+
+type EmbeddingResponse struct {
+	Object string                        `json:"object"`
+	Data   []EmbeddingResponseItem `json:"data"`
+	Model  string                        `json:"model"`
+	Usage  `json:"usage"`
 }

+ 1 - 0
relay/channel/adapter.go

@@ -15,6 +15,7 @@ type Adaptor interface {
 	SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error
 	ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error)
 	ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error)
+	ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error)
 	ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error)
 	ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error)
 	DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error)

+ 5 - 0
relay/channel/ali/adaptor.go

@@ -67,6 +67,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
 	return nil, errors.New("not implemented")
 }
 
+func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
+	//TODO implement me
+	return nil, errors.New("not implemented")
+}
+
 func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
 	//TODO implement me
 	return nil, errors.New("not implemented")

+ 6 - 0
relay/channel/aws/adaptor.go

@@ -59,6 +59,12 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
 	return nil, nil
 }
 
+func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
+	//TODO implement me
+	return nil, errors.New("not implemented")
+}
+
+
 func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
 	return nil, nil
 }

+ 5 - 0
relay/channel/baidu/adaptor.go

@@ -122,6 +122,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
 	return nil, nil
 }
 
+func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
+	//TODO implement me
+	return nil, errors.New("not implemented")
+}
+
 func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
 	return channel.DoApiRequest(a, c, info, requestBody)
 }

+ 5 - 0
relay/channel/claude/adaptor.go

@@ -73,6 +73,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
 	return nil, nil
 }
 
+func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
+	//TODO implement me
+	return nil, errors.New("not implemented")
+}
+
 func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
 	return channel.DoApiRequest(a, c, info, requestBody)
 }

+ 6 - 0
relay/channel/cloudflare/adaptor.go

@@ -56,6 +56,12 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
 	return request, nil
 }
 
+func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
+	//TODO implement me
+	return nil, errors.New("not implemented")
+}
+
+
 func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
 	// 添加文件字段
 	file, _, err := c.Request.FormFile("file")

+ 6 - 0
relay/channel/cohere/adaptor.go

@@ -54,6 +54,12 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
 	return requestConvertRerank2Cohere(request), nil
 }
 
+func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
+	//TODO implement me
+	return nil, errors.New("not implemented")
+}
+
+
 func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
 	if info.RelayMode == constant.RelayModeRerank {
 		err, usage = cohereRerankHandler(c, resp, info)

+ 6 - 0
relay/channel/deepseek/adaptor.go

@@ -49,6 +49,12 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
 	return nil, nil
 }
 
+func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
+	//TODO implement me
+	return nil, errors.New("not implemented")
+}
+
+
 func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
 	return channel.DoApiRequest(a, c, info, requestBody)
 }

+ 6 - 0
relay/channel/dify/adaptor.go

@@ -48,6 +48,12 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
 	return nil, nil
 }
 
+func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
+	//TODO implement me
+	return nil, errors.New("not implemented")
+}
+
+
 func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
 	return channel.DoApiRequest(a, c, info, requestBody)
 }

+ 6 - 0
relay/channel/gemini/adaptor.go

@@ -68,6 +68,12 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
 	return nil, nil
 }
 
+func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
+	//TODO implement me
+	return nil, errors.New("not implemented")
+}
+
+
 func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
 	return channel.DoApiRequest(a, c, info, requestBody)
 }

+ 6 - 0
relay/channel/jina/adaptor.go

@@ -55,6 +55,12 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
 	return request, nil
 }
 
+func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
+	//TODO implement me
+	return nil, errors.New("not implemented")
+}
+
+
 func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
 	if info.RelayMode == constant.RelayModeRerank {
 		err, usage = jinaRerankHandler(c, resp)

+ 6 - 0
relay/channel/mistral/adaptor.go

@@ -50,6 +50,12 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
 	return nil, nil
 }
 
+func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
+	//TODO implement me
+	return nil, errors.New("not implemented")
+}
+
+
 func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
 	return channel.DoApiRequest(a, c, info, requestBody)
 }

+ 27 - 38
relay/channel/mokaai/adaptor.go

@@ -3,54 +3,46 @@ package mokaai
 import (
 	"errors"
 	"fmt"
+	"github.com/gin-gonic/gin"
 	"io"
 	"net/http"
-
-	"github.com/gin-gonic/gin"
-	// "one-api/relay/adaptor"
-	// "one-api/relay/meta"
-	// "one-api/relay/model"
-	// "one-api/relay/constant"
 	"one-api/dto"
 	"one-api/relay/channel"
 	relaycommon "one-api/relay/common"
 	"one-api/relay/constant"
+	"strings"
 )
 
 type Adaptor struct {
 }
 
-// ConvertImageRequest implements adaptor.Adaptor.
-func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
+func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
 	//TODO implement me
 	return nil, errors.New("not implemented")
 }
 
-func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
+func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
 	//TODO implement me
 	return nil, errors.New("not implemented")
 }
 
-func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
+func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
 	//TODO implement me
-	return nil, errors.New("not implemented")
+	return request, nil
 }
+
 func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
-}
 
+}
 
-func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo)  (string, error) {
-	
-	var urlPrefix = info.BaseUrl
-	
-	switch info.RelayMode {
-	case constant.RelayModeChatCompletions:
-		return fmt.Sprintf("%s/chat/completions", urlPrefix), nil
-	case constant.RelayModeEmbeddings:
-		return fmt.Sprintf("%s/embeddings", urlPrefix), nil
-	default:
-		return fmt.Sprintf("%s/run/%s", urlPrefix, info.UpstreamModelName), nil
+func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
+	// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/clntwmv7t
+	suffix := "chat/"
+	if strings.HasPrefix(info.UpstreamModelName, "m3e") {
+		suffix = "embeddings"
 	}
+	fullRequestURL := fmt.Sprintf("%s/%s", info.BaseUrl, suffix)
+	return fullRequestURL, nil
 }
 
 func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
@@ -64,33 +56,30 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re
 		return nil, errors.New("request is nil")
 	}
 	switch info.RelayMode {
-	case constant.RelayModeChatCompletions:
-		return nil, errors.New("not implemented")
-	case  constant.RelayModeEmbeddings:
-		// return ConvertCompletionsRequest(*request), nil
-		return ConvertEmbeddingRequest(*request), nil
+	case constant.RelayModeEmbeddings:
+		baiduEmbeddingRequest := embeddingRequestOpenAI2Moka(*request)
+		return baiduEmbeddingRequest, nil
 	default:
 		return nil, errors.New("not implemented")
 	}
 }
 
+func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
+	return nil, nil
+}
+
 func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
 	return channel.DoApiRequest(a, c, info, requestBody)
 }
 
 func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
-	switch info.RelayMode {
 		
-	case constant.RelayModeAudioTranscription:
-	case constant.RelayModeAudioTranslation:
-	case constant.RelayModeChatCompletions:
-		fallthrough
+	switch info.RelayMode {
 	case constant.RelayModeEmbeddings:
-		if info.IsStream {
-			err, usage = StreamHandler(c, resp, info)
-		} else {
-			err, usage = Handler(c, resp, info)
-		}
+		err, usage = mokaEmbeddingHandler(c, resp)
+	default:
+		// err, usage = mokaHandler(c, resp)
+		
 	}
 	return
 }

+ 36 - 107
relay/channel/mokaai/relay-mokaai.go

@@ -1,41 +1,15 @@
 package mokaai
 
 import (
-	"bufio"
 	"encoding/json"
+	"github.com/gin-gonic/gin"
 	"io"
 	"net/http"
-	"strings"
-
-	// "one-api/common/ctxkey"
-	// "one-api/common/render"
-
-	// "github.com/gin-gonic/gin"
-	// "one-api/common"
-	// "one-api/common/helper"
-	// "one-api/common/logger"
-	// "one-api/relay/adaptor/openai"
-	// "one-api/relay/model"
-
-	"github.com/gin-gonic/gin"
-	"one-api/common"
 	"one-api/dto"
-	relaycommon "one-api/relay/common"
 	"one-api/service"
-	"time"
 )
 
-func ConvertCompletionsRequest(textRequest dto.GeneralOpenAIRequest) *Request {
-	p, _ := textRequest.Prompt.(string)
-	return &Request{
-		Prompt:      p,
-		MaxTokens:   textRequest.GetMaxTokens(),
-		Stream:      textRequest.Stream,
-		Temperature: textRequest.Temperature,
-	}
-}
-
-func ConvertEmbeddingRequest(request dto.GeneralOpenAIRequest) *EmbeddingRequest {
+func embeddingRequestOpenAI2Moka(request dto.GeneralOpenAIRequest) *dto.EmbeddingRequest {
 	var input []string // Change input to []string
 
 	switch v := request.Input.(type) {
@@ -50,105 +24,60 @@ func ConvertEmbeddingRequest(request dto.GeneralOpenAIRequest) *EmbeddingRequest
 			}
 		}
 	}
-
-	return &EmbeddingRequest{
-		Model: request.Model,
-		Input: input, // Assign []string to Input
+	return &dto.EmbeddingRequest{
+		Input: input,
+		Model:  request.Model,
 	}
 }
 
-func StreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
-	scanner := bufio.NewScanner(resp.Body)
-	scanner.Split(bufio.ScanLines)
-
-	service.SetEventStreamHeaders(c)
-	id := service.GetResponseID(c)
-	var responseText string
-	isFirst := true
-
-	for scanner.Scan() {
-		data := scanner.Text()
-		if len(data) < len("data: ") {
-			continue
-		}
-		data = strings.TrimPrefix(data, "data: ")
-		data = strings.TrimSuffix(data, "\r")
-
-		if data == "[DONE]" {
-			break
-		}
-
-		var response dto.ChatCompletionsStreamResponse
-		err := json.Unmarshal([]byte(data), &response)
-		if err != nil {
-			common.LogError(c, "error_unmarshalling_stream_response: "+err.Error())
-			continue
-		}
-		for _, choice := range response.Choices {
-			choice.Delta.Role = "assistant"
-			responseText += choice.Delta.GetContentString()
-		}
-		response.Id = id
-		response.Model = info.UpstreamModelName
-		err = service.ObjectData(c, response)
-		if isFirst {
-			isFirst = false
-			info.FirstResponseTime = time.Now()
-		}
-		if err != nil {
-			common.LogError(c, "error_rendering_stream_response: "+err.Error())
-		}
-	}
-
-	if err := scanner.Err(); err != nil {
-		common.LogError(c, "error_scanning_stream_response: "+err.Error())
+func embeddingResponseMoka2OpenAI(response *dto.EmbeddingResponse) *dto.OpenAIEmbeddingResponse {
+	openAIEmbeddingResponse := dto.OpenAIEmbeddingResponse{
+		Object: "list",
+		Data:   make([]dto.OpenAIEmbeddingResponseItem, 0, len(response.Data)),
+		Model:  "baidu-embedding",
+		Usage:  response.Usage,
 	}
-	usage, _ := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
-	if info.ShouldIncludeUsage {
-		response := service.GenerateFinalUsageResponse(id, info.StartTime.Unix(), info.UpstreamModelName, *usage)
-		err := service.ObjectData(c, response)
-		if err != nil {
-			common.LogError(c, "error_rendering_final_usage_response: "+err.Error())
-		}
-	}
-	service.Done(c)
-
-	err := resp.Body.Close()
-	if err != nil {
-		common.LogError(c, "close_response_body_failed: "+err.Error())
+	for _, item := range response.Data {
+		openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, dto.OpenAIEmbeddingResponseItem{
+			Object:    item.Object,
+			Index:     item.Index,
+			Embedding: item.Embedding,
+		})
 	}
-
-	return nil, usage
+	return &openAIEmbeddingResponse
 }
 
-func Handler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
+func mokaEmbeddingHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
+	var baiduResponse dto.EmbeddingResponse
 	responseBody, err := io.ReadAll(resp.Body)
 	if err != nil {
 		return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
 	}
 	err = resp.Body.Close()
 	if err != nil {
-		return service.OpenAIErrorWrapperLocal(err, "close_response_body_failed", http.StatusInternalServerError), nil
+		return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
 	}
-	var response dto.TextResponse
-	err = json.Unmarshal(responseBody, &response)
+	err = json.Unmarshal(responseBody, &baiduResponse)
 	if err != nil {
 		return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
 	}
-	response.Model = info.UpstreamModelName
-	var responseText string
-	for _, choice := range response.Choices {
-		responseText += choice.Message.StringContent()
-	}
-	usage, _ := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
-	response.Usage = *usage
-	response.Id = service.GetResponseID(c)
-	jsonResponse, err := json.Marshal(response)
+	// if baiduResponse.ErrorMsg != "" {
+	// 	return &dto.OpenAIErrorWithStatusCode{
+	// 		Error: dto.OpenAIError{
+	// 			Type:    "baidu_error",
+	// 			Param:   "",
+	// 		},
+	// 		StatusCode: resp.StatusCode,
+	// 	}, nil
+	// }
+	fullTextResponse := embeddingResponseMoka2OpenAI(&baiduResponse)
+	jsonResponse, err := json.Marshal(fullTextResponse)
 	if err != nil {
 		return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
 	}
 	c.Writer.Header().Set("Content-Type", "application/json")
 	c.Writer.WriteHeader(resp.StatusCode)
-	_, _ = c.Writer.Write(jsonResponse)
-	return nil, usage
+	_, err = c.Writer.Write(jsonResponse)
+	return nil, &fullTextResponse.Usage
 }
+

+ 6 - 0
relay/channel/ollama/adaptor.go

@@ -58,6 +58,12 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
 	return nil, nil
 }
 
+func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
+	//TODO implement me
+	return nil, errors.New("not implemented")
+}
+
+
 func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
 	return channel.DoApiRequest(a, c, info, requestBody)
 }

+ 5 - 0
relay/channel/openai/adaptor.go

@@ -129,6 +129,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
 	return nil, errors.New("not implemented")
 }
 
+func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
+	//TODO implement me
+	return nil, errors.New("not implemented")
+}
+
 func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
 	a.ResponseFormat = request.ResponseFormat
 	if info.RelayMode == constant.RelayModeAudioSpeech {

+ 6 - 0
relay/channel/palm/adaptor.go

@@ -49,6 +49,12 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
 	return nil, nil
 }
 
+func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
+	//TODO implement me
+	return nil, errors.New("not implemented")
+}
+
+
 func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
 	return channel.DoApiRequest(a, c, info, requestBody)
 }

+ 6 - 0
relay/channel/perplexity/adaptor.go

@@ -52,6 +52,12 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
 	return nil, nil
 }
 
+func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
+	//TODO implement me
+	return nil, errors.New("not implemented")
+}
+
+
 func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
 	return channel.DoApiRequest(a, c, info, requestBody)
 }

+ 6 - 0
relay/channel/siliconflow/adaptor.go

@@ -58,6 +58,12 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
 	return request, nil
 }
 
+func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
+	//TODO implement me
+	return nil, errors.New("not implemented")
+}
+
+
 func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
 	switch info.RelayMode {
 	case constant.RelayModeRerank:

+ 6 - 0
relay/channel/tencent/adaptor.go

@@ -73,6 +73,12 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
 	return nil, nil
 }
 
+func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
+	//TODO implement me
+	return nil, errors.New("not implemented")
+}
+
+
 func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
 	return channel.DoApiRequest(a, c, info, requestBody)
 }

+ 6 - 0
relay/channel/vertex/adaptor.go

@@ -151,6 +151,12 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
 	return nil, nil
 }
 
+func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
+	//TODO implement me
+	return nil, errors.New("not implemented")
+}
+
+
 func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
 	return channel.DoApiRequest(a, c, info, requestBody)
 }

+ 6 - 0
relay/channel/xunfei/adaptor.go

@@ -50,6 +50,12 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
 	return nil, nil
 }
 
+func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
+	//TODO implement me
+	return nil, errors.New("not implemented")
+}
+
+
 func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
 	// xunfei's request is not http request, so we don't need to do anything here
 	dummyResp := &http.Response{}

+ 6 - 0
relay/channel/zhipu/adaptor.go

@@ -56,6 +56,12 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
 	return nil, nil
 }
 
+func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
+	//TODO implement me
+	return nil, errors.New("not implemented")
+}
+
+
 func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
 	return channel.DoApiRequest(a, c, info, requestBody)
 }

+ 6 - 0
relay/channel/zhipu_4v/adaptor.go

@@ -53,6 +53,12 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
 	return nil, nil
 }
 
+func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
+	//TODO implement me
+	return nil, errors.New("not implemented")
+}
+
+
 func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
 	return channel.DoApiRequest(a, c, info, requestBody)
 }

+ 127 - 0
relay/relay_embedding.go

@@ -0,0 +1,127 @@
+package relay
+
+import (
+	"bytes"
+	"encoding/json"
+	"fmt"
+	"github.com/gin-gonic/gin"
+	"net/http"
+	"one-api/common"
+	"one-api/dto"
+	relaycommon "one-api/relay/common"
+	relayconstant "one-api/relay/constant"
+	"one-api/service"
+	"one-api/setting"
+)
+
+func getEmbeddingPromptToken(embeddingRequest dto.EmbeddingRequest) int {
+	token, _ := service.CountTokenInput(embeddingRequest.Input, embeddingRequest.Model)
+	return token
+}
+
+func EmbeddingHelper(c *gin.Context, relayMode int) (openaiErr *dto.OpenAIErrorWithStatusCode) {
+	relayInfo := relaycommon.GenRelayInfo(c)
+
+	var embeddingRequest *dto.EmbeddingRequest
+	err := common.UnmarshalBodyReusable(c, &embeddingRequest)
+	if err != nil {
+		common.LogError(c, fmt.Sprintf("getAndValidateTextRequest failed: %s", err.Error()))
+		return service.OpenAIErrorWrapperLocal(err, "invalid_text_request", http.StatusBadRequest)
+	}
+	if relayMode == relayconstant.RelayModeModerations && embeddingRequest.Model == "" {
+		embeddingRequest.Model = "m3e-base"
+	}
+	if relayMode == relayconstant.RelayModeEmbeddings && embeddingRequest.Model == "" {
+		embeddingRequest.Model = c.Param("model")
+	}
+	if embeddingRequest.Input == nil || len(embeddingRequest.Input) == 0 {
+		return service.OpenAIErrorWrapperLocal(fmt.Errorf("input is empty"), "invalid_input", http.StatusBadRequest)
+	}
+	// map model name
+	modelMapping := c.GetString("model_mapping")
+	//isModelMapped := false
+	if modelMapping != "" && modelMapping != "{}" {
+		modelMap := make(map[string]string)
+		err := json.Unmarshal([]byte(modelMapping), &modelMap)
+		if err != nil {
+			return service.OpenAIErrorWrapperLocal(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
+		}
+		if modelMap[embeddingRequest.Model] != "" {
+			embeddingRequest.Model = modelMap[embeddingRequest.Model]
+			// set upstream model name
+			//isModelMapped = true
+		}
+	}
+
+	relayInfo.UpstreamModelName = embeddingRequest.Model
+	modelPrice, success := common.GetModelPrice(embeddingRequest.Model, false)
+	groupRatio := setting.GetGroupRatio(relayInfo.Group)
+
+	var preConsumedQuota int
+	var ratio float64
+	var modelRatio float64
+
+	promptToken := getEmbeddingPromptToken(*embeddingRequest)
+	if !success {
+		preConsumedTokens := promptToken
+		modelRatio = common.GetModelRatio(embeddingRequest.Model)
+		ratio = modelRatio * groupRatio
+		preConsumedQuota = int(float64(preConsumedTokens) * ratio)
+	} else {
+		preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio)
+	}
+	relayInfo.PromptTokens = promptToken
+
+	// pre-consume quota 预消耗配额
+	preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, preConsumedQuota, relayInfo)
+	if openaiErr != nil {
+		return openaiErr
+	}
+	defer func() {
+		if openaiErr != nil {
+			returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
+		}
+	}()
+
+	adaptor := GetAdaptor(relayInfo.ApiType)
+	if adaptor == nil {
+		return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)
+	}
+	adaptor.Init(relayInfo)
+
+	convertedRequest, err := adaptor.ConvertEmbeddingRequest(c,relayInfo,*embeddingRequest)
+	
+	if err != nil {
+		return service.OpenAIErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError)
+	}
+	jsonData, err := json.Marshal(convertedRequest)
+	if err != nil {
+		return service.OpenAIErrorWrapperLocal(err, "json_marshal_failed", http.StatusInternalServerError)
+	}
+	requestBody := bytes.NewBuffer(jsonData)
+	statusCodeMappingStr := c.GetString("status_code_mapping")
+	resp, err := adaptor.DoRequest(c,relayInfo, requestBody)
+	if err != nil {
+		return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
+	}
+
+	var httpResp *http.Response
+	if resp != nil {
+		httpResp = resp.(*http.Response)
+		if httpResp.StatusCode != http.StatusOK {
+			openaiErr = service.RelayErrorHandler(httpResp)
+			// reset status code 重置状态码
+			service.ResetStatusCode(openaiErr, statusCodeMappingStr)
+			return openaiErr
+		}
+	}
+
+	usage, openaiErr := adaptor.DoResponse(c, httpResp, relayInfo)
+	if openaiErr != nil {
+		// reset status code 重置状态码
+		service.ResetStatusCode(openaiErr, statusCodeMappingStr)
+		return openaiErr
+	}
+	postConsumeQuota(c, relayInfo, embeddingRequest.Model, usage.(*dto.Usage), ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, success, "")
+	return nil
+}