Przeglądaj źródła

feat: support ollama multi-text embedding

CalciumIon 1 rok temu
rodzic
commit
0cbf8e07e7
2 zmienionych plików z 28 dodań i 5 usunięć
  1. 16 2
      relay/channel/ollama/dto.go
  2. 12 3
      relay/channel/ollama/relay-ollama.go

+ 16 - 2
relay/channel/ollama/dto.go

@@ -17,11 +17,25 @@ type OllamaRequest struct {
 	PresencePenalty  float64        `json:"presence_penalty,omitempty"`
 }
 
+type Options struct {
+	Seed             int     `json:"seed,omitempty"`
+	Temperature      float64 `json:"temperature,omitempty"`
+	TopK             int     `json:"top_k,omitempty"`
+	TopP             float64 `json:"top_p,omitempty"`
+	FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
+	PresencePenalty  float64 `json:"presence_penalty,omitempty"`
+	NumPredict       int     `json:"num_predict,omitempty"`
+	NumCtx           int     `json:"num_ctx,omitempty"`
+}
+
 type OllamaEmbeddingRequest struct {
-	Model  string `json:"model,omitempty"`
-	Prompt any    `json:"prompt,omitempty"`
+	Model   string   `json:"model,omitempty"`
+	Input   []string `json:"input"`
+	Options *Options `json:"options,omitempty"`
 }
 
 type OllamaEmbeddingResponse struct {
+	Error     string    `json:"error,omitempty"`
+	Model     string    `json:"model"`
 	Embedding []float64 `json:"embedding,omitempty"`
 }

+ 12 - 3
relay/channel/ollama/relay-ollama.go

@@ -9,7 +9,6 @@ import (
 	"net/http"
 	"one-api/dto"
 	"one-api/service"
-	"strings"
 )
 
 func requestOpenAI2Ollama(request dto.GeneralOpenAIRequest) *OllamaRequest {
@@ -45,8 +44,15 @@ func requestOpenAI2Ollama(request dto.GeneralOpenAIRequest) *OllamaRequest {
 
 func requestOpenAI2Embeddings(request dto.GeneralOpenAIRequest) *OllamaEmbeddingRequest {
 	return &OllamaEmbeddingRequest{
-		Model:  request.Model,
-		Prompt: strings.Join(request.ParseInput(), " "),
+		Model: request.Model,
+		Input: request.ParseInput(),
+		Options: &Options{
+			Seed:             int(request.Seed),
+			Temperature:      request.Temperature,
+			TopP:             request.TopP,
+			FrequencyPenalty: request.FrequencyPenalty,
+			PresencePenalty:  request.PresencePenalty,
+		},
 	}
 }
 
@@ -64,6 +70,9 @@ func ollamaEmbeddingHandler(c *gin.Context, resp *http.Response, promptTokens in
 	if err != nil {
 		return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
 	}
+	if ollamaEmbeddingResponse.Error != "" {
+		return service.OpenAIErrorWrapper(err, "ollama_error", resp.StatusCode), nil
+	}
 	data := make([]dto.OpenAIEmbeddingResponseItem, 0, 1)
 	data = append(data, dto.OpenAIEmbeddingResponseItem{
 		Embedding: ollamaEmbeddingResponse.Embedding,