| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242 |
- package controller
- import (
- "fmt"
- "net/http"
- "net/http/httptest"
- "os"
- "strings"
- "testing"
- "github.com/QuantumNous/new-api/common"
- "github.com/QuantumNous/new-api/constant"
- "github.com/QuantumNous/new-api/dto"
- "github.com/QuantumNous/new-api/model"
- "github.com/QuantumNous/new-api/setting/config"
- "github.com/QuantumNous/new-api/setting/operation_setting"
- "github.com/gin-gonic/gin"
- "github.com/glebarez/sqlite"
- "github.com/stretchr/testify/require"
- "gorm.io/gorm"
- )
- type listModelsResponse struct {
- Success bool `json:"success"`
- Data []dto.OpenAIModels `json:"data"`
- Object string `json:"object"`
- }
- func setupModelListControllerTestDB(t *testing.T) *gorm.DB {
- t.Helper()
- initModelListColumnNames(t)
- gin.SetMode(gin.TestMode)
- common.UsingSQLite = true
- common.UsingMySQL = false
- common.UsingPostgreSQL = false
- common.RedisEnabled = false
- dsn := fmt.Sprintf("file:%s?mode=memory&cache=shared", strings.ReplaceAll(t.Name(), "/", "_"))
- db, err := gorm.Open(sqlite.Open(dsn), &gorm.Config{})
- require.NoError(t, err)
- model.DB = db
- model.LOG_DB = db
- require.NoError(t, db.AutoMigrate(&model.User{}, &model.Channel{}, &model.Ability{}, &model.Model{}, &model.Vendor{}))
- t.Cleanup(func() {
- sqlDB, err := db.DB()
- if err == nil {
- _ = sqlDB.Close()
- }
- })
- return db
- }
- func initModelListColumnNames(t *testing.T) {
- t.Helper()
- originalIsMasterNode := common.IsMasterNode
- originalSQLitePath := common.SQLitePath
- originalUsingSQLite := common.UsingSQLite
- originalUsingMySQL := common.UsingMySQL
- originalUsingPostgreSQL := common.UsingPostgreSQL
- originalSQLDSN, hadSQLDSN := os.LookupEnv("SQL_DSN")
- defer func() {
- common.IsMasterNode = originalIsMasterNode
- common.SQLitePath = originalSQLitePath
- common.UsingSQLite = originalUsingSQLite
- common.UsingMySQL = originalUsingMySQL
- common.UsingPostgreSQL = originalUsingPostgreSQL
- if hadSQLDSN {
- require.NoError(t, os.Setenv("SQL_DSN", originalSQLDSN))
- } else {
- require.NoError(t, os.Unsetenv("SQL_DSN"))
- }
- }()
- common.IsMasterNode = false
- common.SQLitePath = fmt.Sprintf("file:%s_init?mode=memory&cache=shared", strings.ReplaceAll(t.Name(), "/", "_"))
- common.UsingSQLite = false
- common.UsingMySQL = false
- common.UsingPostgreSQL = false
- require.NoError(t, os.Setenv("SQL_DSN", "local"))
- require.NoError(t, model.InitDB())
- if model.DB != nil {
- sqlDB, err := model.DB.DB()
- if err == nil {
- _ = sqlDB.Close()
- }
- }
- }
- func withTieredBillingConfig(t *testing.T, modes map[string]string, exprs map[string]string) {
- t.Helper()
- saved := map[string]string{}
- require.NoError(t, config.GlobalConfig.SaveToDB(func(key, value string) error {
- if strings.HasPrefix(key, "billing_setting.") {
- saved[key] = value
- }
- return nil
- }))
- t.Cleanup(func() {
- require.NoError(t, config.GlobalConfig.LoadFromDB(saved))
- model.InvalidatePricingCache()
- })
- modeBytes, err := common.Marshal(modes)
- require.NoError(t, err)
- exprBytes, err := common.Marshal(exprs)
- require.NoError(t, err)
- require.NoError(t, config.GlobalConfig.LoadFromDB(map[string]string{
- "billing_setting.billing_mode": string(modeBytes),
- "billing_setting.billing_expr": string(exprBytes),
- }))
- model.InvalidatePricingCache()
- }
- func withSelfUseModeDisabled(t *testing.T) {
- t.Helper()
- original := operation_setting.SelfUseModeEnabled
- operation_setting.SelfUseModeEnabled = false
- t.Cleanup(func() {
- operation_setting.SelfUseModeEnabled = original
- })
- }
- func decodeListModelsResponse(t *testing.T, recorder *httptest.ResponseRecorder) map[string]struct{} {
- t.Helper()
- require.Equal(t, http.StatusOK, recorder.Code)
- var payload listModelsResponse
- require.NoError(t, common.Unmarshal(recorder.Body.Bytes(), &payload))
- require.True(t, payload.Success)
- require.Equal(t, "list", payload.Object)
- ids := make(map[string]struct{}, len(payload.Data))
- for _, item := range payload.Data {
- ids[item.Id] = struct{}{}
- }
- return ids
- }
- func pricingByModelName(pricings []model.Pricing) map[string]model.Pricing {
- byName := make(map[string]model.Pricing, len(pricings))
- for _, pricing := range pricings {
- byName[pricing.ModelName] = pricing
- }
- return byName
- }
- func TestListModelsIncludesTieredBillingModel(t *testing.T) {
- withSelfUseModeDisabled(t)
- withTieredBillingConfig(t, map[string]string{
- "zz-tiered-visible-model": "tiered_expr",
- "zz-tiered-empty-expr-model": "tiered_expr",
- "zz-tiered-missing-expr-model": "tiered_expr",
- }, map[string]string{
- "zz-tiered-visible-model": `tier("base", p * 1 + c * 2)`,
- "zz-tiered-empty-expr-model": " ",
- })
- db := setupModelListControllerTestDB(t)
- require.NoError(t, db.Create(&model.User{
- Id: 1001,
- Username: "model-list-user",
- Password: "password",
- Group: "default",
- Status: common.UserStatusEnabled,
- }).Error)
- require.NoError(t, db.Create(&[]model.Ability{
- {Group: "default", Model: "zz-tiered-visible-model", ChannelId: 1, Enabled: true},
- {Group: "default", Model: "zz-tiered-empty-expr-model", ChannelId: 1, Enabled: true},
- {Group: "default", Model: "zz-tiered-missing-expr-model", ChannelId: 1, Enabled: true},
- {Group: "default", Model: "zz-unpriced-model", ChannelId: 1, Enabled: true},
- }).Error)
- recorder := httptest.NewRecorder()
- ctx, _ := gin.CreateTestContext(recorder)
- ctx.Request = httptest.NewRequest(http.MethodGet, "/v1/models", nil)
- ctx.Set("id", 1001)
- ListModels(ctx, constant.ChannelTypeOpenAI)
- ids := decodeListModelsResponse(t, recorder)
- require.Contains(t, ids, "zz-tiered-visible-model")
- require.NotContains(t, ids, "zz-tiered-empty-expr-model")
- require.NotContains(t, ids, "zz-tiered-missing-expr-model")
- require.NotContains(t, ids, "zz-unpriced-model")
- pricingByName := pricingByModelName(model.GetPricing())
- visiblePricing, ok := pricingByName["zz-tiered-visible-model"]
- require.True(t, ok)
- require.Equal(t, "tiered_expr", visiblePricing.BillingMode)
- require.NotEmpty(t, visiblePricing.BillingExpr)
- emptyExprPricing, ok := pricingByName["zz-tiered-empty-expr-model"]
- require.True(t, ok)
- require.Empty(t, emptyExprPricing.BillingMode)
- require.Empty(t, emptyExprPricing.BillingExpr)
- missingExprPricing, ok := pricingByName["zz-tiered-missing-expr-model"]
- require.True(t, ok)
- require.Empty(t, missingExprPricing.BillingMode)
- require.Empty(t, missingExprPricing.BillingExpr)
- }
- func TestListModelsTokenLimitIncludesTieredBillingModel(t *testing.T) {
- withSelfUseModeDisabled(t)
- withTieredBillingConfig(t, map[string]string{
- "zz-token-tiered-visible-model": "tiered_expr",
- "zz-token-tiered-empty-expr-model": "tiered_expr",
- "zz-token-tiered-missing-expr-model": "tiered_expr",
- }, map[string]string{
- "zz-token-tiered-visible-model": `tier("base", p * 1 + c * 2)`,
- "zz-token-tiered-empty-expr-model": "",
- })
- recorder := httptest.NewRecorder()
- ctx, _ := gin.CreateTestContext(recorder)
- ctx.Request = httptest.NewRequest(http.MethodGet, "/v1/models", nil)
- common.SetContextKey(ctx, constant.ContextKeyTokenModelLimitEnabled, true)
- common.SetContextKey(ctx, constant.ContextKeyTokenModelLimit, map[string]bool{
- "zz-token-tiered-visible-model": true,
- "zz-token-tiered-empty-expr-model": true,
- "zz-token-tiered-missing-expr-model": true,
- "zz-token-unpriced-model": true,
- })
- ListModels(ctx, constant.ChannelTypeOpenAI)
- ids := decodeListModelsResponse(t, recorder)
- require.Contains(t, ids, "zz-token-tiered-visible-model")
- require.NotContains(t, ids, "zz-token-tiered-empty-expr-model")
- require.NotContains(t, ids, "zz-token-tiered-missing-expr-model")
- require.NotContains(t, ids, "zz-token-unpriced-model")
- }
|