Просмотр исходного кода

fix(billing): preserve text tool surcharges in tiered settlement

yyhhyyyyyy 1 месяц назад
Родитель
Сommit
1fe9f6f989
3 измененных файлов с 174 добавлено и 55 удалено
  1. 2 0
      mise.toml
  2. 88 55
      service/text_quota.go
  3. 84 0
      service/text_quota_test.go

+ 2 - 0
mise.toml

@@ -0,0 +1,2 @@
+[tools]
+bun = "latest"

+ 88 - 55
service/text_quota.go

@@ -52,6 +52,7 @@ type textQuotaSummary struct {
 	FileSearchCallCount      int
 	FileSearchCallCount      int
 	AudioInputPrice          float64
 	AudioInputPrice          float64
 	ImageGenerationCallPrice float64
 	ImageGenerationCallPrice float64
+	ToolCallSurchargeQuota   decimal.Decimal
 }
 }
 
 
 func cacheWriteTokensTotal(summary textQuotaSummary) int {
 func cacheWriteTokensTotal(summary textQuotaSummary) int {
@@ -78,6 +79,89 @@ func isLegacyClaudeDerivedOpenAIUsage(relayInfo *relaycommon.RelayInfo, usage *d
 	return usage.ClaudeCacheCreation5mTokens > 0 || usage.ClaudeCacheCreation1hTokens > 0
 	return usage.ClaudeCacheCreation5mTokens > 0 || usage.ClaudeCacheCreation1hTokens > 0
 }
 }
 
 
+func calculateTextToolCallSurcharge(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, summary *textQuotaSummary) decimal.Decimal {
+	dGroupRatio := decimal.NewFromFloat(summary.GroupRatio)
+	dQuotaPerUnit := decimal.NewFromFloat(common.QuotaPerUnit)
+
+	var surcharge decimal.Decimal
+
+	if relayInfo.ResponsesUsageInfo != nil {
+		if webSearchTool, exists := relayInfo.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolWebSearchPreview]; exists && webSearchTool.CallCount > 0 {
+			summary.WebSearchCallCount = webSearchTool.CallCount
+			summary.WebSearchPrice = operation_setting.GetToolPriceForModel("web_search_preview", summary.ModelName)
+			surcharge = surcharge.Add(decimal.NewFromFloat(summary.WebSearchPrice).
+				Mul(decimal.NewFromInt(int64(webSearchTool.CallCount))).
+				Div(decimal.NewFromInt(1000)).
+				Mul(dGroupRatio).
+				Mul(dQuotaPerUnit))
+		}
+	} else if strings.HasSuffix(summary.ModelName, "search-preview") {
+		summary.WebSearchCallCount = 1
+		summary.WebSearchPrice = operation_setting.GetToolPriceForModel("web_search_preview", summary.ModelName)
+		surcharge = surcharge.Add(decimal.NewFromFloat(summary.WebSearchPrice).
+			Div(decimal.NewFromInt(1000)).
+			Mul(dGroupRatio).
+			Mul(dQuotaPerUnit))
+	}
+
+	summary.ClaudeWebSearchCallCount = ctx.GetInt("claude_web_search_requests")
+	if summary.ClaudeWebSearchCallCount > 0 {
+		summary.ClaudeWebSearchPrice = operation_setting.GetToolPrice("web_search")
+		surcharge = surcharge.Add(decimal.NewFromFloat(summary.ClaudeWebSearchPrice).
+			Div(decimal.NewFromInt(1000)).
+			Mul(dGroupRatio).
+			Mul(dQuotaPerUnit).
+			Mul(decimal.NewFromInt(int64(summary.ClaudeWebSearchCallCount))))
+	}
+
+	if relayInfo.ResponsesUsageInfo != nil {
+		if fileSearchTool, exists := relayInfo.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolFileSearch]; exists && fileSearchTool.CallCount > 0 {
+			summary.FileSearchCallCount = fileSearchTool.CallCount
+			summary.FileSearchPrice = operation_setting.GetToolPrice("file_search")
+			surcharge = surcharge.Add(decimal.NewFromFloat(summary.FileSearchPrice).
+				Mul(decimal.NewFromInt(int64(fileSearchTool.CallCount))).
+				Div(decimal.NewFromInt(1000)).
+				Mul(dGroupRatio).
+				Mul(dQuotaPerUnit))
+		}
+	}
+
+	if ctx.GetBool("image_generation_call") {
+		summary.ImageGenerationCallPrice = operation_setting.GetGPTImage1PriceOnceCall(ctx.GetString("image_generation_call_quality"), ctx.GetString("image_generation_call_size"))
+		surcharge = surcharge.Add(decimal.NewFromFloat(summary.ImageGenerationCallPrice).
+			Mul(dGroupRatio).
+			Mul(dQuotaPerUnit))
+	}
+
+	return surcharge
+}
+
+func composeTieredTextQuota(relayInfo *relaycommon.RelayInfo, summary textQuotaSummary, tieredQuota int, tieredResult *billingexpr.TieredResult) int {
+	if summary.ToolCallSurchargeQuota.IsZero() {
+		return tieredQuota
+	}
+
+	if tieredResult != nil {
+		if snap := relayInfo.TieredBillingSnapshot; snap != nil {
+			return int(decimal.NewFromFloat(tieredResult.ActualQuotaBeforeGroup).
+				Mul(decimal.NewFromFloat(snap.GroupRatio)).
+				Add(summary.ToolCallSurchargeQuota).
+				Round(0).
+				IntPart())
+		}
+	}
+
+	if snap := relayInfo.TieredBillingSnapshot; snap != nil {
+		return int(decimal.NewFromFloat(snap.EstimatedQuotaBeforeGroup).
+			Mul(decimal.NewFromFloat(snap.GroupRatio)).
+			Add(summary.ToolCallSurchargeQuota).
+			Round(0).
+			IntPart())
+	}
+
+	return tieredQuota + int(summary.ToolCallSurchargeQuota.Round(0).IntPart())
+}
+
 func calculateTextQuotaSummary(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.Usage) textQuotaSummary {
 func calculateTextQuotaSummary(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.Usage) textQuotaSummary {
 	summary := textQuotaSummary{
 	summary := textQuotaSummary{
 		ModelName:            relayInfo.OriginModelName,
 		ModelName:            relayInfo.OriginModelName,
@@ -148,52 +232,7 @@ func calculateTextQuotaSummary(ctx *gin.Context, relayInfo *relaycommon.RelayInf
 	dQuotaPerUnit := decimal.NewFromFloat(common.QuotaPerUnit)
 	dQuotaPerUnit := decimal.NewFromFloat(common.QuotaPerUnit)
 
 
 	ratio := dModelRatio.Mul(dGroupRatio)
 	ratio := dModelRatio.Mul(dGroupRatio)
-
-	var dWebSearchQuota decimal.Decimal
-	if relayInfo.ResponsesUsageInfo != nil {
-		if webSearchTool, exists := relayInfo.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolWebSearchPreview]; exists && webSearchTool.CallCount > 0 {
-			summary.WebSearchCallCount = webSearchTool.CallCount
-			summary.WebSearchPrice = operation_setting.GetToolPriceForModel("web_search_preview", summary.ModelName)
-			dWebSearchQuota = decimal.NewFromFloat(summary.WebSearchPrice).
-				Mul(decimal.NewFromInt(int64(webSearchTool.CallCount))).
-				Div(decimal.NewFromInt(1000)).Mul(dGroupRatio).Mul(dQuotaPerUnit)
-		}
-	} else if strings.HasSuffix(summary.ModelName, "search-preview") {
-		searchContextSize := ctx.GetString("chat_completion_web_search_context_size")
-		if searchContextSize == "" {
-			searchContextSize = "medium"
-		}
-		summary.WebSearchCallCount = 1
-		summary.WebSearchPrice = operation_setting.GetToolPriceForModel("web_search_preview", summary.ModelName)
-		dWebSearchQuota = decimal.NewFromFloat(summary.WebSearchPrice).
-			Div(decimal.NewFromInt(1000)).Mul(dGroupRatio).Mul(dQuotaPerUnit)
-	}
-
-	var dClaudeWebSearchQuota decimal.Decimal
-	summary.ClaudeWebSearchCallCount = ctx.GetInt("claude_web_search_requests")
-	if summary.ClaudeWebSearchCallCount > 0 {
-		summary.ClaudeWebSearchPrice = operation_setting.GetToolPrice("web_search")
-		dClaudeWebSearchQuota = decimal.NewFromFloat(summary.ClaudeWebSearchPrice).
-			Div(decimal.NewFromInt(1000)).Mul(dGroupRatio).Mul(dQuotaPerUnit).
-			Mul(decimal.NewFromInt(int64(summary.ClaudeWebSearchCallCount)))
-	}
-
-	var dFileSearchQuota decimal.Decimal
-	if relayInfo.ResponsesUsageInfo != nil {
-		if fileSearchTool, exists := relayInfo.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolFileSearch]; exists && fileSearchTool.CallCount > 0 {
-			summary.FileSearchCallCount = fileSearchTool.CallCount
-			summary.FileSearchPrice = operation_setting.GetToolPrice("file_search")
-			dFileSearchQuota = decimal.NewFromFloat(summary.FileSearchPrice).
-				Mul(decimal.NewFromInt(int64(fileSearchTool.CallCount))).
-				Div(decimal.NewFromInt(1000)).Mul(dGroupRatio).Mul(dQuotaPerUnit)
-		}
-	}
-
-	var dImageGenerationCallQuota decimal.Decimal
-	if ctx.GetBool("image_generation_call") {
-		summary.ImageGenerationCallPrice = operation_setting.GetGPTImage1PriceOnceCall(ctx.GetString("image_generation_call_quality"), ctx.GetString("image_generation_call_size"))
-		dImageGenerationCallQuota = decimal.NewFromFloat(summary.ImageGenerationCallPrice).Mul(dGroupRatio).Mul(dQuotaPerUnit)
-	}
+	summary.ToolCallSurchargeQuota = calculateTextToolCallSurcharge(ctx, relayInfo, &summary)
 
 
 	var audioInputQuota decimal.Decimal
 	var audioInputQuota decimal.Decimal
 	if !relayInfo.PriceData.UsePrice {
 	if !relayInfo.PriceData.UsePrice {
@@ -242,11 +281,8 @@ func calculateTextQuotaSummary(ctx *gin.Context, relayInfo *relaycommon.RelayInf
 		promptQuota := baseTokens.Add(cachedTokensWithRatio).Add(imageTokensWithRatio).Add(cachedCreationTokensWithRatio)
 		promptQuota := baseTokens.Add(cachedTokensWithRatio).Add(imageTokensWithRatio).Add(cachedCreationTokensWithRatio)
 		completionQuota := dCompletionTokens.Mul(dCompletionRatio)
 		completionQuota := dCompletionTokens.Mul(dCompletionRatio)
 		quotaCalculateDecimal := promptQuota.Add(completionQuota).Mul(ratio)
 		quotaCalculateDecimal := promptQuota.Add(completionQuota).Mul(ratio)
-		quotaCalculateDecimal = quotaCalculateDecimal.Add(dWebSearchQuota)
-		quotaCalculateDecimal = quotaCalculateDecimal.Add(dClaudeWebSearchQuota)
-		quotaCalculateDecimal = quotaCalculateDecimal.Add(dFileSearchQuota)
+		quotaCalculateDecimal = quotaCalculateDecimal.Add(summary.ToolCallSurchargeQuota)
 		quotaCalculateDecimal = quotaCalculateDecimal.Add(audioInputQuota)
 		quotaCalculateDecimal = quotaCalculateDecimal.Add(audioInputQuota)
-		quotaCalculateDecimal = quotaCalculateDecimal.Add(dImageGenerationCallQuota)
 
 
 		if len(relayInfo.PriceData.OtherRatios) > 0 {
 		if len(relayInfo.PriceData.OtherRatios) > 0 {
 			for _, otherRatio := range relayInfo.PriceData.OtherRatios {
 			for _, otherRatio := range relayInfo.PriceData.OtherRatios {
@@ -260,11 +296,8 @@ func calculateTextQuotaSummary(ctx *gin.Context, relayInfo *relaycommon.RelayInf
 		summary.Quota = int(quotaCalculateDecimal.Round(0).IntPart())
 		summary.Quota = int(quotaCalculateDecimal.Round(0).IntPart())
 	} else {
 	} else {
 		quotaCalculateDecimal := dModelPrice.Mul(dQuotaPerUnit).Mul(dGroupRatio)
 		quotaCalculateDecimal := dModelPrice.Mul(dQuotaPerUnit).Mul(dGroupRatio)
-		quotaCalculateDecimal = quotaCalculateDecimal.Add(dWebSearchQuota)
-		quotaCalculateDecimal = quotaCalculateDecimal.Add(dClaudeWebSearchQuota)
-		quotaCalculateDecimal = quotaCalculateDecimal.Add(dFileSearchQuota)
+		quotaCalculateDecimal = quotaCalculateDecimal.Add(summary.ToolCallSurchargeQuota)
 		quotaCalculateDecimal = quotaCalculateDecimal.Add(audioInputQuota)
 		quotaCalculateDecimal = quotaCalculateDecimal.Add(audioInputQuota)
-		quotaCalculateDecimal = quotaCalculateDecimal.Add(dImageGenerationCallQuota)
 		if len(relayInfo.PriceData.OtherRatios) > 0 {
 		if len(relayInfo.PriceData.OtherRatios) > 0 {
 			for _, otherRatio := range relayInfo.PriceData.OtherRatios {
 			for _, otherRatio := range relayInfo.PriceData.OtherRatios {
 				quotaCalculateDecimal = quotaCalculateDecimal.Mul(decimal.NewFromFloat(otherRatio))
 				quotaCalculateDecimal = quotaCalculateDecimal.Mul(decimal.NewFromFloat(otherRatio))
@@ -313,7 +346,7 @@ func PostTextConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, us
 		tieredOk, tieredQuota, tieredRes := TryTieredSettle(relayInfo, BuildTieredTokenParams(usage, summary.IsClaudeUsageSemantic, tieredUsedVars))
 		tieredOk, tieredQuota, tieredRes := TryTieredSettle(relayInfo, BuildTieredTokenParams(usage, summary.IsClaudeUsageSemantic, tieredUsedVars))
 		if tieredOk {
 		if tieredOk {
 			tieredResult = tieredRes
 			tieredResult = tieredRes
-			summary.Quota = tieredQuota
+			summary.Quota = composeTieredTextQuota(relayInfo, summary, tieredQuota, tieredRes)
 		}
 		}
 	}
 	}
 
 

+ 84 - 0
service/text_quota_test.go

@@ -7,6 +7,7 @@ import (
 
 
 	"github.com/QuantumNous/new-api/constant"
 	"github.com/QuantumNous/new-api/constant"
 	"github.com/QuantumNous/new-api/dto"
 	"github.com/QuantumNous/new-api/dto"
+	"github.com/QuantumNous/new-api/pkg/billingexpr"
 	relaycommon "github.com/QuantumNous/new-api/relay/common"
 	relaycommon "github.com/QuantumNous/new-api/relay/common"
 	"github.com/QuantumNous/new-api/types"
 	"github.com/QuantumNous/new-api/types"
 
 
@@ -316,3 +317,86 @@ func TestCalculateTextQuotaSummaryKeepsPrePRClaudeOpenRouterBilling(t *testing.T
 	require.Equal(t, 172, summary.PromptTokens)
 	require.Equal(t, 172, summary.PromptTokens)
 	require.Equal(t, 798, summary.Quota)
 	require.Equal(t, 798, summary.Quota)
 }
 }
+
+func TestComposeTieredTextQuotaKeepsToolCallSurcharges(t *testing.T) {
+	gin.SetMode(gin.TestMode)
+	w := httptest.NewRecorder()
+	ctx, _ := gin.CreateTestContext(w)
+	ctx.Set("image_generation_call", true)
+	ctx.Set("image_generation_call_quality", "low")
+	ctx.Set("image_generation_call_size", "1024x1024")
+
+	relayInfo := &relaycommon.RelayInfo{
+		OriginModelName: "o1",
+		PriceData: types.PriceData{
+			ModelRatio:      1,
+			CompletionRatio: 1,
+			GroupRatioInfo:  types.GroupRatioInfo{GroupRatio: 1},
+		},
+		ResponsesUsageInfo: &relaycommon.ResponsesUsageInfo{
+			BuiltInTools: map[string]*relaycommon.BuildInToolInfo{
+				dto.BuildInToolWebSearchPreview: &relaycommon.BuildInToolInfo{
+					CallCount: 1,
+				},
+				dto.BuildInToolFileSearch: &relaycommon.BuildInToolInfo{
+					CallCount: 2,
+				},
+			},
+		},
+		TieredBillingSnapshot: &billingexpr.BillingSnapshot{
+			BillingMode:               "tiered_expr",
+			GroupRatio:                1,
+			EstimatedQuotaBeforeGroup: 1000,
+		},
+		StartTime: time.Now(),
+	}
+
+	usage := &dto.Usage{
+		PromptTokens:     100,
+		CompletionTokens: 50,
+		TotalTokens:      150,
+	}
+
+	summary := calculateTextQuotaSummary(ctx, relayInfo, usage)
+	quota := composeTieredTextQuota(relayInfo, summary, 1000, &billingexpr.TieredResult{
+		ActualQuotaBeforeGroup: 1000,
+		ActualQuotaAfterGroup:  1000,
+	})
+
+	require.Equal(t, int64(13000), summary.ToolCallSurchargeQuota.Round(0).IntPart())
+	require.Equal(t, 14000, quota)
+}
+
+func TestComposeTieredTextQuotaFallbackKeepsToolCallSurcharges(t *testing.T) {
+	gin.SetMode(gin.TestMode)
+	w := httptest.NewRecorder()
+	ctx, _ := gin.CreateTestContext(w)
+	ctx.Set("claude_web_search_requests", 2)
+
+	relayInfo := &relaycommon.RelayInfo{
+		OriginModelName: "claude-3-7-sonnet",
+		PriceData: types.PriceData{
+			ModelRatio:      1,
+			CompletionRatio: 1,
+			GroupRatioInfo:  types.GroupRatioInfo{GroupRatio: 1.25},
+		},
+		TieredBillingSnapshot: &billingexpr.BillingSnapshot{
+			BillingMode:               "tiered_expr",
+			GroupRatio:                1.25,
+			EstimatedQuotaBeforeGroup: 1000,
+		},
+		StartTime: time.Now(),
+	}
+
+	usage := &dto.Usage{
+		PromptTokens:     100,
+		CompletionTokens: 50,
+		TotalTokens:      150,
+	}
+
+	summary := calculateTextQuotaSummary(ctx, relayInfo, usage)
+	quota := composeTieredTextQuota(relayInfo, summary, 1250, nil)
+
+	require.Equal(t, int64(12500), summary.ToolCallSurchargeQuota.Round(0).IntPart())
+	require.Equal(t, 13750, quota)
+}