Explorar o código

feat: dalle系列改为使用模型固定价格计费

CaIon hai 1 ano
pai
achega
71547849bc
Modificáronse 5 ficheiros con 32 adicións e 24 borrados
  1. 6 5
      common/model-ratio.go
  2. 16 9
      relay/relay-image.go
  3. 4 4
      relay/relay-mj.go
  4. 5 5
      relay/relay-text.go
  5. 1 1
      web/src/helpers/render.js

+ 6 - 5
common/model-ratio.go

@@ -61,8 +61,6 @@ var DefaultModelRatio = map[string]float64{
 	"text-search-ada-doc-001":      10,
 	"text-search-ada-doc-001":      10,
 	"text-moderation-stable":       0.1,
 	"text-moderation-stable":       0.1,
 	"text-moderation-latest":       0.1,
 	"text-moderation-latest":       0.1,
-	"dall-e-2":                     8,
-	"dall-e-3":                     16,
 	"claude-instant-1":             0.4,    // $0.8 / 1M tokens
 	"claude-instant-1":             0.4,    // $0.8 / 1M tokens
 	"claude-2.0":                   4,      // $8 / 1M tokens
 	"claude-2.0":                   4,      // $8 / 1M tokens
 	"claude-2.1":                   4,      // $8 / 1M tokens
 	"claude-2.1":                   4,      // $8 / 1M tokens
@@ -117,6 +115,8 @@ var DefaultModelRatio = map[string]float64{
 }
 }
 
 
 var DefaultModelPrice = map[string]float64{
 var DefaultModelPrice = map[string]float64{
+	"dall-e-2":          0.02,
+	"dall-e-3":          0.04,
 	"gpt-4-gizmo-*":     0.1,
 	"gpt-4-gizmo-*":     0.1,
 	"mj_imagine":        0.1,
 	"mj_imagine":        0.1,
 	"mj_variation":      0.1,
 	"mj_variation":      0.1,
@@ -160,7 +160,8 @@ func UpdateModelPriceByJSONString(jsonStr string) error {
 	return json.Unmarshal([]byte(jsonStr), &modelPrice)
 	return json.Unmarshal([]byte(jsonStr), &modelPrice)
 }
 }
 
 
-func GetModelPrice(name string, printErr bool) float64 {
+// GetModelPrice 返回模型的价格,如果模型不存在则返回-1,false
+func GetModelPrice(name string, printErr bool) (float64, bool) {
 	if modelPrice == nil {
 	if modelPrice == nil {
 		modelPrice = DefaultModelPrice
 		modelPrice = DefaultModelPrice
 	}
 	}
@@ -172,9 +173,9 @@ func GetModelPrice(name string, printErr bool) float64 {
 		if printErr {
 		if printErr {
 			SysError("model price not found: " + name)
 			SysError("model price not found: " + name)
 		}
 		}
-		return -1
+		return -1, false
 	}
 	}
-	return price
+	return price, true
 }
 }
 
 
 func ModelRatio2JSONString() string {
 func ModelRatio2JSONString() string {

+ 16 - 9
relay/relay-image.go

@@ -8,6 +8,7 @@ import (
 	"fmt"
 	"fmt"
 	"github.com/gin-gonic/gin"
 	"github.com/gin-gonic/gin"
 	"io"
 	"io"
+	"log"
 	"net/http"
 	"net/http"
 	"one-api/common"
 	"one-api/common"
 	"one-api/dto"
 	"one-api/dto"
@@ -106,21 +107,27 @@ func RelayImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusC
 		requestBody = c.Request.Body
 		requestBody = c.Request.Body
 	}
 	}
 
 
-	modelRatio := common.GetModelRatio(imageRequest.Model)
+	modelPrice, success := common.GetModelPrice(imageRequest.Model, true)
+	if !success {
+		modelRatio := common.GetModelRatio(imageRequest.Model)
+		// modelRatio 16 = modelPrice $0.04
+		// per 1 modelRatio = $0.04 / 16
+		modelPrice = 0.0025 * modelRatio
+	}
+	log.Printf("modelPrice: %f", modelPrice)
 	groupRatio := common.GetGroupRatio(group)
 	groupRatio := common.GetGroupRatio(group)
-	ratio := modelRatio * groupRatio
 	userQuota, err := model.CacheGetUserQuota(userId)
 	userQuota, err := model.CacheGetUserQuota(userId)
 
 
 	sizeRatio := 1.0
 	sizeRatio := 1.0
 	// Size
 	// Size
 	if imageRequest.Size == "256x256" {
 	if imageRequest.Size == "256x256" {
-		sizeRatio = 1
+		sizeRatio = 0.4
 	} else if imageRequest.Size == "512x512" {
 	} else if imageRequest.Size == "512x512" {
-		sizeRatio = 1.125
+		sizeRatio = 0.45
 	} else if imageRequest.Size == "1024x1024" {
 	} else if imageRequest.Size == "1024x1024" {
-		sizeRatio = 1.25
+		sizeRatio = 1
 	} else if imageRequest.Size == "1024x1792" || imageRequest.Size == "1792x1024" {
 	} else if imageRequest.Size == "1024x1792" || imageRequest.Size == "1792x1024" {
-		sizeRatio = 2.5
+		sizeRatio = 2
 	}
 	}
 
 
 	qualityRatio := 1.0
 	qualityRatio := 1.0
@@ -131,7 +138,7 @@ func RelayImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusC
 		}
 		}
 	}
 	}
 
 
-	quota := int(ratio*sizeRatio*qualityRatio*1000) * imageRequest.N
+	quota := int(modelPrice*groupRatio*common.QuotaPerUnit*sizeRatio*qualityRatio) * imageRequest.N
 
 
 	if userQuota-quota < 0 {
 	if userQuota-quota < 0 {
 		return service.OpenAIErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
 		return service.OpenAIErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
@@ -190,9 +197,9 @@ func RelayImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusC
 			if imageRequest.Quality == "hd" {
 			if imageRequest.Quality == "hd" {
 				quality = "hd"
 				quality = "hd"
 			}
 			}
-			logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f, 大小 %s, 品质 %s", modelRatio, groupRatio, imageRequest.Size, quality)
+			logContent := fmt.Sprintf("模型价格 %.2f,分组倍率 %.2f, 大小 %s, 品质 %s", modelPrice, groupRatio, imageRequest.Size, quality)
 			other := make(map[string]interface{})
 			other := make(map[string]interface{})
-			other["model_ratio"] = modelRatio
+			other["model_price"] = modelPrice
 			other["group_ratio"] = groupRatio
 			other["group_ratio"] = groupRatio
 			model.RecordConsumeLog(ctx, userId, channelId, 0, 0, imageRequest.Model, tokenName, quota, logContent, tokenId, userQuota, int(useTimeSeconds), false, other)
 			model.RecordConsumeLog(ctx, userId, channelId, 0, 0, imageRequest.Model, tokenName, quota, logContent, tokenId, userQuota, int(useTimeSeconds), false, other)
 			model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
 			model.UpdateUserUsedQuotaAndRequestCount(userId, quota)

+ 4 - 4
relay/relay-mj.go

@@ -155,9 +155,9 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse {
 		return service.MidjourneyErrorWrapper(constant.MjRequestError, "sour_base64_and_target_base64_is_required")
 		return service.MidjourneyErrorWrapper(constant.MjRequestError, "sour_base64_and_target_base64_is_required")
 	}
 	}
 	modelName := service.CoverActionToModelName(constant.MjActionSwapFace)
 	modelName := service.CoverActionToModelName(constant.MjActionSwapFace)
-	modelPrice := common.GetModelPrice(modelName, true)
+	modelPrice, success := common.GetModelPrice(modelName, true)
 	// 如果没有配置价格,则使用默认价格
 	// 如果没有配置价格,则使用默认价格
-	if modelPrice == -1 {
+	if !success {
 		defaultPrice, ok := common.DefaultModelPrice[modelName]
 		defaultPrice, ok := common.DefaultModelPrice[modelName]
 		if !ok {
 		if !ok {
 			modelPrice = 0.1
 			modelPrice = 0.1
@@ -454,9 +454,9 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
 	fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
 	fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
 
 
 	modelName := service.CoverActionToModelName(midjRequest.Action)
 	modelName := service.CoverActionToModelName(midjRequest.Action)
-	modelPrice := common.GetModelPrice(modelName, true)
+	modelPrice, success := common.GetModelPrice(modelName, true)
 	// 如果没有配置价格,则使用默认价格
 	// 如果没有配置价格,则使用默认价格
-	if modelPrice == -1 {
+	if !success {
 		defaultPrice, ok := common.DefaultModelPrice[modelName]
 		defaultPrice, ok := common.DefaultModelPrice[modelName]
 		if !ok {
 		if !ok {
 			modelPrice = 0.1
 			modelPrice = 0.1

+ 5 - 5
relay/relay-text.go

@@ -91,7 +91,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
 		}
 		}
 	}
 	}
 	relayInfo.UpstreamModelName = textRequest.Model
 	relayInfo.UpstreamModelName = textRequest.Model
-	modelPrice := common.GetModelPrice(textRequest.Model, false)
+	modelPrice, success := common.GetModelPrice(textRequest.Model, false)
 	groupRatio := common.GetGroupRatio(relayInfo.Group)
 	groupRatio := common.GetGroupRatio(relayInfo.Group)
 
 
 	var preConsumedQuota int
 	var preConsumedQuota int
@@ -108,7 +108,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
 		return service.OpenAIErrorWrapper(err, "count_token_messages_failed", http.StatusInternalServerError)
 		return service.OpenAIErrorWrapper(err, "count_token_messages_failed", http.StatusInternalServerError)
 	}
 	}
 
 
-	if modelPrice == -1 {
+	if !success {
 		preConsumedTokens := common.PreConsumedQuota
 		preConsumedTokens := common.PreConsumedQuota
 		if textRequest.MaxTokens != 0 {
 		if textRequest.MaxTokens != 0 {
 			preConsumedTokens = promptTokens + int(textRequest.MaxTokens)
 			preConsumedTokens = promptTokens + int(textRequest.MaxTokens)
@@ -178,7 +178,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
 		service.ResetStatusCode(openaiErr, statusCodeMappingStr)
 		service.ResetStatusCode(openaiErr, statusCodeMappingStr)
 		return openaiErr
 		return openaiErr
 	}
 	}
-	postConsumeQuota(c, relayInfo, *textRequest, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice)
+	postConsumeQuota(c, relayInfo, *textRequest, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, success)
 	return nil
 	return nil
 }
 }
 
 
@@ -257,7 +257,7 @@ func returnPreConsumedQuota(c *gin.Context, tokenId int, userQuota int, preConsu
 
 
 func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, textRequest dto.GeneralOpenAIRequest,
 func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, textRequest dto.GeneralOpenAIRequest,
 	usage *dto.Usage, ratio float64, preConsumedQuota int, userQuota int, modelRatio float64, groupRatio float64,
 	usage *dto.Usage, ratio float64, preConsumedQuota int, userQuota int, modelRatio float64, groupRatio float64,
-	modelPrice float64) {
+	modelPrice float64, usePrice bool) {
 
 
 	useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
 	useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
 	promptTokens := usage.PromptTokens
 	promptTokens := usage.PromptTokens
@@ -267,7 +267,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, textRe
 	completionRatio := common.GetCompletionRatio(textRequest.Model)
 	completionRatio := common.GetCompletionRatio(textRequest.Model)
 
 
 	quota := 0
 	quota := 0
-	if modelPrice == -1 {
+	if !usePrice {
 		quota = promptTokens + int(math.Round(float64(completionTokens)*completionRatio))
 		quota = promptTokens + int(math.Round(float64(completionTokens)*completionRatio))
 		quota = int(math.Round(float64(quota) * ratio))
 		quota = int(math.Round(float64(quota) * ratio))
 		if ratio != 0 && quota <= 0 {
 		if ratio != 0 && quota <= 0 {

+ 1 - 1
web/src/helpers/render.js

@@ -159,7 +159,7 @@ export function renderModelPrice(
         <article>
         <article>
           <p>提示 ${inputRatioPrice} / 1M tokens</p>
           <p>提示 ${inputRatioPrice} / 1M tokens</p>
           <p>补全 ${completionRatioPrice} / 1M tokens</p>
           <p>补全 ${completionRatioPrice} / 1M tokens</p>
-          <p>计算过程:</p>
+          <p> </p>
           <p>
           <p>
             提示 {inputTokens} tokens / 1M tokens * ${inputRatioPrice} + 补全{' '}
             提示 {inputTokens} tokens / 1M tokens * ${inputRatioPrice} + 补全{' '}
             {completionTokens} tokens / 1M tokens * ${completionRatioPrice} = $
             {completionTokens} tokens / 1M tokens * ${completionRatioPrice} = $