瀏覽代碼

✨ feat: enhance group ratio handling in pricing calculations

CaIon 8 月之前
父節點
當前提交
0f35d2368f
共有 8 個文件被更改,包括 74 次插入104 次删除
  1. 3 3
      controller/channel-test.go
  2. 6 3
      model/cache.go
  3. 3 3
      model/option.go
  4. 41 20
      relay/helper/price.go
  5. 1 1
      relay/relay-image.go
  6. 2 3
      relay/relay-text.go
  7. 6 37
      relay/websocket.go
  8. 12 34
      service/quota.go

+ 3 - 3
controller/channel-test.go

@@ -165,8 +165,8 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
 	tok := time.Now()
 	milliseconds := tok.Sub(tik).Milliseconds()
 	consumedTime := float64(milliseconds) / 1000.0
-	other := service.GenerateTextOtherInfo(c, info, priceData.ModelRatio, priceData.GroupRatio, priceData.CompletionRatio,
-		usage.PromptTokensDetails.CachedTokens, priceData.CacheRatio, priceData.ModelPrice, priceData.UserGroupRatio)
+	other := service.GenerateTextOtherInfo(c, info, priceData.ModelRatio, priceData.GroupRatioInfo.GroupRatio, priceData.CompletionRatio,
+		usage.PromptTokensDetails.CachedTokens, priceData.CacheRatio, priceData.ModelPrice, priceData.GroupRatioInfo.GroupSpecialRatio)
 	model.RecordConsumeLog(c, 1, channel.Id, usage.PromptTokens, usage.CompletionTokens, info.OriginModelName, "模型测试",
 		quota, "模型测试", 0, quota, int(consumedTime), false, info.Group, other)
 	common.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody)))
@@ -312,7 +312,7 @@ func testAllChannels(notify bool) error {
 			channel.UpdateResponseTime(milliseconds)
 			time.Sleep(common.RequestInterval)
 		}
-		
+
 		if notify {
 			service.NotifyRootUser(dto.NotifyTypeChannelTest, "通道测试完成", "所有通道测试已完成")
 		}

+ 6 - 3
model/cache.go

@@ -3,7 +3,6 @@ package model
 import (
 	"errors"
 	"fmt"
-	"log"
 	"math/rand"
 	"one-api/common"
 	"one-api/setting"
@@ -88,14 +87,18 @@ func CacheGetRandomSatisfiedChannel(c *gin.Context, group string, model string,
 			return nil, selectGroup, errors.New("auto groups is not enabled")
 		}
 		for _, autoGroup := range setting.AutoGroups {
-			log.Printf("autoGroup: %s", autoGroup)
+			if common.DebugEnabled {
+				println("autoGroup:", autoGroup)
+			}
 			channel, _ = getRandomSatisfiedChannel(autoGroup, model, retry)
 			if channel == nil {
 				continue
 			} else {
 				c.Set("auto_group", autoGroup)
 				selectGroup = autoGroup
-				log.Printf("selectGroup: %s", selectGroup)
+				if common.DebugEnabled {
+					println("selectGroup:", selectGroup)
+				}
 				break
 			}
 		}

+ 3 - 3
model/option.go

@@ -194,7 +194,7 @@ func updateOptionMap(key string, value string) (err error) {
 			common.ImageDownloadPermission = intValue
 		}
 	}
-	if strings.HasSuffix(key, "Enabled") || key == "DefaultCollapseSidebar" {
+	if strings.HasSuffix(key, "Enabled") || key == "DefaultCollapseSidebar" || key == "DefaultUseAutoGroup" {
 		boolValue := value == "true"
 		switch key {
 		case "PasswordRegisterEnabled":
@@ -263,6 +263,8 @@ func updateOptionMap(key string, value string) (err error) {
 			common.SMTPSSLEnabled = boolValue
 		case "WorkerAllowHttpImageRequestEnabled":
 			setting.WorkerAllowHttpImageRequestEnabled = boolValue
+		case "DefaultUseAutoGroup":
+			setting.DefaultUseAutoGroup = boolValue
 		}
 	}
 	switch key {
@@ -291,8 +293,6 @@ func updateOptionMap(key string, value string) (err error) {
 		err = setting.UpdateChatsByJsonString(value)
 	case "AutoGroups":
 		err = setting.UpdateAutoGroupsByJsonString(value)
-	case "DefaultUseAutoGroup":
-		setting.DefaultUseAutoGroup = value == "true"
 	case "CustomCallbackAddress":
 		setting.CustomCallbackAddress = value
 	case "EpayId":

+ 41 - 20
relay/helper/price.go

@@ -2,7 +2,6 @@ package helper
 
 import (
 	"fmt"
-	"log"
 	"one-api/common"
 	constant2 "one-api/constant"
 	relaycommon "one-api/relay/common"
@@ -12,6 +11,11 @@ import (
 	"github.com/gin-gonic/gin"
 )
 
+type GroupRatioInfo struct {
+	GroupRatio        float64
+	GroupSpecialRatio float64
+}
+
 type PriceData struct {
 	ModelPrice             float64
 	ModelRatio             float64
@@ -19,32 +23,50 @@ type PriceData struct {
 	CacheRatio             float64
 	CacheCreationRatio     float64
 	ImageRatio             float64
-	GroupRatio             float64
-	UserGroupRatio         float64
 	UsePrice               bool
 	ShouldPreConsumedQuota int
+	GroupRatioInfo         GroupRatioInfo
 }
 
 func (p PriceData) ToSetting() string {
-	return fmt.Sprintf("ModelPrice: %f, ModelRatio: %f, CompletionRatio: %f, CacheRatio: %f, GroupRatio: %f, UsePrice: %t, CacheCreationRatio: %f, ShouldPreConsumedQuota: %d, ImageRatio: %f", p.ModelPrice, p.ModelRatio, p.CompletionRatio, p.CacheRatio, p.GroupRatio, p.UsePrice, p.CacheCreationRatio, p.ShouldPreConsumedQuota, p.ImageRatio)
+	return fmt.Sprintf("ModelPrice: %f, ModelRatio: %f, CompletionRatio: %f, CacheRatio: %f, GroupRatio: %f, UsePrice: %t, CacheCreationRatio: %f, ShouldPreConsumedQuota: %d, ImageRatio: %f", p.ModelPrice, p.ModelRatio, p.CompletionRatio, p.CacheRatio, p.GroupRatioInfo.GroupRatio, p.UsePrice, p.CacheCreationRatio, p.ShouldPreConsumedQuota, p.ImageRatio)
 }
 
-func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens int, maxTokens int) (PriceData, error) {
-	modelPrice, usePrice := operation_setting.GetModelPrice(info.OriginModelName, false)
-	groupRatio := setting.GetGroupRatio(info.Group)
-	var userGroupRatio float64
-	autoGroup, exists := c.Get("auto_group")
+// HandleGroupRatio checks for "auto_group" in the context and updates the group ratio and relayInfo.Group if present
+func HandleGroupRatio(ctx *gin.Context, relayInfo *relaycommon.RelayInfo) GroupRatioInfo {
+	groupRatioInfo := GroupRatioInfo{
+		GroupRatio:        1.0, // default ratio
+		GroupSpecialRatio: 1.0, // default user group ratio
+	}
+
+	// check auto group
+	autoGroup, exists := ctx.Get("auto_group")
 	if exists {
-		groupRatio = setting.GetGroupRatio(autoGroup.(string))
-		log.Printf("final group ratio: %f", groupRatio)
-		info.Group = autoGroup.(string)
+		if common.DebugEnabled {
+			println(fmt.Sprintf("final group: %s", autoGroup))
+		}
+		relayInfo.Group = autoGroup.(string)
 	}
-	actualGroupRatio := groupRatio
-	userGroupRatio, ok := setting.GetGroupGroupRatio(info.UserGroup, info.Group)
+
+	// check user group special ratio
+	userGroupRatio, ok := setting.GetGroupGroupRatio(relayInfo.UserGroup, relayInfo.Group)
 	if ok {
-		actualGroupRatio = userGroupRatio
+		// user group special ratio
+		groupRatioInfo.GroupSpecialRatio = userGroupRatio
+		groupRatioInfo.GroupRatio = userGroupRatio
+	} else {
+		// normal group ratio
+		groupRatioInfo.GroupRatio = setting.GetGroupRatio(relayInfo.Group)
 	}
-	groupRatio = actualGroupRatio
+
+	return groupRatioInfo
+}
+
+func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens int, maxTokens int) (PriceData, error) {
+	modelPrice, usePrice := operation_setting.GetModelPrice(info.OriginModelName, false)
+
+	groupRatioInfo := HandleGroupRatio(c, info)
+
 	var preConsumedQuota int
 	var modelRatio float64
 	var completionRatio float64
@@ -74,18 +96,17 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens
 		cacheRatio, _ = operation_setting.GetCacheRatio(info.OriginModelName)
 		cacheCreationRatio, _ = operation_setting.GetCreateCacheRatio(info.OriginModelName)
 		imageRatio, _ = operation_setting.GetImageRatio(info.OriginModelName)
-		ratio := modelRatio * groupRatio
+		ratio := modelRatio * groupRatioInfo.GroupRatio
 		preConsumedQuota = int(float64(preConsumedTokens) * ratio)
 	} else {
-		preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio)
+		preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatioInfo.GroupRatio)
 	}
 
 	priceData := PriceData{
 		ModelPrice:             modelPrice,
 		ModelRatio:             modelRatio,
 		CompletionRatio:        completionRatio,
-		GroupRatio:             groupRatio,
-		UserGroupRatio:         userGroupRatio,
+		GroupRatioInfo:         groupRatioInfo,
 		UsePrice:               usePrice,
 		CacheRatio:             cacheRatio,
 		ImageRatio:             imageRatio,

+ 1 - 1
relay/relay-image.go

@@ -162,7 +162,7 @@ func ImageHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
 
 		// reset model price
 		priceData.ModelPrice *= sizeRatio * qualityRatio * float64(imageRequest.N)
-		quota = int(priceData.ModelPrice * priceData.GroupRatio * common.QuotaPerUnit)
+		quota = int(priceData.ModelPrice * priceData.GroupRatioInfo.GroupRatio * common.QuotaPerUnit)
 		userQuota, err = model.GetUserQuota(relayInfo.UserId, false)
 		if err != nil {
 			return service.OpenAIErrorWrapperLocal(err, "get_user_quota_failed", http.StatusInternalServerError)

+ 2 - 3
relay/relay-text.go

@@ -361,9 +361,8 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
 	cacheRatio := priceData.CacheRatio
 	imageRatio := priceData.ImageRatio
 	modelRatio := priceData.ModelRatio
-	groupRatio := priceData.GroupRatio
+	groupRatio := priceData.GroupRatioInfo.GroupRatio
 	modelPrice := priceData.ModelPrice
-	userGroupRatio := priceData.UserGroupRatio
 
 	// Convert values to decimal for precise calculation
 	dPromptTokens := decimal.NewFromInt(int64(promptTokens))
@@ -511,7 +510,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
 	if extraContent != "" {
 		logContent += ", " + extraContent
 	}
-	other := service.GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, cacheTokens, cacheRatio, modelPrice, userGroupRatio)
+	other := service.GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, cacheTokens, cacheRatio, modelPrice, priceData.GroupRatioInfo.GroupSpecialRatio)
 	if imageTokens != 0 {
 		other["image"] = true
 		other["image_ratio"] = imageRatio

+ 6 - 37
relay/websocket.go

@@ -6,12 +6,10 @@ import (
 	"github.com/gin-gonic/gin"
 	"github.com/gorilla/websocket"
 	"net/http"
-	"one-api/common"
 	"one-api/dto"
 	relaycommon "one-api/relay/common"
+	"one-api/relay/helper"
 	"one-api/service"
-	"one-api/setting"
-	"one-api/setting/operation_setting"
 )
 
 func WssHelper(c *gin.Context, ws *websocket.Conn) (openaiErr *dto.OpenAIErrorWithStatusCode) {
@@ -39,43 +37,14 @@ func WssHelper(c *gin.Context, ws *websocket.Conn) (openaiErr *dto.OpenAIErrorWi
 			//isModelMapped = true
 		}
 	}
-	//relayInfo.UpstreamModelName = textRequest.Model
-	modelPrice, getModelPriceSuccess := operation_setting.GetModelPrice(relayInfo.UpstreamModelName, false)
-	groupRatio := setting.GetGroupRatio(relayInfo.Group)
 
-	var preConsumedQuota int
-	var ratio float64
-	var modelRatio float64
-	//err := service.SensitiveWordsCheck(textRequest)
-
-	//if constant.ShouldCheckPromptSensitive() {
-	//	err = checkRequestSensitive(textRequest, relayInfo)
-	//	if err != nil {
-	//		return service.OpenAIErrorWrapperLocal(err, "sensitive_words_detected", http.StatusBadRequest)
-	//	}
-	//}
-
-	//promptTokens, err := getWssPromptTokens(realtimeEvent, relayInfo)
-	//// count messages token error 计算promptTokens错误
-	//if err != nil {
-	//	return service.OpenAIErrorWrapper(err, "count_token_messages_failed", http.StatusInternalServerError)
-	//}
-	//
-	if !getModelPriceSuccess {
-		preConsumedTokens := common.PreConsumedQuota
-		//if realtimeEvent.Session.MaxResponseOutputTokens != 0 {
-		//	preConsumedTokens = promptTokens + int(realtimeEvent.Session.MaxResponseOutputTokens)
-		//}
-		modelRatio, _ = operation_setting.GetModelRatio(relayInfo.UpstreamModelName)
-		ratio = modelRatio * groupRatio
-		preConsumedQuota = int(float64(preConsumedTokens) * ratio)
-	} else {
-		preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio)
-		relayInfo.UsePrice = true
+	priceData, err := helper.ModelPriceHelper(c, relayInfo, 0, 0)
+	if err != nil {
+		return service.OpenAIErrorWrapperLocal(err, "model_price_error", http.StatusInternalServerError)
 	}
 
 	// pre-consume quota 预消耗配额
-	preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, preConsumedQuota, relayInfo)
+	preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
 	if openaiErr != nil {
 		return openaiErr
 	}
@@ -113,6 +82,6 @@ func WssHelper(c *gin.Context, ws *websocket.Conn) (openaiErr *dto.OpenAIErrorWi
 		return openaiErr
 	}
 	service.PostWssConsumeQuota(c, relayInfo, relayInfo.UpstreamModelName, usage.(*dto.RealtimeUsage), preConsumedQuota,
-		userQuota, modelRatio, groupRatio, modelPrice, getModelPriceSuccess, "")
+		userQuota, priceData, "")
 	return nil
 }

+ 12 - 34
service/quota.go

@@ -144,8 +144,7 @@ func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usag
 }
 
 func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelName string,
-	usage *dto.RealtimeUsage, preConsumedQuota int, userQuota int, modelRatio float64, groupRatio float64,
-	modelPrice float64, usePrice bool, extraContent string) {
+	usage *dto.RealtimeUsage, preConsumedQuota int, userQuota int, priceData helper.PriceData, extraContent string) {
 
 	useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
 	textInputTokens := usage.InputTokenDetails.TextTokens
@@ -159,18 +158,11 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod
 	audioRatio := decimal.NewFromFloat(operation_setting.GetAudioRatio(relayInfo.OriginModelName))
 	audioCompletionRatio := decimal.NewFromFloat(operation_setting.GetAudioCompletionRatio(modelName))
 
-	autoGroup, exists := ctx.Get("auto_group")
-	if exists {
-		groupRatio = setting.GetGroupRatio(autoGroup.(string))
-		log.Printf("final group ratio: %f", groupRatio)
-		relayInfo.Group = autoGroup.(string)
-	}
+	modelRatio := priceData.ModelRatio
+	groupRatio := priceData.GroupRatioInfo.GroupRatio
+	modelPrice := priceData.ModelPrice
+	usePrice := priceData.UsePrice
 
-	actualGroupRatio := groupRatio
-	userGroupRatio, ok := setting.GetGroupGroupRatio(relayInfo.UserGroup, relayInfo.Group)
-	if ok {
-		actualGroupRatio = userGroupRatio
-	}
 	quotaInfo := QuotaInfo{
 		InputDetails: TokenDetails{
 			TextTokens:  textInputTokens,
@@ -183,7 +175,7 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod
 		ModelName:  modelName,
 		UsePrice:   usePrice,
 		ModelRatio: modelRatio,
-		GroupRatio: actualGroupRatio,
+		GroupRatio: groupRatio,
 	}
 
 	quota := calculateAudioQuota(quotaInfo)
@@ -215,7 +207,7 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod
 		logContent += ", " + extraContent
 	}
 	other := GenerateWssOtherInfo(ctx, relayInfo, usage, modelRatio, groupRatio,
-		completionRatio.InexactFloat64(), audioRatio.InexactFloat64(), audioCompletionRatio.InexactFloat64(), modelPrice, userGroupRatio)
+		completionRatio.InexactFloat64(), audioRatio.InexactFloat64(), audioCompletionRatio.InexactFloat64(), modelPrice, priceData.GroupRatioInfo.GroupSpecialRatio)
 	model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, usage.InputTokens, usage.OutputTokens, logModel,
 		tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.Group, other)
 }
@@ -231,9 +223,8 @@ func PostClaudeConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
 	tokenName := ctx.GetString("token_name")
 	completionRatio := priceData.CompletionRatio
 	modelRatio := priceData.ModelRatio
-	groupRatio := priceData.GroupRatio
+	groupRatio := priceData.GroupRatioInfo.GroupRatio
 	modelPrice := priceData.ModelPrice
-	userGroupRatio := priceData.UserGroupRatio
 	cacheRatio := priceData.CacheRatio
 	cacheTokens := usage.PromptTokensDetails.CachedTokens
 
@@ -282,7 +273,7 @@ func PostClaudeConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
 	}
 
 	other := GenerateClaudeOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio,
-		cacheTokens, cacheRatio, cacheCreationTokens, cacheCreationRatio, modelPrice, userGroupRatio)
+		cacheTokens, cacheRatio, cacheCreationTokens, cacheCreationRatio, modelPrice, priceData.GroupRatioInfo.GroupSpecialRatio)
 	model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, promptTokens, completionTokens, modelName,
 		tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.Group, other)
 }
@@ -303,23 +294,10 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
 	audioCompletionRatio := decimal.NewFromFloat(operation_setting.GetAudioCompletionRatio(relayInfo.OriginModelName))
 
 	modelRatio := priceData.ModelRatio
-	groupRatio := priceData.GroupRatio
+	groupRatio := priceData.GroupRatioInfo.GroupRatio
 	modelPrice := priceData.ModelPrice
 	usePrice := priceData.UsePrice
 
-	autoGroup, exists := ctx.Get("auto_group")
-	if exists {
-		groupRatio = setting.GetGroupRatio(autoGroup.(string))
-		log.Printf("final group ratio: %f", groupRatio)
-		relayInfo.Group = autoGroup.(string)
-	}
-
-	actualGroupRatio := groupRatio
-	userGroupRatio, ok := setting.GetGroupGroupRatio(relayInfo.UserGroup, relayInfo.Group)
-	if ok {
-		actualGroupRatio = userGroupRatio
-	}
-
 	quotaInfo := QuotaInfo{
 		InputDetails: TokenDetails{
 			TextTokens:  textInputTokens,
@@ -332,7 +310,7 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
 		ModelName:  relayInfo.OriginModelName,
 		UsePrice:   usePrice,
 		ModelRatio: modelRatio,
-		GroupRatio: actualGroupRatio,
+		GroupRatio: groupRatio,
 	}
 
 	quota := calculateAudioQuota(quotaInfo)
@@ -372,7 +350,7 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
 		logContent += ", " + extraContent
 	}
 	other := GenerateAudioOtherInfo(ctx, relayInfo, usage, modelRatio, groupRatio,
-		completionRatio.InexactFloat64(), audioRatio.InexactFloat64(), audioCompletionRatio.InexactFloat64(), modelPrice, userGroupRatio)
+		completionRatio.InexactFloat64(), audioRatio.InexactFloat64(), audioCompletionRatio.InexactFloat64(), modelPrice, priceData.GroupRatioInfo.GroupSpecialRatio)
 	model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, usage.PromptTokens, usage.CompletionTokens, logModel,
 		tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.Group, other)
 }