Просмотр исходного кода

feat: dalle系列改为使用模型固定价格计费

CaIon 1 год назад
Родитель
Сommit
71547849bc
5 измененных файлов с 32 добавлено и 24 удалено
  1. 6 5
      common/model-ratio.go
  2. 16 9
      relay/relay-image.go
  3. 4 4
      relay/relay-mj.go
  4. 5 5
      relay/relay-text.go
  5. 1 1
      web/src/helpers/render.js

+ 6 - 5
common/model-ratio.go

@@ -61,8 +61,6 @@ var DefaultModelRatio = map[string]float64{
 	"text-search-ada-doc-001":      10,
 	"text-moderation-stable":       0.1,
 	"text-moderation-latest":       0.1,
-	"dall-e-2":                     8,
-	"dall-e-3":                     16,
 	"claude-instant-1":             0.4,    // $0.8 / 1M tokens
 	"claude-2.0":                   4,      // $8 / 1M tokens
 	"claude-2.1":                   4,      // $8 / 1M tokens
@@ -117,6 +115,8 @@ var DefaultModelRatio = map[string]float64{
 }
 
 var DefaultModelPrice = map[string]float64{
+	"dall-e-2":          0.02,
+	"dall-e-3":          0.04,
 	"gpt-4-gizmo-*":     0.1,
 	"mj_imagine":        0.1,
 	"mj_variation":      0.1,
@@ -160,7 +160,8 @@ func UpdateModelPriceByJSONString(jsonStr string) error {
 	return json.Unmarshal([]byte(jsonStr), &modelPrice)
 }
 
-func GetModelPrice(name string, printErr bool) float64 {
+// GetModelPrice 返回模型的价格,如果模型不存在则返回-1,false
+func GetModelPrice(name string, printErr bool) (float64, bool) {
 	if modelPrice == nil {
 		modelPrice = DefaultModelPrice
 	}
@@ -172,9 +173,9 @@ func GetModelPrice(name string, printErr bool) float64 {
 		if printErr {
 			SysError("model price not found: " + name)
 		}
-		return -1
+		return -1, false
 	}
-	return price
+	return price, true
 }
 
 func ModelRatio2JSONString() string {

+ 16 - 9
relay/relay-image.go

@@ -8,6 +8,7 @@ import (
 	"fmt"
 	"github.com/gin-gonic/gin"
 	"io"
+	"log"
 	"net/http"
 	"one-api/common"
 	"one-api/dto"
@@ -106,21 +107,27 @@ func RelayImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusC
 		requestBody = c.Request.Body
 	}
 
-	modelRatio := common.GetModelRatio(imageRequest.Model)
+	modelPrice, success := common.GetModelPrice(imageRequest.Model, true)
+	if !success {
+		modelRatio := common.GetModelRatio(imageRequest.Model)
+		// modelRatio 16 = modelPrice $0.04
+		// per 1 modelRatio = $0.04 / 16
+		modelPrice = 0.0025 * modelRatio
+	}
+	log.Printf("modelPrice: %f", modelPrice)
 	groupRatio := common.GetGroupRatio(group)
-	ratio := modelRatio * groupRatio
 	userQuota, err := model.CacheGetUserQuota(userId)
 
 	sizeRatio := 1.0
 	// Size
 	if imageRequest.Size == "256x256" {
-		sizeRatio = 1
+		sizeRatio = 0.4
 	} else if imageRequest.Size == "512x512" {
-		sizeRatio = 1.125
+		sizeRatio = 0.45
 	} else if imageRequest.Size == "1024x1024" {
-		sizeRatio = 1.25
+		sizeRatio = 1
 	} else if imageRequest.Size == "1024x1792" || imageRequest.Size == "1792x1024" {
-		sizeRatio = 2.5
+		sizeRatio = 2
 	}
 
 	qualityRatio := 1.0
@@ -131,7 +138,7 @@ func RelayImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusC
 		}
 	}
 
-	quota := int(ratio*sizeRatio*qualityRatio*1000) * imageRequest.N
+	quota := int(modelPrice*groupRatio*common.QuotaPerUnit*sizeRatio*qualityRatio) * imageRequest.N
 
 	if userQuota-quota < 0 {
 		return service.OpenAIErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
@@ -190,9 +197,9 @@ func RelayImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusC
 			if imageRequest.Quality == "hd" {
 				quality = "hd"
 			}
-			logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f, 大小 %s, 品质 %s", modelRatio, groupRatio, imageRequest.Size, quality)
+			logContent := fmt.Sprintf("模型价格 %.2f,分组倍率 %.2f, 大小 %s, 品质 %s", modelPrice, groupRatio, imageRequest.Size, quality)
 			other := make(map[string]interface{})
-			other["model_ratio"] = modelRatio
+			other["model_price"] = modelPrice
 			other["group_ratio"] = groupRatio
 			model.RecordConsumeLog(ctx, userId, channelId, 0, 0, imageRequest.Model, tokenName, quota, logContent, tokenId, userQuota, int(useTimeSeconds), false, other)
 			model.UpdateUserUsedQuotaAndRequestCount(userId, quota)

+ 4 - 4
relay/relay-mj.go

@@ -155,9 +155,9 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse {
 		return service.MidjourneyErrorWrapper(constant.MjRequestError, "sour_base64_and_target_base64_is_required")
 	}
 	modelName := service.CoverActionToModelName(constant.MjActionSwapFace)
-	modelPrice := common.GetModelPrice(modelName, true)
+	modelPrice, success := common.GetModelPrice(modelName, true)
 	// 如果没有配置价格,则使用默认价格
-	if modelPrice == -1 {
+	if !success {
 		defaultPrice, ok := common.DefaultModelPrice[modelName]
 		if !ok {
 			modelPrice = 0.1
@@ -454,9 +454,9 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
 	fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
 
 	modelName := service.CoverActionToModelName(midjRequest.Action)
-	modelPrice := common.GetModelPrice(modelName, true)
+	modelPrice, success := common.GetModelPrice(modelName, true)
 	// 如果没有配置价格,则使用默认价格
-	if modelPrice == -1 {
+	if !success {
 		defaultPrice, ok := common.DefaultModelPrice[modelName]
 		if !ok {
 			modelPrice = 0.1

+ 5 - 5
relay/relay-text.go

@@ -91,7 +91,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
 		}
 	}
 	relayInfo.UpstreamModelName = textRequest.Model
-	modelPrice := common.GetModelPrice(textRequest.Model, false)
+	modelPrice, success := common.GetModelPrice(textRequest.Model, false)
 	groupRatio := common.GetGroupRatio(relayInfo.Group)
 
 	var preConsumedQuota int
@@ -108,7 +108,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
 		return service.OpenAIErrorWrapper(err, "count_token_messages_failed", http.StatusInternalServerError)
 	}
 
-	if modelPrice == -1 {
+	if !success {
 		preConsumedTokens := common.PreConsumedQuota
 		if textRequest.MaxTokens != 0 {
 			preConsumedTokens = promptTokens + int(textRequest.MaxTokens)
@@ -178,7 +178,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
 		service.ResetStatusCode(openaiErr, statusCodeMappingStr)
 		return openaiErr
 	}
-	postConsumeQuota(c, relayInfo, *textRequest, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice)
+	postConsumeQuota(c, relayInfo, *textRequest, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, success)
 	return nil
 }
 
@@ -257,7 +257,7 @@ func returnPreConsumedQuota(c *gin.Context, tokenId int, userQuota int, preConsu
 
 func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, textRequest dto.GeneralOpenAIRequest,
 	usage *dto.Usage, ratio float64, preConsumedQuota int, userQuota int, modelRatio float64, groupRatio float64,
-	modelPrice float64) {
+	modelPrice float64, usePrice bool) {
 
 	useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
 	promptTokens := usage.PromptTokens
@@ -267,7 +267,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, textRe
 	completionRatio := common.GetCompletionRatio(textRequest.Model)
 
 	quota := 0
-	if modelPrice == -1 {
+	if !usePrice {
 		quota = promptTokens + int(math.Round(float64(completionTokens)*completionRatio))
 		quota = int(math.Round(float64(quota) * ratio))
 		if ratio != 0 && quota <= 0 {

+ 1 - 1
web/src/helpers/render.js

@@ -159,7 +159,7 @@ export function renderModelPrice(
         <article>
           <p>提示 ${inputRatioPrice} / 1M tokens</p>
           <p>补全 ${completionRatioPrice} / 1M tokens</p>
-          <p>计算过程:</p>
+          <p> </p>
           <p>
             提示 {inputTokens} tokens / 1M tokens * ${inputRatioPrice} + 补全{' '}
             {completionTokens} tokens / 1M tokens * ${completionRatioPrice} = $