token_test.go 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541
  1. package controller
  2. import (
  3. "bytes"
  4. "database/sql"
  5. "encoding/json"
  6. "fmt"
  7. "net/http"
  8. "net/http/httptest"
  9. "os"
  10. "strconv"
  11. "strings"
  12. "testing"
  13. "github.com/QuantumNous/new-api/common"
  14. "github.com/QuantumNous/new-api/model"
  15. "github.com/gin-gonic/gin"
  16. "github.com/glebarez/sqlite"
  17. "gorm.io/driver/mysql"
  18. "gorm.io/driver/postgres"
  19. "gorm.io/gorm"
  20. )
  21. type tokenAPIResponse struct {
  22. Success bool `json:"success"`
  23. Message string `json:"message"`
  24. Data json.RawMessage `json:"data"`
  25. }
  26. type tokenPageResponse struct {
  27. Items []tokenResponseItem `json:"items"`
  28. }
  29. type tokenResponseItem struct {
  30. ID int `json:"id"`
  31. Name string `json:"name"`
  32. Key string `json:"key"`
  33. Status int `json:"status"`
  34. }
  35. type tokenKeyResponse struct {
  36. Key string `json:"key"`
  37. }
  38. type sqliteColumnInfo struct {
  39. Name string `gorm:"column:name"`
  40. Type string `gorm:"column:type"`
  41. }
  42. type legacyToken struct {
  43. Id int `gorm:"primaryKey"`
  44. UserId int `gorm:"index"`
  45. Key string `gorm:"column:key;type:char(48);uniqueIndex"`
  46. Status int `gorm:"default:1"`
  47. Name string `gorm:"index"`
  48. CreatedTime int64 `gorm:"bigint"`
  49. AccessedTime int64 `gorm:"bigint"`
  50. ExpiredTime int64 `gorm:"bigint;default:-1"`
  51. RemainQuota int `gorm:"default:0"`
  52. UnlimitedQuota bool
  53. ModelLimitsEnabled bool
  54. ModelLimits string `gorm:"type:text"`
  55. AllowIps *string `gorm:"default:''"`
  56. UsedQuota int `gorm:"default:0"`
  57. Group string `gorm:"column:group;default:''"`
  58. CrossGroupRetry bool
  59. DeletedAt gorm.DeletedAt `gorm:"index"`
  60. }
  61. func (legacyToken) TableName() string {
  62. return "tokens"
  63. }
  64. func openTokenControllerTestDB(t *testing.T) *gorm.DB {
  65. t.Helper()
  66. gin.SetMode(gin.TestMode)
  67. common.UsingSQLite = true
  68. common.UsingMySQL = false
  69. common.UsingPostgreSQL = false
  70. common.RedisEnabled = false
  71. dsn := fmt.Sprintf("file:%s?mode=memory&cache=shared", strings.ReplaceAll(t.Name(), "/", "_"))
  72. db, err := gorm.Open(sqlite.Open(dsn), &gorm.Config{})
  73. if err != nil {
  74. t.Fatalf("failed to open sqlite db: %v", err)
  75. }
  76. model.DB = db
  77. model.LOG_DB = db
  78. t.Cleanup(func() {
  79. sqlDB, err := db.DB()
  80. if err == nil {
  81. _ = sqlDB.Close()
  82. }
  83. })
  84. return db
  85. }
  86. func migrateTokenControllerTestDB(t *testing.T, db *gorm.DB) {
  87. t.Helper()
  88. if err := db.AutoMigrate(&model.Token{}); err != nil {
  89. t.Fatalf("failed to migrate token table: %v", err)
  90. }
  91. }
  92. func setupTokenControllerTestDB(t *testing.T) *gorm.DB {
  93. t.Helper()
  94. db := openTokenControllerTestDB(t)
  95. migrateTokenControllerTestDB(t, db)
  96. return db
  97. }
  98. func openTokenControllerExternalDB(t *testing.T, dialect string, dsn string) (*gorm.DB, *bool) {
  99. t.Helper()
  100. gin.SetMode(gin.TestMode)
  101. common.RedisEnabled = false
  102. common.UsingSQLite = false
  103. common.UsingMySQL = dialect == "mysql"
  104. common.UsingPostgreSQL = dialect == "postgres"
  105. var (
  106. db *gorm.DB
  107. err error
  108. )
  109. switch dialect {
  110. case "mysql":
  111. db, err = gorm.Open(mysql.Open(dsn), &gorm.Config{})
  112. case "postgres":
  113. db, err = gorm.Open(postgres.Open(dsn), &gorm.Config{})
  114. default:
  115. t.Fatalf("unsupported dialect %q", dialect)
  116. }
  117. if err != nil {
  118. t.Fatalf("failed to open %s db: %v", dialect, err)
  119. }
  120. model.DB = db
  121. model.LOG_DB = db
  122. if db.Migrator().HasTable("tokens") {
  123. t.Skipf("refusing to run %s migration compatibility test against external database because tokens table already exists", dialect)
  124. }
  125. managedTokensTable := new(bool)
  126. t.Cleanup(func() {
  127. if *managedTokensTable && db.Migrator().HasTable("tokens") {
  128. _ = db.Migrator().DropTable("tokens")
  129. }
  130. sqlDB, err := db.DB()
  131. if err == nil {
  132. _ = sqlDB.Close()
  133. }
  134. })
  135. return db, managedTokensTable
  136. }
  137. func seedToken(t *testing.T, db *gorm.DB, userID int, name string, rawKey string) *model.Token {
  138. t.Helper()
  139. token := &model.Token{
  140. UserId: userID,
  141. Name: name,
  142. Key: rawKey,
  143. Status: common.TokenStatusEnabled,
  144. CreatedTime: 1,
  145. AccessedTime: 1,
  146. ExpiredTime: -1,
  147. RemainQuota: 100,
  148. UnlimitedQuota: true,
  149. Group: "default",
  150. }
  151. if err := db.Create(token).Error; err != nil {
  152. t.Fatalf("failed to create token: %v", err)
  153. }
  154. return token
  155. }
  156. func newAuthenticatedContext(t *testing.T, method string, target string, body any, userID int) (*gin.Context, *httptest.ResponseRecorder) {
  157. t.Helper()
  158. var requestBody *bytes.Reader
  159. if body != nil {
  160. payload, err := common.Marshal(body)
  161. if err != nil {
  162. t.Fatalf("failed to marshal request body: %v", err)
  163. }
  164. requestBody = bytes.NewReader(payload)
  165. } else {
  166. requestBody = bytes.NewReader(nil)
  167. }
  168. recorder := httptest.NewRecorder()
  169. ctx, _ := gin.CreateTestContext(recorder)
  170. ctx.Request = httptest.NewRequest(method, target, requestBody)
  171. if body != nil {
  172. ctx.Request.Header.Set("Content-Type", "application/json")
  173. }
  174. ctx.Set("id", userID)
  175. return ctx, recorder
  176. }
  177. func decodeAPIResponse(t *testing.T, recorder *httptest.ResponseRecorder) tokenAPIResponse {
  178. t.Helper()
  179. var response tokenAPIResponse
  180. if err := common.Unmarshal(recorder.Body.Bytes(), &response); err != nil {
  181. t.Fatalf("failed to decode api response: %v", err)
  182. }
  183. return response
  184. }
  185. func getSQLiteColumnType(t *testing.T, db *gorm.DB, tableName string, columnName string) string {
  186. t.Helper()
  187. var columns []sqliteColumnInfo
  188. if err := db.Raw("PRAGMA table_info(" + tableName + ")").Scan(&columns).Error; err != nil {
  189. t.Fatalf("failed to inspect %s schema: %v", tableName, err)
  190. }
  191. for _, column := range columns {
  192. if column.Name == columnName {
  193. return strings.ToLower(column.Type)
  194. }
  195. }
  196. t.Fatalf("column %s not found in %s schema", columnName, tableName)
  197. return ""
  198. }
  199. func getTokenKeyColumnType(t *testing.T, db *gorm.DB, dialect string) string {
  200. t.Helper()
  201. switch dialect {
  202. case "sqlite":
  203. return getSQLiteColumnType(t, db, "tokens", "key")
  204. case "mysql":
  205. var columnType string
  206. if err := db.Raw(`SELECT COLUMN_TYPE FROM information_schema.columns
  207. WHERE table_schema = DATABASE() AND table_name = ? AND column_name = ?`,
  208. "tokens", "key").Scan(&columnType).Error; err != nil {
  209. t.Fatalf("failed to inspect mysql token key column: %v", err)
  210. }
  211. return strings.ToLower(columnType)
  212. case "postgres":
  213. var dataType string
  214. var maxLength sql.NullInt64
  215. if err := db.Raw(`SELECT data_type, character_maximum_length
  216. FROM information_schema.columns
  217. WHERE table_schema = current_schema() AND table_name = ? AND column_name = ?`,
  218. "tokens", "key").Row().Scan(&dataType, &maxLength); err != nil {
  219. t.Fatalf("failed to inspect postgres token key column: %v", err)
  220. }
  221. switch strings.ToLower(dataType) {
  222. case "character varying":
  223. return fmt.Sprintf("varchar(%d)", maxLength.Int64)
  224. case "character":
  225. return fmt.Sprintf("char(%d)", maxLength.Int64)
  226. default:
  227. if maxLength.Valid {
  228. return fmt.Sprintf("%s(%d)", strings.ToLower(dataType), maxLength.Int64)
  229. }
  230. return strings.ToLower(dataType)
  231. }
  232. default:
  233. t.Fatalf("unsupported dialect %q", dialect)
  234. return ""
  235. }
  236. }
  237. func runTokenMigrationCompatibilityTest(t *testing.T, db *gorm.DB, dialect string, managedTokensTable *bool) {
  238. t.Helper()
  239. legacyKey := strings.Repeat("a", 48)
  240. longKey := strings.Repeat("b", 64)
  241. if err := db.AutoMigrate(&legacyToken{}); err != nil {
  242. t.Fatalf("failed to create legacy token schema: %v", err)
  243. }
  244. if managedTokensTable != nil {
  245. *managedTokensTable = true
  246. }
  247. if err := db.Create(&legacyToken{
  248. UserId: 7,
  249. Key: legacyKey,
  250. Status: common.TokenStatusEnabled,
  251. Name: "legacy-token",
  252. CreatedTime: 1,
  253. AccessedTime: 1,
  254. ExpiredTime: -1,
  255. RemainQuota: 100,
  256. UnlimitedQuota: true,
  257. ModelLimitsEnabled: false,
  258. ModelLimits: "",
  259. AllowIps: common.GetPointer(""),
  260. UsedQuota: 0,
  261. Group: "default",
  262. CrossGroupRetry: false,
  263. }).Error; err != nil {
  264. t.Fatalf("failed to seed legacy token row: %v", err)
  265. }
  266. if got := getTokenKeyColumnType(t, db, dialect); got != "char(48)" {
  267. t.Fatalf("expected legacy key column type char(48), got %q", got)
  268. }
  269. migrateTokenControllerTestDB(t, db)
  270. if got := getTokenKeyColumnType(t, db, dialect); got != "varchar(128)" {
  271. t.Fatalf("expected migrated key column type varchar(128), got %q", got)
  272. }
  273. var migratedToken model.Token
  274. if err := db.First(&migratedToken, "name = ?", "legacy-token").Error; err != nil {
  275. t.Fatalf("failed to load migrated token row: %v", err)
  276. }
  277. if migratedToken.Key != legacyKey {
  278. t.Fatalf("expected migrated token key %q, got %q", legacyKey, migratedToken.Key)
  279. }
  280. if migratedToken.Name != "legacy-token" {
  281. t.Fatalf("expected migrated token name to be preserved, got %q", migratedToken.Name)
  282. }
  283. inserted := model.Token{
  284. UserId: 8,
  285. Name: "long-token",
  286. Key: longKey,
  287. Status: common.TokenStatusEnabled,
  288. CreatedTime: 1,
  289. AccessedTime: 1,
  290. ExpiredTime: -1,
  291. RemainQuota: 200,
  292. UnlimitedQuota: true,
  293. ModelLimitsEnabled: false,
  294. ModelLimits: "",
  295. AllowIps: common.GetPointer(""),
  296. UsedQuota: 0,
  297. Group: "default",
  298. CrossGroupRetry: false,
  299. }
  300. if err := db.Create(&inserted).Error; err != nil {
  301. t.Fatalf("failed to insert long token after migration: %v", err)
  302. }
  303. var fetched model.Token
  304. if err := db.First(&fetched, "id = ?", inserted.Id).Error; err != nil {
  305. t.Fatalf("failed to fetch long token after migration: %v", err)
  306. }
  307. if fetched.Key != longKey {
  308. t.Fatalf("expected long token key %q, got %q", longKey, fetched.Key)
  309. }
  310. }
  311. func TestTokenAutoMigrateUsesVarchar128KeyColumn(t *testing.T) {
  312. db := setupTokenControllerTestDB(t)
  313. if got := getTokenKeyColumnType(t, db, "sqlite"); got != "varchar(128)" {
  314. t.Fatalf("expected key column type varchar(128), got %q", got)
  315. }
  316. }
  317. func TestTokenMigrationFromChar48ToVarchar128(t *testing.T) {
  318. db := openTokenControllerTestDB(t)
  319. runTokenMigrationCompatibilityTest(t, db, "sqlite", nil)
  320. }
  321. func TestTokenMigrationFromChar48ToVarchar128MySQL(t *testing.T) {
  322. dsn := os.Getenv("TEST_MYSQL_DSN")
  323. if dsn == "" {
  324. t.Skip("set TEST_MYSQL_DSN to run mysql migration compatibility test")
  325. }
  326. db, managedTokensTable := openTokenControllerExternalDB(t, "mysql", dsn)
  327. runTokenMigrationCompatibilityTest(t, db, "mysql", managedTokensTable)
  328. }
  329. func TestTokenMigrationFromChar48ToVarchar128Postgres(t *testing.T) {
  330. dsn := os.Getenv("TEST_POSTGRES_DSN")
  331. if dsn == "" {
  332. t.Skip("set TEST_POSTGRES_DSN to run postgres migration compatibility test")
  333. }
  334. db, managedTokensTable := openTokenControllerExternalDB(t, "postgres", dsn)
  335. runTokenMigrationCompatibilityTest(t, db, "postgres", managedTokensTable)
  336. }
  337. func TestGetAllTokensMasksKeyInResponse(t *testing.T) {
  338. db := setupTokenControllerTestDB(t)
  339. token := seedToken(t, db, 1, "list-token", "abcd1234efgh5678")
  340. seedToken(t, db, 2, "other-user-token", "zzzz1234yyyy5678")
  341. ctx, recorder := newAuthenticatedContext(t, http.MethodGet, "/api/token/?p=1&size=10", nil, 1)
  342. GetAllTokens(ctx)
  343. response := decodeAPIResponse(t, recorder)
  344. if !response.Success {
  345. t.Fatalf("expected success response, got message: %s", response.Message)
  346. }
  347. var page tokenPageResponse
  348. if err := common.Unmarshal(response.Data, &page); err != nil {
  349. t.Fatalf("failed to decode token page response: %v", err)
  350. }
  351. if len(page.Items) != 1 {
  352. t.Fatalf("expected exactly one token, got %d", len(page.Items))
  353. }
  354. if page.Items[0].Key != token.GetMaskedKey() {
  355. t.Fatalf("expected masked key %q, got %q", token.GetMaskedKey(), page.Items[0].Key)
  356. }
  357. if strings.Contains(recorder.Body.String(), token.Key) {
  358. t.Fatalf("list response leaked raw token key: %s", recorder.Body.String())
  359. }
  360. }
  361. func TestSearchTokensMasksKeyInResponse(t *testing.T) {
  362. db := setupTokenControllerTestDB(t)
  363. token := seedToken(t, db, 1, "searchable-token", "ijkl1234mnop5678")
  364. ctx, recorder := newAuthenticatedContext(t, http.MethodGet, "/api/token/search?keyword=searchable-token&p=1&size=10", nil, 1)
  365. SearchTokens(ctx)
  366. response := decodeAPIResponse(t, recorder)
  367. if !response.Success {
  368. t.Fatalf("expected success response, got message: %s", response.Message)
  369. }
  370. var page tokenPageResponse
  371. if err := common.Unmarshal(response.Data, &page); err != nil {
  372. t.Fatalf("failed to decode search response: %v", err)
  373. }
  374. if len(page.Items) != 1 {
  375. t.Fatalf("expected exactly one search result, got %d", len(page.Items))
  376. }
  377. if page.Items[0].Key != token.GetMaskedKey() {
  378. t.Fatalf("expected masked search key %q, got %q", token.GetMaskedKey(), page.Items[0].Key)
  379. }
  380. if strings.Contains(recorder.Body.String(), token.Key) {
  381. t.Fatalf("search response leaked raw token key: %s", recorder.Body.String())
  382. }
  383. }
  384. func TestGetTokenMasksKeyInResponse(t *testing.T) {
  385. db := setupTokenControllerTestDB(t)
  386. token := seedToken(t, db, 1, "detail-token", "qrst1234uvwx5678")
  387. ctx, recorder := newAuthenticatedContext(t, http.MethodGet, "/api/token/"+strconv.Itoa(token.Id), nil, 1)
  388. ctx.Params = gin.Params{{Key: "id", Value: strconv.Itoa(token.Id)}}
  389. GetToken(ctx)
  390. response := decodeAPIResponse(t, recorder)
  391. if !response.Success {
  392. t.Fatalf("expected success response, got message: %s", response.Message)
  393. }
  394. var detail tokenResponseItem
  395. if err := common.Unmarshal(response.Data, &detail); err != nil {
  396. t.Fatalf("failed to decode token detail response: %v", err)
  397. }
  398. if detail.Key != token.GetMaskedKey() {
  399. t.Fatalf("expected masked detail key %q, got %q", token.GetMaskedKey(), detail.Key)
  400. }
  401. if strings.Contains(recorder.Body.String(), token.Key) {
  402. t.Fatalf("detail response leaked raw token key: %s", recorder.Body.String())
  403. }
  404. }
  405. func TestUpdateTokenMasksKeyInResponse(t *testing.T) {
  406. db := setupTokenControllerTestDB(t)
  407. token := seedToken(t, db, 1, "editable-token", "yzab1234cdef5678")
  408. body := map[string]any{
  409. "id": token.Id,
  410. "name": "updated-token",
  411. "expired_time": -1,
  412. "remain_quota": 100,
  413. "unlimited_quota": true,
  414. "model_limits_enabled": false,
  415. "model_limits": "",
  416. "group": "default",
  417. "cross_group_retry": false,
  418. }
  419. ctx, recorder := newAuthenticatedContext(t, http.MethodPut, "/api/token/", body, 1)
  420. UpdateToken(ctx)
  421. response := decodeAPIResponse(t, recorder)
  422. if !response.Success {
  423. t.Fatalf("expected success response, got message: %s", response.Message)
  424. }
  425. var detail tokenResponseItem
  426. if err := common.Unmarshal(response.Data, &detail); err != nil {
  427. t.Fatalf("failed to decode token update response: %v", err)
  428. }
  429. if detail.Key != token.GetMaskedKey() {
  430. t.Fatalf("expected masked update key %q, got %q", token.GetMaskedKey(), detail.Key)
  431. }
  432. if strings.Contains(recorder.Body.String(), token.Key) {
  433. t.Fatalf("update response leaked raw token key: %s", recorder.Body.String())
  434. }
  435. }
  436. func TestGetTokenKeyRequiresOwnershipAndReturnsFullKey(t *testing.T) {
  437. db := setupTokenControllerTestDB(t)
  438. token := seedToken(t, db, 1, "owned-token", "owner1234token5678")
  439. authorizedCtx, authorizedRecorder := newAuthenticatedContext(t, http.MethodPost, "/api/token/"+strconv.Itoa(token.Id)+"/key", nil, 1)
  440. authorizedCtx.Params = gin.Params{{Key: "id", Value: strconv.Itoa(token.Id)}}
  441. GetTokenKey(authorizedCtx)
  442. authorizedResponse := decodeAPIResponse(t, authorizedRecorder)
  443. if !authorizedResponse.Success {
  444. t.Fatalf("expected authorized key fetch to succeed, got message: %s", authorizedResponse.Message)
  445. }
  446. var keyData tokenKeyResponse
  447. if err := common.Unmarshal(authorizedResponse.Data, &keyData); err != nil {
  448. t.Fatalf("failed to decode token key response: %v", err)
  449. }
  450. if keyData.Key != token.GetFullKey() {
  451. t.Fatalf("expected full key %q, got %q", token.GetFullKey(), keyData.Key)
  452. }
  453. unauthorizedCtx, unauthorizedRecorder := newAuthenticatedContext(t, http.MethodPost, "/api/token/"+strconv.Itoa(token.Id)+"/key", nil, 2)
  454. unauthorizedCtx.Params = gin.Params{{Key: "id", Value: strconv.Itoa(token.Id)}}
  455. GetTokenKey(unauthorizedCtx)
  456. unauthorizedResponse := decodeAPIResponse(t, unauthorizedRecorder)
  457. if unauthorizedResponse.Success {
  458. t.Fatalf("expected unauthorized key fetch to fail")
  459. }
  460. if strings.Contains(unauthorizedRecorder.Body.String(), token.Key) {
  461. t.Fatalf("unauthorized key response leaked raw token key: %s", unauthorizedRecorder.Body.String())
  462. }
  463. }