token_test.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392
  1. package controller
  2. import (
  3. "bytes"
  4. "encoding/json"
  5. "fmt"
  6. "net/http"
  7. "net/http/httptest"
  8. "strconv"
  9. "strings"
  10. "testing"
  11. "github.com/QuantumNous/new-api/common"
  12. "github.com/QuantumNous/new-api/model"
  13. "github.com/gin-gonic/gin"
  14. "github.com/glebarez/sqlite"
  15. "gorm.io/gorm"
  16. )
  17. type tokenAPIResponse struct {
  18. Success bool `json:"success"`
  19. Message string `json:"message"`
  20. Data json.RawMessage `json:"data"`
  21. }
  22. type tokenPageResponse struct {
  23. Items []tokenResponseItem `json:"items"`
  24. }
  25. type tokenResponseItem struct {
  26. ID int `json:"id"`
  27. Name string `json:"name"`
  28. Key string `json:"key"`
  29. Status int `json:"status"`
  30. }
  31. type tokenKeyResponse struct {
  32. Key string `json:"key"`
  33. }
  34. type sqliteColumnInfo struct {
  35. Name string `gorm:"column:name"`
  36. Type string `gorm:"column:type"`
  37. }
  38. type legacyToken struct {
  39. Id int `gorm:"primaryKey"`
  40. UserId int `gorm:"index"`
  41. Key string `gorm:"column:key;type:char(48);uniqueIndex"`
  42. Status int `gorm:"default:1"`
  43. Name string `gorm:"index"`
  44. CreatedTime int64 `gorm:"bigint"`
  45. AccessedTime int64 `gorm:"bigint"`
  46. ExpiredTime int64 `gorm:"bigint;default:-1"`
  47. RemainQuota int `gorm:"default:0"`
  48. UnlimitedQuota bool
  49. ModelLimitsEnabled bool
  50. ModelLimits string `gorm:"type:text"`
  51. AllowIps *string `gorm:"default:''"`
  52. UsedQuota int `gorm:"default:0"`
  53. Group string `gorm:"column:group;default:''"`
  54. CrossGroupRetry bool
  55. DeletedAt gorm.DeletedAt `gorm:"index"`
  56. }
  57. func (legacyToken) TableName() string {
  58. return "tokens"
  59. }
  60. func openTokenControllerTestDB(t *testing.T) *gorm.DB {
  61. t.Helper()
  62. gin.SetMode(gin.TestMode)
  63. common.UsingSQLite = true
  64. common.UsingMySQL = false
  65. common.UsingPostgreSQL = false
  66. common.RedisEnabled = false
  67. dsn := fmt.Sprintf("file:%s?mode=memory&cache=shared", strings.ReplaceAll(t.Name(), "/", "_"))
  68. db, err := gorm.Open(sqlite.Open(dsn), &gorm.Config{})
  69. if err != nil {
  70. t.Fatalf("failed to open sqlite db: %v", err)
  71. }
  72. model.DB = db
  73. model.LOG_DB = db
  74. t.Cleanup(func() {
  75. sqlDB, err := db.DB()
  76. if err == nil {
  77. _ = sqlDB.Close()
  78. }
  79. })
  80. return db
  81. }
  82. func migrateTokenControllerTestDB(t *testing.T, db *gorm.DB) {
  83. t.Helper()
  84. if err := db.AutoMigrate(&model.Token{}); err != nil {
  85. t.Fatalf("failed to migrate token table: %v", err)
  86. }
  87. }
  88. func setupTokenControllerTestDB(t *testing.T) *gorm.DB {
  89. t.Helper()
  90. db := openTokenControllerTestDB(t)
  91. migrateTokenControllerTestDB(t, db)
  92. return db
  93. }
  94. func seedToken(t *testing.T, db *gorm.DB, userID int, name string, rawKey string) *model.Token {
  95. t.Helper()
  96. token := &model.Token{
  97. UserId: userID,
  98. Name: name,
  99. Key: rawKey,
  100. Status: common.TokenStatusEnabled,
  101. CreatedTime: 1,
  102. AccessedTime: 1,
  103. ExpiredTime: -1,
  104. RemainQuota: 100,
  105. UnlimitedQuota: true,
  106. Group: "default",
  107. }
  108. if err := db.Create(token).Error; err != nil {
  109. t.Fatalf("failed to create token: %v", err)
  110. }
  111. return token
  112. }
  113. func newAuthenticatedContext(t *testing.T, method string, target string, body any, userID int) (*gin.Context, *httptest.ResponseRecorder) {
  114. t.Helper()
  115. var requestBody *bytes.Reader
  116. if body != nil {
  117. payload, err := common.Marshal(body)
  118. if err != nil {
  119. t.Fatalf("failed to marshal request body: %v", err)
  120. }
  121. requestBody = bytes.NewReader(payload)
  122. } else {
  123. requestBody = bytes.NewReader(nil)
  124. }
  125. recorder := httptest.NewRecorder()
  126. ctx, _ := gin.CreateTestContext(recorder)
  127. ctx.Request = httptest.NewRequest(method, target, requestBody)
  128. if body != nil {
  129. ctx.Request.Header.Set("Content-Type", "application/json")
  130. }
  131. ctx.Set("id", userID)
  132. return ctx, recorder
  133. }
  134. func decodeAPIResponse(t *testing.T, recorder *httptest.ResponseRecorder) tokenAPIResponse {
  135. t.Helper()
  136. var response tokenAPIResponse
  137. if err := common.Unmarshal(recorder.Body.Bytes(), &response); err != nil {
  138. t.Fatalf("failed to decode api response: %v", err)
  139. }
  140. return response
  141. }
  142. func getSQLiteColumnType(t *testing.T, db *gorm.DB, tableName string, columnName string) string {
  143. t.Helper()
  144. var columns []sqliteColumnInfo
  145. if err := db.Raw("PRAGMA table_info(" + tableName + ")").Scan(&columns).Error; err != nil {
  146. t.Fatalf("failed to inspect %s schema: %v", tableName, err)
  147. }
  148. for _, column := range columns {
  149. if column.Name == columnName {
  150. return strings.ToLower(column.Type)
  151. }
  152. }
  153. t.Fatalf("column %s not found in %s schema", columnName, tableName)
  154. return ""
  155. }
  156. func TestTokenAutoMigrateUsesVarchar128KeyColumn(t *testing.T) {
  157. db := setupTokenControllerTestDB(t)
  158. if got := getSQLiteColumnType(t, db, "tokens", "key"); got != "varchar(128)" {
  159. t.Fatalf("expected key column type varchar(128), got %q", got)
  160. }
  161. }
  162. func TestTokenMigrationFromChar48ToVarchar128(t *testing.T) {
  163. db := openTokenControllerTestDB(t)
  164. legacyKey := strings.Repeat("a", 48)
  165. if err := db.AutoMigrate(&legacyToken{}); err != nil {
  166. t.Fatalf("failed to create legacy token schema: %v", err)
  167. }
  168. if err := db.Create(&legacyToken{
  169. Id: 1,
  170. UserId: 7,
  171. Key: legacyKey,
  172. Status: common.TokenStatusEnabled,
  173. Name: "legacy-token",
  174. CreatedTime: 1,
  175. AccessedTime: 1,
  176. ExpiredTime: -1,
  177. RemainQuota: 100,
  178. UnlimitedQuota: true,
  179. ModelLimitsEnabled: false,
  180. ModelLimits: "",
  181. AllowIps: common.GetPointer(""),
  182. UsedQuota: 0,
  183. Group: "default",
  184. CrossGroupRetry: false,
  185. }).Error; err != nil {
  186. t.Fatalf("failed to seed legacy token row: %v", err)
  187. }
  188. if got := getSQLiteColumnType(t, db, "tokens", "key"); got != "char(48)" {
  189. t.Fatalf("expected legacy key column type char(48), got %q", got)
  190. }
  191. migrateTokenControllerTestDB(t, db)
  192. if got := getSQLiteColumnType(t, db, "tokens", "key"); got != "varchar(128)" {
  193. t.Fatalf("expected migrated key column type varchar(128), got %q", got)
  194. }
  195. var migratedToken model.Token
  196. if err := db.First(&migratedToken, "id = ?", 1).Error; err != nil {
  197. t.Fatalf("failed to load migrated token row: %v", err)
  198. }
  199. if migratedToken.Key != legacyKey {
  200. t.Fatalf("expected migrated token key %q, got %q", legacyKey, migratedToken.Key)
  201. }
  202. if migratedToken.Name != "legacy-token" {
  203. t.Fatalf("expected migrated token name to be preserved, got %q", migratedToken.Name)
  204. }
  205. }
  206. func TestGetAllTokensMasksKeyInResponse(t *testing.T) {
  207. db := setupTokenControllerTestDB(t)
  208. token := seedToken(t, db, 1, "list-token", "abcd1234efgh5678")
  209. seedToken(t, db, 2, "other-user-token", "zzzz1234yyyy5678")
  210. ctx, recorder := newAuthenticatedContext(t, http.MethodGet, "/api/token/?p=1&size=10", nil, 1)
  211. GetAllTokens(ctx)
  212. response := decodeAPIResponse(t, recorder)
  213. if !response.Success {
  214. t.Fatalf("expected success response, got message: %s", response.Message)
  215. }
  216. var page tokenPageResponse
  217. if err := common.Unmarshal(response.Data, &page); err != nil {
  218. t.Fatalf("failed to decode token page response: %v", err)
  219. }
  220. if len(page.Items) != 1 {
  221. t.Fatalf("expected exactly one token, got %d", len(page.Items))
  222. }
  223. if page.Items[0].Key != token.GetMaskedKey() {
  224. t.Fatalf("expected masked key %q, got %q", token.GetMaskedKey(), page.Items[0].Key)
  225. }
  226. if strings.Contains(recorder.Body.String(), token.Key) {
  227. t.Fatalf("list response leaked raw token key: %s", recorder.Body.String())
  228. }
  229. }
  230. func TestSearchTokensMasksKeyInResponse(t *testing.T) {
  231. db := setupTokenControllerTestDB(t)
  232. token := seedToken(t, db, 1, "searchable-token", "ijkl1234mnop5678")
  233. ctx, recorder := newAuthenticatedContext(t, http.MethodGet, "/api/token/search?keyword=searchable-token&p=1&size=10", nil, 1)
  234. SearchTokens(ctx)
  235. response := decodeAPIResponse(t, recorder)
  236. if !response.Success {
  237. t.Fatalf("expected success response, got message: %s", response.Message)
  238. }
  239. var page tokenPageResponse
  240. if err := common.Unmarshal(response.Data, &page); err != nil {
  241. t.Fatalf("failed to decode search response: %v", err)
  242. }
  243. if len(page.Items) != 1 {
  244. t.Fatalf("expected exactly one search result, got %d", len(page.Items))
  245. }
  246. if page.Items[0].Key != token.GetMaskedKey() {
  247. t.Fatalf("expected masked search key %q, got %q", token.GetMaskedKey(), page.Items[0].Key)
  248. }
  249. if strings.Contains(recorder.Body.String(), token.Key) {
  250. t.Fatalf("search response leaked raw token key: %s", recorder.Body.String())
  251. }
  252. }
  253. func TestGetTokenMasksKeyInResponse(t *testing.T) {
  254. db := setupTokenControllerTestDB(t)
  255. token := seedToken(t, db, 1, "detail-token", "qrst1234uvwx5678")
  256. ctx, recorder := newAuthenticatedContext(t, http.MethodGet, "/api/token/"+strconv.Itoa(token.Id), nil, 1)
  257. ctx.Params = gin.Params{{Key: "id", Value: strconv.Itoa(token.Id)}}
  258. GetToken(ctx)
  259. response := decodeAPIResponse(t, recorder)
  260. if !response.Success {
  261. t.Fatalf("expected success response, got message: %s", response.Message)
  262. }
  263. var detail tokenResponseItem
  264. if err := common.Unmarshal(response.Data, &detail); err != nil {
  265. t.Fatalf("failed to decode token detail response: %v", err)
  266. }
  267. if detail.Key != token.GetMaskedKey() {
  268. t.Fatalf("expected masked detail key %q, got %q", token.GetMaskedKey(), detail.Key)
  269. }
  270. if strings.Contains(recorder.Body.String(), token.Key) {
  271. t.Fatalf("detail response leaked raw token key: %s", recorder.Body.String())
  272. }
  273. }
  274. func TestUpdateTokenMasksKeyInResponse(t *testing.T) {
  275. db := setupTokenControllerTestDB(t)
  276. token := seedToken(t, db, 1, "editable-token", "yzab1234cdef5678")
  277. body := map[string]any{
  278. "id": token.Id,
  279. "name": "updated-token",
  280. "expired_time": -1,
  281. "remain_quota": 100,
  282. "unlimited_quota": true,
  283. "model_limits_enabled": false,
  284. "model_limits": "",
  285. "group": "default",
  286. "cross_group_retry": false,
  287. }
  288. ctx, recorder := newAuthenticatedContext(t, http.MethodPut, "/api/token/", body, 1)
  289. UpdateToken(ctx)
  290. response := decodeAPIResponse(t, recorder)
  291. if !response.Success {
  292. t.Fatalf("expected success response, got message: %s", response.Message)
  293. }
  294. var detail tokenResponseItem
  295. if err := common.Unmarshal(response.Data, &detail); err != nil {
  296. t.Fatalf("failed to decode token update response: %v", err)
  297. }
  298. if detail.Key != token.GetMaskedKey() {
  299. t.Fatalf("expected masked update key %q, got %q", token.GetMaskedKey(), detail.Key)
  300. }
  301. if strings.Contains(recorder.Body.String(), token.Key) {
  302. t.Fatalf("update response leaked raw token key: %s", recorder.Body.String())
  303. }
  304. }
  305. func TestGetTokenKeyRequiresOwnershipAndReturnsFullKey(t *testing.T) {
  306. db := setupTokenControllerTestDB(t)
  307. token := seedToken(t, db, 1, "owned-token", "owner1234token5678")
  308. authorizedCtx, authorizedRecorder := newAuthenticatedContext(t, http.MethodPost, "/api/token/"+strconv.Itoa(token.Id)+"/key", nil, 1)
  309. authorizedCtx.Params = gin.Params{{Key: "id", Value: strconv.Itoa(token.Id)}}
  310. GetTokenKey(authorizedCtx)
  311. authorizedResponse := decodeAPIResponse(t, authorizedRecorder)
  312. if !authorizedResponse.Success {
  313. t.Fatalf("expected authorized key fetch to succeed, got message: %s", authorizedResponse.Message)
  314. }
  315. var keyData tokenKeyResponse
  316. if err := common.Unmarshal(authorizedResponse.Data, &keyData); err != nil {
  317. t.Fatalf("failed to decode token key response: %v", err)
  318. }
  319. if keyData.Key != token.GetFullKey() {
  320. t.Fatalf("expected full key %q, got %q", token.GetFullKey(), keyData.Key)
  321. }
  322. unauthorizedCtx, unauthorizedRecorder := newAuthenticatedContext(t, http.MethodPost, "/api/token/"+strconv.Itoa(token.Id)+"/key", nil, 2)
  323. unauthorizedCtx.Params = gin.Params{{Key: "id", Value: strconv.Itoa(token.Id)}}
  324. GetTokenKey(unauthorizedCtx)
  325. unauthorizedResponse := decodeAPIResponse(t, unauthorizedRecorder)
  326. if unauthorizedResponse.Success {
  327. t.Fatalf("expected unauthorized key fetch to fail")
  328. }
  329. if strings.Contains(unauthorizedRecorder.Body.String(), token.Key) {
  330. t.Fatalf("unauthorized key response leaked raw token key: %s", unauthorizedRecorder.Body.String())
  331. }
  332. }