Przeglądaj źródła

refactor: Improve token quota consumption logic

1808837298@qq.com 11 miesięcy temu
rodzic
commit
dd82618c05
3 zmienionych plików z 38 dodań i 27 usunięć
  1. 13 2
      relay/helper/price.go
  2. 9 8
      relay/relay-text.go
  3. 16 17
      service/quota.go

+ 13 - 2
relay/helper/price.go

@@ -20,6 +20,10 @@ type PriceData struct {
 	ShouldPreConsumedQuota int
 }
 
+func (p PriceData) ToSetting() string {
+	return fmt.Sprintf("ModelPrice: %f, ModelRatio: %f, CompletionRatio: %f, CacheRatio: %f, GroupRatio: %f, UsePrice: %t, CacheCreationRatio: %f, ShouldPreConsumedQuota: %d", p.ModelPrice, p.ModelRatio, p.CompletionRatio, p.CacheRatio, p.GroupRatio, p.UsePrice, p.CacheCreationRatio, p.ShouldPreConsumedQuota)
+}
+
 func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens int, maxTokens int) (PriceData, error) {
 	modelPrice, usePrice := operation_setting.GetModelPrice(info.OriginModelName, false)
 	groupRatio := setting.GetGroupRatio(info.Group)
@@ -50,7 +54,8 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens
 	} else {
 		preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio)
 	}
-	return PriceData{
+
+	priceData := PriceData{
 		ModelPrice:             modelPrice,
 		ModelRatio:             modelRatio,
 		CompletionRatio:        completionRatio,
@@ -59,5 +64,11 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens
 		CacheRatio:             cacheRatio,
 		CacheCreationRatio:     cacheCreationRatio,
 		ShouldPreConsumedQuota: preConsumedQuota,
-	}, nil
+	}
+
+	if common.DebugEnabled {
+		println(fmt.Sprintf("model_price_helper result: %s", priceData.ToSetting()))
+	}
+
+	return priceData, nil
 }

+ 9 - 8
relay/relay-text.go

@@ -109,7 +109,7 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
 		c.Set("prompt_tokens", promptTokens)
 	}
 
-	priceData, err := helper.ModelPriceHelper(c, relayInfo, promptTokens, int(textRequest.MaxTokens))
+	priceData, err := helper.ModelPriceHelper(c, relayInfo, promptTokens, int(math.Max(float64(textRequest.MaxTokens), float64(textRequest.MaxCompletionTokens))))
 	if err != nil {
 		return service.OpenAIErrorWrapperLocal(err, "model_price_error", http.StatusInternalServerError)
 	}
@@ -372,17 +372,18 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
 		common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+
 			"tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, preConsumedQuota))
 	} else {
-		quotaDelta := quota - preConsumedQuota
-		if quotaDelta != 0 {
-			err := service.PostConsumeQuota(relayInfo, quotaDelta, preConsumedQuota, true)
-			if err != nil {
-				common.LogError(ctx, "error consuming token remain quota: "+err.Error())
-			}
-		}
 		model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota)
 		model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
 	}
 
+	quotaDelta := quota - preConsumedQuota
+	if quotaDelta != 0 {
+		err := service.PostConsumeQuota(relayInfo, quotaDelta, preConsumedQuota, true)
+		if err != nil {
+			common.LogError(ctx, "error consuming token remain quota: "+err.Error())
+		}
+	}
+
 	logModel := modelName
 	if strings.HasPrefix(logModel, "gpt-4-gizmo") {
 		logModel = "gpt-4-gizmo-*"

+ 16 - 17
service/quota.go

@@ -243,20 +243,18 @@ func PostClaudeConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
 		common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+
 			"tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, preConsumedQuota))
 	} else {
-		//if sensitiveResp != nil {
-		//	logContent += fmt.Sprintf(",敏感词:%s", strings.Join(sensitiveResp.SensitiveWords, ", "))
-		//}
-		quotaDelta := quota - preConsumedQuota
-		if quotaDelta != 0 {
-			err := PostConsumeQuota(relayInfo, quotaDelta, preConsumedQuota, true)
-			if err != nil {
-				common.LogError(ctx, "error consuming token remain quota: "+err.Error())
-			}
-		}
 		model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota)
 		model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
 	}
 
+	quotaDelta := quota - preConsumedQuota
+	if quotaDelta != 0 {
+		err := PostConsumeQuota(relayInfo, quotaDelta, preConsumedQuota, true)
+		if err != nil {
+			common.LogError(ctx, "error consuming token remain quota: "+err.Error())
+		}
+	}
+
 	other := GenerateClaudeOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio,
 		cacheTokens, cacheRatio, cacheCreationTokens, cacheCreationRatio, modelPrice)
 	model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, promptTokens, completionTokens, modelName,
@@ -318,17 +316,18 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
 		common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+
 			"tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, relayInfo.OriginModelName, preConsumedQuota))
 	} else {
-		quotaDelta := quota - preConsumedQuota
-		if quotaDelta != 0 {
-			err := PostConsumeQuota(relayInfo, quotaDelta, preConsumedQuota, true)
-			if err != nil {
-				common.LogError(ctx, "error consuming token remain quota: "+err.Error())
-			}
-		}
 		model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota)
 		model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
 	}
 
+	quotaDelta := quota - preConsumedQuota
+	if quotaDelta != 0 {
+		err := PostConsumeQuota(relayInfo, quotaDelta, preConsumedQuota, true)
+		if err != nil {
+			common.LogError(ctx, "error consuming token remain quota: "+err.Error())
+		}
+	}
+
 	logModel := relayInfo.OriginModelName
 	if extraContent != "" {
 		logContent += ", " + extraContent