Просмотр исходного кода

feat: supper whisper now (close #197)

JustSong 2 лет назад
Родитель
Сommit
d09d317459
6 измененных файлов с 177 добавлено и 4 удалено
  1. 1 1
      common/model-ratio.go
  2. 9 0
      controller/model.go
  3. 147 0
      controller/relay-audio.go
  4. 9 0
      controller/relay.go
  5. 9 1
      middleware/distributor.go
  6. 2 2
      router/relay-router.go

+ 1 - 1
common/model-ratio.go

@@ -31,7 +31,7 @@ var ModelRatio = map[string]float64{
 	"text-davinci-003":          10,
 	"text-davinci-edit-001":     10,
 	"code-davinci-edit-001":     10,
-	"whisper-1":                 10,
+	"whisper-1":                 15, // $0.006 / minute -> $0.006 / 150 words -> $0.006 / 200 tokens -> $0.03 / 1k tokens
 	"davinci":                   10,
 	"curie":                     10,
 	"babbage":                   10,

+ 9 - 0
controller/model.go

@@ -63,6 +63,15 @@ func init() {
 			Root:       "dall-e",
 			Parent:     nil,
 		},
+		{
+			Id:         "whisper-1",
+			Object:     "model",
+			Created:    1677649963,
+			OwnedBy:    "openai",
+			Permission: permission,
+			Root:       "whisper-1",
+			Parent:     nil,
+		},
 		{
 			Id:         "gpt-3.5-turbo",
 			Object:     "model",

+ 147 - 0
controller/relay-audio.go

@@ -0,0 +1,147 @@
+package controller
+
+import (
+	"bytes"
+	"encoding/json"
+	"fmt"
+	"io"
+	"net/http"
+	"one-api/common"
+	"one-api/model"
+
+	"github.com/gin-gonic/gin"
+)
+
+func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
+	audioModel := "whisper-1"
+
+	tokenId := c.GetInt("token_id")
+	channelType := c.GetInt("channel")
+	userId := c.GetInt("id")
+	group := c.GetString("group")
+
+	preConsumedTokens := common.PreConsumedQuota
+	modelRatio := common.GetModelRatio(audioModel)
+	groupRatio := common.GetGroupRatio(group)
+	ratio := modelRatio * groupRatio
+	preConsumedQuota := int(float64(preConsumedTokens) * ratio)
+	userQuota, err := model.CacheGetUserQuota(userId)
+	if err != nil {
+		return errorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
+	}
+	err = model.CacheDecreaseUserQuota(userId, preConsumedQuota)
+	if err != nil {
+		return errorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError)
+	}
+	if userQuota > 100*preConsumedQuota {
+		// in this case, we do not pre-consume quota
+		// because the user has enough quota
+		preConsumedQuota = 0
+	}
+	if preConsumedQuota > 0 {
+		err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota)
+		if err != nil {
+			return errorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden)
+		}
+	}
+
+	// map model name
+	modelMapping := c.GetString("model_mapping")
+	if modelMapping != "" {
+		modelMap := make(map[string]string)
+		err := json.Unmarshal([]byte(modelMapping), &modelMap)
+		if err != nil {
+			return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
+		}
+		if modelMap[audioModel] != "" {
+			audioModel = modelMap[audioModel]
+		}
+	}
+
+	baseURL := common.ChannelBaseURLs[channelType]
+	requestURL := c.Request.URL.String()
+
+	if c.GetString("base_url") != "" {
+		baseURL = c.GetString("base_url")
+	}
+
+	fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
+	requestBody := c.Request.Body
+
+	req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
+	if err != nil {
+		return errorWrapper(err, "new_request_failed", http.StatusInternalServerError)
+	}
+	req.Header.Set("Authorization", c.Request.Header.Get("Authorization"))
+	req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
+	req.Header.Set("Accept", c.Request.Header.Get("Accept"))
+
+	resp, err := httpClient.Do(req)
+	if err != nil {
+		return errorWrapper(err, "do_request_failed", http.StatusInternalServerError)
+	}
+
+	err = req.Body.Close()
+	if err != nil {
+		return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
+	}
+	err = c.Request.Body.Close()
+	if err != nil {
+		return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
+	}
+	var audioResponse AudioResponse
+
+	defer func() {
+		go func() {
+			quota := countTokenText(audioResponse.Text, audioModel)
+			quotaDelta := quota - preConsumedQuota
+			err := model.PostConsumeTokenQuota(tokenId, quotaDelta)
+			if err != nil {
+				common.SysError("error consuming token remain quota: " + err.Error())
+			}
+			err = model.CacheUpdateUserQuota(userId)
+			if err != nil {
+				common.SysError("error update user quota cache: " + err.Error())
+			}
+			if quota != 0 {
+				tokenName := c.GetString("token_name")
+				logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
+				model.RecordConsumeLog(userId, 0, 0, audioModel, tokenName, quota, logContent)
+				model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
+				channelId := c.GetInt("channel_id")
+				model.UpdateChannelUsedQuota(channelId, quota)
+			}
+		}()
+	}()
+
+	responseBody, err := io.ReadAll(resp.Body)
+
+	if err != nil {
+		return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
+	}
+	err = resp.Body.Close()
+	if err != nil {
+		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
+	}
+	err = json.Unmarshal(responseBody, &audioResponse)
+	if err != nil {
+		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
+	}
+
+	resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
+
+	for k, v := range resp.Header {
+		c.Writer.Header().Set(k, v[0])
+	}
+	c.Writer.WriteHeader(resp.StatusCode)
+
+	_, err = io.Copy(c.Writer, resp.Body)
+	if err != nil {
+		return errorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError)
+	}
+	err = resp.Body.Close()
+	if err != nil {
+		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
+	}
+	return nil
+}

+ 9 - 0
controller/relay.go

@@ -24,6 +24,7 @@ const (
 	RelayModeModerations
 	RelayModeImagesGenerations
 	RelayModeEdits
+	RelayModeAudio
 )
 
 // https://platform.openai.com/docs/api-reference/chat
@@ -63,6 +64,10 @@ type ImageRequest struct {
 	Size   string `json:"size"`
 }
 
+type AudioResponse struct {
+	Text string `json:"text,omitempty"`
+}
+
 type Usage struct {
 	PromptTokens     int `json:"prompt_tokens"`
 	CompletionTokens int `json:"completion_tokens"`
@@ -159,11 +164,15 @@ func Relay(c *gin.Context) {
 		relayMode = RelayModeImagesGenerations
 	} else if strings.HasPrefix(c.Request.URL.Path, "/v1/edits") {
 		relayMode = RelayModeEdits
+	} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") {
+		relayMode = RelayModeAudio
 	}
 	var err *OpenAIErrorWithStatusCode
 	switch relayMode {
 	case RelayModeImagesGenerations:
 		err = relayImageHelper(c, relayMode)
+	case RelayModeAudio:
+		err = relayAudioHelper(c, relayMode)
 	default:
 		err = relayTextHelper(c, relayMode)
 	}

+ 9 - 1
middleware/distributor.go

@@ -58,7 +58,10 @@ func Distribute() func(c *gin.Context) {
 		} else {
 			// Select a channel for the user
 			var modelRequest ModelRequest
-			err := common.UnmarshalBodyReusable(c, &modelRequest)
+			var err error
+			if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio") {
+				err = common.UnmarshalBodyReusable(c, &modelRequest)
+			}
 			if err != nil {
 				c.JSON(http.StatusBadRequest, gin.H{
 					"error": gin.H{
@@ -84,6 +87,11 @@ func Distribute() func(c *gin.Context) {
 					modelRequest.Model = "dall-e"
 				}
 			}
+			if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") {
+				if modelRequest.Model == "" {
+					modelRequest.Model = "whisper-1"
+				}
+			}
 			channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model)
 			if err != nil {
 				message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model)

+ 2 - 2
router/relay-router.go

@@ -26,8 +26,8 @@ func SetRelayRouter(router *gin.Engine) {
 		relayV1Router.POST("/images/variations", controller.RelayNotImplemented)
 		relayV1Router.POST("/embeddings", controller.Relay)
 		relayV1Router.POST("/engines/:model/embeddings", controller.Relay)
-		relayV1Router.POST("/audio/transcriptions", controller.RelayNotImplemented)
-		relayV1Router.POST("/audio/translations", controller.RelayNotImplemented)
+		relayV1Router.POST("/audio/transcriptions", controller.Relay)
+		relayV1Router.POST("/audio/translations", controller.Relay)
 		relayV1Router.GET("/files", controller.RelayNotImplemented)
 		relayV1Router.POST("/files", controller.RelayNotImplemented)
 		relayV1Router.DELETE("/files/:id", controller.RelayNotImplemented)