Procházet zdrojové kódy

Merge pull request #4431 from yyhhyyyyyy/fix/tiered-billing-model-list

fix: include tiered billing models in model listing
yyhhyyyyyy před 2 týdny
rodič
revize
e3d64cb76d
5 změnil soubory, kde provedl 272 přidání a 19 odebrání
  1. 2 5
      controller/model.go
  2. 242 0
      controller/model_list_test.go
  3. 3 0
      model/option.go
  4. 24 1
      model/pricing.go
  5. 1 13
      relay/helper/price.go

+ 2 - 5
controller/model.go

@@ -17,7 +17,6 @@ import (
 	relaycommon "github.com/QuantumNous/new-api/relay/common"
 	"github.com/QuantumNous/new-api/service"
 	"github.com/QuantumNous/new-api/setting/operation_setting"
-	"github.com/QuantumNous/new-api/setting/ratio_setting"
 	"github.com/QuantumNous/new-api/types"
 	"github.com/gin-gonic/gin"
 	"github.com/samber/lo"
@@ -134,8 +133,7 @@ func ListModels(c *gin.Context, modelType int) {
 		}
 		for allowModel, _ := range tokenModelLimit {
 			if !acceptUnsetRatioModel {
-				_, _, exist := ratio_setting.GetModelRatioOrPrice(allowModel)
-				if !exist {
+				if !model.HasModelBillingConfig(allowModel) {
 					continue
 				}
 			}
@@ -182,8 +180,7 @@ func ListModels(c *gin.Context, modelType int) {
 		}
 		for _, modelName := range models {
 			if !acceptUnsetRatioModel {
-				_, _, exist := ratio_setting.GetModelRatioOrPrice(modelName)
-				if !exist {
+				if !model.HasModelBillingConfig(modelName) {
 					continue
 				}
 			}

+ 242 - 0
controller/model_list_test.go

@@ -0,0 +1,242 @@
+package controller
+
+import (
+	"fmt"
+	"net/http"
+	"net/http/httptest"
+	"os"
+	"strings"
+	"testing"
+
+	"github.com/QuantumNous/new-api/common"
+	"github.com/QuantumNous/new-api/constant"
+	"github.com/QuantumNous/new-api/dto"
+	"github.com/QuantumNous/new-api/model"
+	"github.com/QuantumNous/new-api/setting/config"
+	"github.com/QuantumNous/new-api/setting/operation_setting"
+	"github.com/gin-gonic/gin"
+	"github.com/glebarez/sqlite"
+	"github.com/stretchr/testify/require"
+	"gorm.io/gorm"
+)
+
+type listModelsResponse struct {
+	Success bool               `json:"success"`
+	Data    []dto.OpenAIModels `json:"data"`
+	Object  string             `json:"object"`
+}
+
+func setupModelListControllerTestDB(t *testing.T) *gorm.DB {
+	t.Helper()
+
+	initModelListColumnNames(t)
+
+	gin.SetMode(gin.TestMode)
+	common.UsingSQLite = true
+	common.UsingMySQL = false
+	common.UsingPostgreSQL = false
+	common.RedisEnabled = false
+
+	dsn := fmt.Sprintf("file:%s?mode=memory&cache=shared", strings.ReplaceAll(t.Name(), "/", "_"))
+	db, err := gorm.Open(sqlite.Open(dsn), &gorm.Config{})
+	require.NoError(t, err)
+	model.DB = db
+	model.LOG_DB = db
+
+	require.NoError(t, db.AutoMigrate(&model.User{}, &model.Channel{}, &model.Ability{}, &model.Model{}, &model.Vendor{}))
+
+	t.Cleanup(func() {
+		sqlDB, err := db.DB()
+		if err == nil {
+			_ = sqlDB.Close()
+		}
+	})
+
+	return db
+}
+
+func initModelListColumnNames(t *testing.T) {
+	t.Helper()
+
+	originalIsMasterNode := common.IsMasterNode
+	originalSQLitePath := common.SQLitePath
+	originalUsingSQLite := common.UsingSQLite
+	originalUsingMySQL := common.UsingMySQL
+	originalUsingPostgreSQL := common.UsingPostgreSQL
+	originalSQLDSN, hadSQLDSN := os.LookupEnv("SQL_DSN")
+	defer func() {
+		common.IsMasterNode = originalIsMasterNode
+		common.SQLitePath = originalSQLitePath
+		common.UsingSQLite = originalUsingSQLite
+		common.UsingMySQL = originalUsingMySQL
+		common.UsingPostgreSQL = originalUsingPostgreSQL
+		if hadSQLDSN {
+			require.NoError(t, os.Setenv("SQL_DSN", originalSQLDSN))
+		} else {
+			require.NoError(t, os.Unsetenv("SQL_DSN"))
+		}
+	}()
+
+	common.IsMasterNode = false
+	common.SQLitePath = fmt.Sprintf("file:%s_init?mode=memory&cache=shared", strings.ReplaceAll(t.Name(), "/", "_"))
+	common.UsingSQLite = false
+	common.UsingMySQL = false
+	common.UsingPostgreSQL = false
+	require.NoError(t, os.Setenv("SQL_DSN", "local"))
+
+	require.NoError(t, model.InitDB())
+	if model.DB != nil {
+		sqlDB, err := model.DB.DB()
+		if err == nil {
+			_ = sqlDB.Close()
+		}
+	}
+}
+
+func withTieredBillingConfig(t *testing.T, modes map[string]string, exprs map[string]string) {
+	t.Helper()
+
+	saved := map[string]string{}
+	require.NoError(t, config.GlobalConfig.SaveToDB(func(key, value string) error {
+		if strings.HasPrefix(key, "billing_setting.") {
+			saved[key] = value
+		}
+		return nil
+	}))
+	t.Cleanup(func() {
+		require.NoError(t, config.GlobalConfig.LoadFromDB(saved))
+		model.InvalidatePricingCache()
+	})
+
+	modeBytes, err := common.Marshal(modes)
+	require.NoError(t, err)
+	exprBytes, err := common.Marshal(exprs)
+	require.NoError(t, err)
+
+	require.NoError(t, config.GlobalConfig.LoadFromDB(map[string]string{
+		"billing_setting.billing_mode": string(modeBytes),
+		"billing_setting.billing_expr": string(exprBytes),
+	}))
+	model.InvalidatePricingCache()
+}
+
+func withSelfUseModeDisabled(t *testing.T) {
+	t.Helper()
+
+	original := operation_setting.SelfUseModeEnabled
+	operation_setting.SelfUseModeEnabled = false
+	t.Cleanup(func() {
+		operation_setting.SelfUseModeEnabled = original
+	})
+}
+
+func decodeListModelsResponse(t *testing.T, recorder *httptest.ResponseRecorder) map[string]struct{} {
+	t.Helper()
+
+	require.Equal(t, http.StatusOK, recorder.Code)
+	var payload listModelsResponse
+	require.NoError(t, common.Unmarshal(recorder.Body.Bytes(), &payload))
+	require.True(t, payload.Success)
+	require.Equal(t, "list", payload.Object)
+
+	ids := make(map[string]struct{}, len(payload.Data))
+	for _, item := range payload.Data {
+		ids[item.Id] = struct{}{}
+	}
+	return ids
+}
+
+func pricingByModelName(pricings []model.Pricing) map[string]model.Pricing {
+	byName := make(map[string]model.Pricing, len(pricings))
+	for _, pricing := range pricings {
+		byName[pricing.ModelName] = pricing
+	}
+	return byName
+}
+
+func TestListModelsIncludesTieredBillingModel(t *testing.T) {
+	withSelfUseModeDisabled(t)
+	withTieredBillingConfig(t, map[string]string{
+		"zz-tiered-visible-model":      "tiered_expr",
+		"zz-tiered-empty-expr-model":   "tiered_expr",
+		"zz-tiered-missing-expr-model": "tiered_expr",
+	}, map[string]string{
+		"zz-tiered-visible-model":    `tier("base", p * 1 + c * 2)`,
+		"zz-tiered-empty-expr-model": "   ",
+	})
+
+	db := setupModelListControllerTestDB(t)
+	require.NoError(t, db.Create(&model.User{
+		Id:       1001,
+		Username: "model-list-user",
+		Password: "password",
+		Group:    "default",
+		Status:   common.UserStatusEnabled,
+	}).Error)
+	require.NoError(t, db.Create(&[]model.Ability{
+		{Group: "default", Model: "zz-tiered-visible-model", ChannelId: 1, Enabled: true},
+		{Group: "default", Model: "zz-tiered-empty-expr-model", ChannelId: 1, Enabled: true},
+		{Group: "default", Model: "zz-tiered-missing-expr-model", ChannelId: 1, Enabled: true},
+		{Group: "default", Model: "zz-unpriced-model", ChannelId: 1, Enabled: true},
+	}).Error)
+
+	recorder := httptest.NewRecorder()
+	ctx, _ := gin.CreateTestContext(recorder)
+	ctx.Request = httptest.NewRequest(http.MethodGet, "/v1/models", nil)
+	ctx.Set("id", 1001)
+
+	ListModels(ctx, constant.ChannelTypeOpenAI)
+
+	ids := decodeListModelsResponse(t, recorder)
+	require.Contains(t, ids, "zz-tiered-visible-model")
+	require.NotContains(t, ids, "zz-tiered-empty-expr-model")
+	require.NotContains(t, ids, "zz-tiered-missing-expr-model")
+	require.NotContains(t, ids, "zz-unpriced-model")
+
+	pricingByName := pricingByModelName(model.GetPricing())
+	visiblePricing, ok := pricingByName["zz-tiered-visible-model"]
+	require.True(t, ok)
+	require.Equal(t, "tiered_expr", visiblePricing.BillingMode)
+	require.NotEmpty(t, visiblePricing.BillingExpr)
+
+	emptyExprPricing, ok := pricingByName["zz-tiered-empty-expr-model"]
+	require.True(t, ok)
+	require.Empty(t, emptyExprPricing.BillingMode)
+	require.Empty(t, emptyExprPricing.BillingExpr)
+
+	missingExprPricing, ok := pricingByName["zz-tiered-missing-expr-model"]
+	require.True(t, ok)
+	require.Empty(t, missingExprPricing.BillingMode)
+	require.Empty(t, missingExprPricing.BillingExpr)
+}
+
+func TestListModelsTokenLimitIncludesTieredBillingModel(t *testing.T) {
+	withSelfUseModeDisabled(t)
+	withTieredBillingConfig(t, map[string]string{
+		"zz-token-tiered-visible-model":      "tiered_expr",
+		"zz-token-tiered-empty-expr-model":   "tiered_expr",
+		"zz-token-tiered-missing-expr-model": "tiered_expr",
+	}, map[string]string{
+		"zz-token-tiered-visible-model":    `tier("base", p * 1 + c * 2)`,
+		"zz-token-tiered-empty-expr-model": "",
+	})
+
+	recorder := httptest.NewRecorder()
+	ctx, _ := gin.CreateTestContext(recorder)
+	ctx.Request = httptest.NewRequest(http.MethodGet, "/v1/models", nil)
+	common.SetContextKey(ctx, constant.ContextKeyTokenModelLimitEnabled, true)
+	common.SetContextKey(ctx, constant.ContextKeyTokenModelLimit, map[string]bool{
+		"zz-token-tiered-visible-model":      true,
+		"zz-token-tiered-empty-expr-model":   true,
+		"zz-token-tiered-missing-expr-model": true,
+		"zz-token-unpriced-model":            true,
+	})
+
+	ListModels(ctx, constant.ChannelTypeOpenAI)
+
+	ids := decodeListModelsResponse(t, recorder)
+	require.Contains(t, ids, "zz-token-tiered-visible-model")
+	require.NotContains(t, ids, "zz-token-tiered-empty-expr-model")
+	require.NotContains(t, ids, "zz-token-tiered-missing-expr-model")
+	require.NotContains(t, ids, "zz-token-unpriced-model")
+}

+ 3 - 0
model/option.go

@@ -578,6 +578,9 @@ func handleConfigUpdate(key, value string) bool {
 		performance_setting.UpdateAndSync()
 	} else if configName == "tool_price_setting" {
 		operation_setting.RebuildToolPriceIndex()
+	} else if configName == "billing_setting" {
+		InvalidatePricingCache()
+		ratio_setting.InvalidateExposedDataCache()
 	}
 
 	return true // 已处理

+ 24 - 1
model/pricing.go

@@ -77,6 +77,29 @@ func GetPricing() []Pricing {
 	return pricingMap
 }
 
+func InvalidatePricingCache() {
+	updatePricingLock.Lock()
+	defer updatePricingLock.Unlock()
+
+	pricingMap = nil
+	vendorsList = nil
+	lastGetPricingTime = time.Time{}
+}
+
+func HasModelBillingConfig(modelName string) bool {
+	if _, ok := ratio_setting.GetModelPrice(modelName, false); ok {
+		return true
+	}
+	if _, ok, _ := ratio_setting.GetModelRatio(modelName); ok {
+		return true
+	}
+	if billing_setting.GetBillingMode(modelName) != billing_setting.BillingModeTieredExpr {
+		return false
+	}
+	expr, ok := billing_setting.GetBillingExpr(modelName)
+	return ok && strings.TrimSpace(expr) != ""
+}
+
 // GetVendors 返回当前定价接口使用到的供应商信息
 func GetVendors() []PricingVendor {
 	if time.Since(lastGetPricingTime) > time.Minute*1 || len(pricingMap) == 0 {
@@ -323,7 +346,7 @@ func updatePricing() {
 			pricing.AudioCompletionRatio = &audioCompletionRatio
 		}
 		if billingMode := billing_setting.GetBillingMode(model); billingMode == "tiered_expr" {
-			if expr, ok := billing_setting.GetBillingExpr(model); ok && expr != "" {
+			if expr, ok := billing_setting.GetBillingExpr(model); ok && strings.TrimSpace(expr) != "" {
 				pricing.BillingMode = billingMode
 				pricing.BillingExpr = expr
 			}

+ 1 - 13
relay/helper/price.go

@@ -224,19 +224,7 @@ func ModelPriceHelperPerCall(c *gin.Context, info *relaycommon.RelayInfo) (types
 }
 
 func ContainPriceOrRatio(modelName string) bool {
-	_, ok := ratio_setting.GetModelPrice(modelName, false)
-	if ok {
-		return true
-	}
-	_, ok, _ = ratio_setting.GetModelRatio(modelName)
-	if ok {
-		return true
-	}
-	if billing_setting.GetBillingMode(modelName) == billing_setting.BillingModeTieredExpr {
-		_, ok = billing_setting.GetBillingExpr(modelName)
-		return ok
-	}
-	return false
+	return model.HasModelBillingConfig(modelName)
 }
 
 func modelPriceHelperTiered(c *gin.Context, info *relaycommon.RelayInfo, promptTokens int, meta *types.TokenCountMeta, groupRatioInfo types.GroupRatioInfo) (types.PriceData, error) {