model_list_test.go 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242
  1. package controller
  2. import (
  3. "fmt"
  4. "net/http"
  5. "net/http/httptest"
  6. "os"
  7. "strings"
  8. "testing"
  9. "github.com/QuantumNous/new-api/common"
  10. "github.com/QuantumNous/new-api/constant"
  11. "github.com/QuantumNous/new-api/dto"
  12. "github.com/QuantumNous/new-api/model"
  13. "github.com/QuantumNous/new-api/setting/config"
  14. "github.com/QuantumNous/new-api/setting/operation_setting"
  15. "github.com/gin-gonic/gin"
  16. "github.com/glebarez/sqlite"
  17. "github.com/stretchr/testify/require"
  18. "gorm.io/gorm"
  19. )
  20. type listModelsResponse struct {
  21. Success bool `json:"success"`
  22. Data []dto.OpenAIModels `json:"data"`
  23. Object string `json:"object"`
  24. }
  25. func setupModelListControllerTestDB(t *testing.T) *gorm.DB {
  26. t.Helper()
  27. initModelListColumnNames(t)
  28. gin.SetMode(gin.TestMode)
  29. common.UsingSQLite = true
  30. common.UsingMySQL = false
  31. common.UsingPostgreSQL = false
  32. common.RedisEnabled = false
  33. dsn := fmt.Sprintf("file:%s?mode=memory&cache=shared", strings.ReplaceAll(t.Name(), "/", "_"))
  34. db, err := gorm.Open(sqlite.Open(dsn), &gorm.Config{})
  35. require.NoError(t, err)
  36. model.DB = db
  37. model.LOG_DB = db
  38. require.NoError(t, db.AutoMigrate(&model.User{}, &model.Channel{}, &model.Ability{}, &model.Model{}, &model.Vendor{}))
  39. t.Cleanup(func() {
  40. sqlDB, err := db.DB()
  41. if err == nil {
  42. _ = sqlDB.Close()
  43. }
  44. })
  45. return db
  46. }
  47. func initModelListColumnNames(t *testing.T) {
  48. t.Helper()
  49. originalIsMasterNode := common.IsMasterNode
  50. originalSQLitePath := common.SQLitePath
  51. originalUsingSQLite := common.UsingSQLite
  52. originalUsingMySQL := common.UsingMySQL
  53. originalUsingPostgreSQL := common.UsingPostgreSQL
  54. originalSQLDSN, hadSQLDSN := os.LookupEnv("SQL_DSN")
  55. defer func() {
  56. common.IsMasterNode = originalIsMasterNode
  57. common.SQLitePath = originalSQLitePath
  58. common.UsingSQLite = originalUsingSQLite
  59. common.UsingMySQL = originalUsingMySQL
  60. common.UsingPostgreSQL = originalUsingPostgreSQL
  61. if hadSQLDSN {
  62. require.NoError(t, os.Setenv("SQL_DSN", originalSQLDSN))
  63. } else {
  64. require.NoError(t, os.Unsetenv("SQL_DSN"))
  65. }
  66. }()
  67. common.IsMasterNode = false
  68. common.SQLitePath = fmt.Sprintf("file:%s_init?mode=memory&cache=shared", strings.ReplaceAll(t.Name(), "/", "_"))
  69. common.UsingSQLite = false
  70. common.UsingMySQL = false
  71. common.UsingPostgreSQL = false
  72. require.NoError(t, os.Setenv("SQL_DSN", "local"))
  73. require.NoError(t, model.InitDB())
  74. if model.DB != nil {
  75. sqlDB, err := model.DB.DB()
  76. if err == nil {
  77. _ = sqlDB.Close()
  78. }
  79. }
  80. }
  81. func withTieredBillingConfig(t *testing.T, modes map[string]string, exprs map[string]string) {
  82. t.Helper()
  83. saved := map[string]string{}
  84. require.NoError(t, config.GlobalConfig.SaveToDB(func(key, value string) error {
  85. if strings.HasPrefix(key, "billing_setting.") {
  86. saved[key] = value
  87. }
  88. return nil
  89. }))
  90. t.Cleanup(func() {
  91. require.NoError(t, config.GlobalConfig.LoadFromDB(saved))
  92. model.InvalidatePricingCache()
  93. })
  94. modeBytes, err := common.Marshal(modes)
  95. require.NoError(t, err)
  96. exprBytes, err := common.Marshal(exprs)
  97. require.NoError(t, err)
  98. require.NoError(t, config.GlobalConfig.LoadFromDB(map[string]string{
  99. "billing_setting.billing_mode": string(modeBytes),
  100. "billing_setting.billing_expr": string(exprBytes),
  101. }))
  102. model.InvalidatePricingCache()
  103. }
  104. func withSelfUseModeDisabled(t *testing.T) {
  105. t.Helper()
  106. original := operation_setting.SelfUseModeEnabled
  107. operation_setting.SelfUseModeEnabled = false
  108. t.Cleanup(func() {
  109. operation_setting.SelfUseModeEnabled = original
  110. })
  111. }
  112. func decodeListModelsResponse(t *testing.T, recorder *httptest.ResponseRecorder) map[string]struct{} {
  113. t.Helper()
  114. require.Equal(t, http.StatusOK, recorder.Code)
  115. var payload listModelsResponse
  116. require.NoError(t, common.Unmarshal(recorder.Body.Bytes(), &payload))
  117. require.True(t, payload.Success)
  118. require.Equal(t, "list", payload.Object)
  119. ids := make(map[string]struct{}, len(payload.Data))
  120. for _, item := range payload.Data {
  121. ids[item.Id] = struct{}{}
  122. }
  123. return ids
  124. }
  125. func pricingByModelName(pricings []model.Pricing) map[string]model.Pricing {
  126. byName := make(map[string]model.Pricing, len(pricings))
  127. for _, pricing := range pricings {
  128. byName[pricing.ModelName] = pricing
  129. }
  130. return byName
  131. }
  132. func TestListModelsIncludesTieredBillingModel(t *testing.T) {
  133. withSelfUseModeDisabled(t)
  134. withTieredBillingConfig(t, map[string]string{
  135. "zz-tiered-visible-model": "tiered_expr",
  136. "zz-tiered-empty-expr-model": "tiered_expr",
  137. "zz-tiered-missing-expr-model": "tiered_expr",
  138. }, map[string]string{
  139. "zz-tiered-visible-model": `tier("base", p * 1 + c * 2)`,
  140. "zz-tiered-empty-expr-model": " ",
  141. })
  142. db := setupModelListControllerTestDB(t)
  143. require.NoError(t, db.Create(&model.User{
  144. Id: 1001,
  145. Username: "model-list-user",
  146. Password: "password",
  147. Group: "default",
  148. Status: common.UserStatusEnabled,
  149. }).Error)
  150. require.NoError(t, db.Create(&[]model.Ability{
  151. {Group: "default", Model: "zz-tiered-visible-model", ChannelId: 1, Enabled: true},
  152. {Group: "default", Model: "zz-tiered-empty-expr-model", ChannelId: 1, Enabled: true},
  153. {Group: "default", Model: "zz-tiered-missing-expr-model", ChannelId: 1, Enabled: true},
  154. {Group: "default", Model: "zz-unpriced-model", ChannelId: 1, Enabled: true},
  155. }).Error)
  156. recorder := httptest.NewRecorder()
  157. ctx, _ := gin.CreateTestContext(recorder)
  158. ctx.Request = httptest.NewRequest(http.MethodGet, "/v1/models", nil)
  159. ctx.Set("id", 1001)
  160. ListModels(ctx, constant.ChannelTypeOpenAI)
  161. ids := decodeListModelsResponse(t, recorder)
  162. require.Contains(t, ids, "zz-tiered-visible-model")
  163. require.NotContains(t, ids, "zz-tiered-empty-expr-model")
  164. require.NotContains(t, ids, "zz-tiered-missing-expr-model")
  165. require.NotContains(t, ids, "zz-unpriced-model")
  166. pricingByName := pricingByModelName(model.GetPricing())
  167. visiblePricing, ok := pricingByName["zz-tiered-visible-model"]
  168. require.True(t, ok)
  169. require.Equal(t, "tiered_expr", visiblePricing.BillingMode)
  170. require.NotEmpty(t, visiblePricing.BillingExpr)
  171. emptyExprPricing, ok := pricingByName["zz-tiered-empty-expr-model"]
  172. require.True(t, ok)
  173. require.Empty(t, emptyExprPricing.BillingMode)
  174. require.Empty(t, emptyExprPricing.BillingExpr)
  175. missingExprPricing, ok := pricingByName["zz-tiered-missing-expr-model"]
  176. require.True(t, ok)
  177. require.Empty(t, missingExprPricing.BillingMode)
  178. require.Empty(t, missingExprPricing.BillingExpr)
  179. }
  180. func TestListModelsTokenLimitIncludesTieredBillingModel(t *testing.T) {
  181. withSelfUseModeDisabled(t)
  182. withTieredBillingConfig(t, map[string]string{
  183. "zz-token-tiered-visible-model": "tiered_expr",
  184. "zz-token-tiered-empty-expr-model": "tiered_expr",
  185. "zz-token-tiered-missing-expr-model": "tiered_expr",
  186. }, map[string]string{
  187. "zz-token-tiered-visible-model": `tier("base", p * 1 + c * 2)`,
  188. "zz-token-tiered-empty-expr-model": "",
  189. })
  190. recorder := httptest.NewRecorder()
  191. ctx, _ := gin.CreateTestContext(recorder)
  192. ctx.Request = httptest.NewRequest(http.MethodGet, "/v1/models", nil)
  193. common.SetContextKey(ctx, constant.ContextKeyTokenModelLimitEnabled, true)
  194. common.SetContextKey(ctx, constant.ContextKeyTokenModelLimit, map[string]bool{
  195. "zz-token-tiered-visible-model": true,
  196. "zz-token-tiered-empty-expr-model": true,
  197. "zz-token-tiered-missing-expr-model": true,
  198. "zz-token-unpriced-model": true,
  199. })
  200. ListModels(ctx, constant.ChannelTypeOpenAI)
  201. ids := decodeListModelsResponse(t, recorder)
  202. require.Contains(t, ids, "zz-token-tiered-visible-model")
  203. require.NotContains(t, ids, "zz-token-tiered-empty-expr-model")
  204. require.NotContains(t, ids, "zz-token-tiered-missing-expr-model")
  205. require.NotContains(t, ids, "zz-token-unpriced-model")
  206. }