Просмотр исходного кода

feat: add support for /v1/engines/text-embedding-ada-002/embeddings (#224, close #222)

玩牛牛 2 лет назад
Родитель
Сommit
81c5901123
4 измененных файлов с 19 добавлено и 4 удалено
  1. 5 1
      controller/relay-text.go
  2. 4 1
      controller/relay.go
  3. 7 1
      middleware/distributor.go
  4. 3 1
      router/relay-router.go

+ 5 - 1
controller/relay-text.go

@@ -6,12 +6,13 @@ import (
 	"encoding/json"
 	"errors"
 	"fmt"
-	"github.com/gin-gonic/gin"
 	"io"
 	"net/http"
 	"one-api/common"
 	"one-api/model"
 	"strings"
+
+	"github.com/gin-gonic/gin"
 )
 
 func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
@@ -30,6 +31,9 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 	if relayMode == RelayModeModerations && textRequest.Model == "" {
 		textRequest.Model = "text-moderation-latest"
 	}
+	if relayMode == RelayModeEmbeddings && textRequest.Model == "" {
+		textRequest.Model = c.Param("model")
+	}
 	// request validation
 	if textRequest.Model == "" {
 		return errorWrapper(errors.New("model is required"), "required_field_missing", http.StatusBadRequest)

+ 4 - 1
controller/relay.go

@@ -2,10 +2,11 @@ package controller
 
 import (
 	"fmt"
-	"github.com/gin-gonic/gin"
 	"net/http"
 	"one-api/common"
 	"strings"
+
+	"github.com/gin-gonic/gin"
 )
 
 type Message struct {
@@ -100,6 +101,8 @@ func Relay(c *gin.Context) {
 		relayMode = RelayModeCompletions
 	} else if strings.HasPrefix(c.Request.URL.Path, "/v1/embeddings") {
 		relayMode = RelayModeEmbeddings
+	} else if strings.HasSuffix(c.Request.URL.Path, "embeddings") {
+		relayMode = RelayModeEmbeddings
 	} else if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
 		relayMode = RelayModeModerations
 	} else if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {

+ 7 - 1
middleware/distributor.go

@@ -2,12 +2,13 @@ package middleware
 
 import (
 	"fmt"
-	"github.com/gin-gonic/gin"
 	"net/http"
 	"one-api/common"
 	"one-api/model"
 	"strconv"
 	"strings"
+
+	"github.com/gin-gonic/gin"
 )
 
 type ModelRequest struct {
@@ -73,6 +74,11 @@ func Distribute() func(c *gin.Context) {
 					modelRequest.Model = "text-moderation-stable"
 				}
 			}
+			if strings.HasSuffix(c.Request.URL.Path, "embeddings") {
+				if modelRequest.Model == "" {
+					modelRequest.Model = c.Param("model")
+				}
+			}
 			channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model)
 			if err != nil {
 				message := "无可用渠道"

+ 3 - 1
router/relay-router.go

@@ -1,9 +1,10 @@
 package router
 
 import (
-	"github.com/gin-gonic/gin"
 	"one-api/controller"
 	"one-api/middleware"
+
+	"github.com/gin-gonic/gin"
 )
 
 func SetRelayRouter(router *gin.Engine) {
@@ -24,6 +25,7 @@ func SetRelayRouter(router *gin.Engine) {
 		relayV1Router.POST("/images/edits", controller.RelayNotImplemented)
 		relayV1Router.POST("/images/variations", controller.RelayNotImplemented)
 		relayV1Router.POST("/embeddings", controller.Relay)
+		relayV1Router.POST("/engines/:model/embeddings", controller.Relay)
 		relayV1Router.POST("/audio/transcriptions", controller.RelayNotImplemented)
 		relayV1Router.POST("/audio/translations", controller.RelayNotImplemented)
 		relayV1Router.GET("/files", controller.RelayNotImplemented)