| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541 |
- package controller
- import (
- "bytes"
- "database/sql"
- "encoding/json"
- "fmt"
- "net/http"
- "net/http/httptest"
- "os"
- "strconv"
- "strings"
- "testing"
- "github.com/QuantumNous/new-api/common"
- "github.com/QuantumNous/new-api/model"
- "github.com/gin-gonic/gin"
- "github.com/glebarez/sqlite"
- "gorm.io/driver/mysql"
- "gorm.io/driver/postgres"
- "gorm.io/gorm"
- )
- type tokenAPIResponse struct {
- Success bool `json:"success"`
- Message string `json:"message"`
- Data json.RawMessage `json:"data"`
- }
- type tokenPageResponse struct {
- Items []tokenResponseItem `json:"items"`
- }
- type tokenResponseItem struct {
- ID int `json:"id"`
- Name string `json:"name"`
- Key string `json:"key"`
- Status int `json:"status"`
- }
- type tokenKeyResponse struct {
- Key string `json:"key"`
- }
- type sqliteColumnInfo struct {
- Name string `gorm:"column:name"`
- Type string `gorm:"column:type"`
- }
- type legacyToken struct {
- Id int `gorm:"primaryKey"`
- UserId int `gorm:"index"`
- Key string `gorm:"column:key;type:char(48);uniqueIndex"`
- Status int `gorm:"default:1"`
- Name string `gorm:"index"`
- CreatedTime int64 `gorm:"bigint"`
- AccessedTime int64 `gorm:"bigint"`
- ExpiredTime int64 `gorm:"bigint;default:-1"`
- RemainQuota int `gorm:"default:0"`
- UnlimitedQuota bool
- ModelLimitsEnabled bool
- ModelLimits string `gorm:"type:text"`
- AllowIps *string `gorm:"default:''"`
- UsedQuota int `gorm:"default:0"`
- Group string `gorm:"column:group;default:''"`
- CrossGroupRetry bool
- DeletedAt gorm.DeletedAt `gorm:"index"`
- }
- func (legacyToken) TableName() string {
- return "tokens"
- }
- func openTokenControllerTestDB(t *testing.T) *gorm.DB {
- t.Helper()
- gin.SetMode(gin.TestMode)
- common.UsingSQLite = true
- common.UsingMySQL = false
- common.UsingPostgreSQL = false
- common.RedisEnabled = false
- dsn := fmt.Sprintf("file:%s?mode=memory&cache=shared", strings.ReplaceAll(t.Name(), "/", "_"))
- db, err := gorm.Open(sqlite.Open(dsn), &gorm.Config{})
- if err != nil {
- t.Fatalf("failed to open sqlite db: %v", err)
- }
- model.DB = db
- model.LOG_DB = db
- t.Cleanup(func() {
- sqlDB, err := db.DB()
- if err == nil {
- _ = sqlDB.Close()
- }
- })
- return db
- }
- func migrateTokenControllerTestDB(t *testing.T, db *gorm.DB) {
- t.Helper()
- if err := db.AutoMigrate(&model.Token{}); err != nil {
- t.Fatalf("failed to migrate token table: %v", err)
- }
- }
- func setupTokenControllerTestDB(t *testing.T) *gorm.DB {
- t.Helper()
- db := openTokenControllerTestDB(t)
- migrateTokenControllerTestDB(t, db)
- return db
- }
- func openTokenControllerExternalDB(t *testing.T, dialect string, dsn string) (*gorm.DB, *bool) {
- t.Helper()
- gin.SetMode(gin.TestMode)
- common.RedisEnabled = false
- common.UsingSQLite = false
- common.UsingMySQL = dialect == "mysql"
- common.UsingPostgreSQL = dialect == "postgres"
- var (
- db *gorm.DB
- err error
- )
- switch dialect {
- case "mysql":
- db, err = gorm.Open(mysql.Open(dsn), &gorm.Config{})
- case "postgres":
- db, err = gorm.Open(postgres.Open(dsn), &gorm.Config{})
- default:
- t.Fatalf("unsupported dialect %q", dialect)
- }
- if err != nil {
- t.Fatalf("failed to open %s db: %v", dialect, err)
- }
- model.DB = db
- model.LOG_DB = db
- if db.Migrator().HasTable("tokens") {
- t.Skipf("refusing to run %s migration compatibility test against external database because tokens table already exists", dialect)
- }
- managedTokensTable := new(bool)
- t.Cleanup(func() {
- if *managedTokensTable && db.Migrator().HasTable("tokens") {
- _ = db.Migrator().DropTable("tokens")
- }
- sqlDB, err := db.DB()
- if err == nil {
- _ = sqlDB.Close()
- }
- })
- return db, managedTokensTable
- }
- func seedToken(t *testing.T, db *gorm.DB, userID int, name string, rawKey string) *model.Token {
- t.Helper()
- token := &model.Token{
- UserId: userID,
- Name: name,
- Key: rawKey,
- Status: common.TokenStatusEnabled,
- CreatedTime: 1,
- AccessedTime: 1,
- ExpiredTime: -1,
- RemainQuota: 100,
- UnlimitedQuota: true,
- Group: "default",
- }
- if err := db.Create(token).Error; err != nil {
- t.Fatalf("failed to create token: %v", err)
- }
- return token
- }
- func newAuthenticatedContext(t *testing.T, method string, target string, body any, userID int) (*gin.Context, *httptest.ResponseRecorder) {
- t.Helper()
- var requestBody *bytes.Reader
- if body != nil {
- payload, err := common.Marshal(body)
- if err != nil {
- t.Fatalf("failed to marshal request body: %v", err)
- }
- requestBody = bytes.NewReader(payload)
- } else {
- requestBody = bytes.NewReader(nil)
- }
- recorder := httptest.NewRecorder()
- ctx, _ := gin.CreateTestContext(recorder)
- ctx.Request = httptest.NewRequest(method, target, requestBody)
- if body != nil {
- ctx.Request.Header.Set("Content-Type", "application/json")
- }
- ctx.Set("id", userID)
- return ctx, recorder
- }
- func decodeAPIResponse(t *testing.T, recorder *httptest.ResponseRecorder) tokenAPIResponse {
- t.Helper()
- var response tokenAPIResponse
- if err := common.Unmarshal(recorder.Body.Bytes(), &response); err != nil {
- t.Fatalf("failed to decode api response: %v", err)
- }
- return response
- }
- func getSQLiteColumnType(t *testing.T, db *gorm.DB, tableName string, columnName string) string {
- t.Helper()
- var columns []sqliteColumnInfo
- if err := db.Raw("PRAGMA table_info(" + tableName + ")").Scan(&columns).Error; err != nil {
- t.Fatalf("failed to inspect %s schema: %v", tableName, err)
- }
- for _, column := range columns {
- if column.Name == columnName {
- return strings.ToLower(column.Type)
- }
- }
- t.Fatalf("column %s not found in %s schema", columnName, tableName)
- return ""
- }
- func getTokenKeyColumnType(t *testing.T, db *gorm.DB, dialect string) string {
- t.Helper()
- switch dialect {
- case "sqlite":
- return getSQLiteColumnType(t, db, "tokens", "key")
- case "mysql":
- var columnType string
- if err := db.Raw(`SELECT COLUMN_TYPE FROM information_schema.columns
- WHERE table_schema = DATABASE() AND table_name = ? AND column_name = ?`,
- "tokens", "key").Scan(&columnType).Error; err != nil {
- t.Fatalf("failed to inspect mysql token key column: %v", err)
- }
- return strings.ToLower(columnType)
- case "postgres":
- var dataType string
- var maxLength sql.NullInt64
- if err := db.Raw(`SELECT data_type, character_maximum_length
- FROM information_schema.columns
- WHERE table_schema = current_schema() AND table_name = ? AND column_name = ?`,
- "tokens", "key").Row().Scan(&dataType, &maxLength); err != nil {
- t.Fatalf("failed to inspect postgres token key column: %v", err)
- }
- switch strings.ToLower(dataType) {
- case "character varying":
- return fmt.Sprintf("varchar(%d)", maxLength.Int64)
- case "character":
- return fmt.Sprintf("char(%d)", maxLength.Int64)
- default:
- if maxLength.Valid {
- return fmt.Sprintf("%s(%d)", strings.ToLower(dataType), maxLength.Int64)
- }
- return strings.ToLower(dataType)
- }
- default:
- t.Fatalf("unsupported dialect %q", dialect)
- return ""
- }
- }
- func runTokenMigrationCompatibilityTest(t *testing.T, db *gorm.DB, dialect string, managedTokensTable *bool) {
- t.Helper()
- legacyKey := strings.Repeat("a", 48)
- longKey := strings.Repeat("b", 64)
- if err := db.AutoMigrate(&legacyToken{}); err != nil {
- t.Fatalf("failed to create legacy token schema: %v", err)
- }
- if managedTokensTable != nil {
- *managedTokensTable = true
- }
- if err := db.Create(&legacyToken{
- UserId: 7,
- Key: legacyKey,
- Status: common.TokenStatusEnabled,
- Name: "legacy-token",
- CreatedTime: 1,
- AccessedTime: 1,
- ExpiredTime: -1,
- RemainQuota: 100,
- UnlimitedQuota: true,
- ModelLimitsEnabled: false,
- ModelLimits: "",
- AllowIps: common.GetPointer(""),
- UsedQuota: 0,
- Group: "default",
- CrossGroupRetry: false,
- }).Error; err != nil {
- t.Fatalf("failed to seed legacy token row: %v", err)
- }
- if got := getTokenKeyColumnType(t, db, dialect); got != "char(48)" {
- t.Fatalf("expected legacy key column type char(48), got %q", got)
- }
- migrateTokenControllerTestDB(t, db)
- if got := getTokenKeyColumnType(t, db, dialect); got != "varchar(128)" {
- t.Fatalf("expected migrated key column type varchar(128), got %q", got)
- }
- var migratedToken model.Token
- if err := db.First(&migratedToken, "name = ?", "legacy-token").Error; err != nil {
- t.Fatalf("failed to load migrated token row: %v", err)
- }
- if migratedToken.Key != legacyKey {
- t.Fatalf("expected migrated token key %q, got %q", legacyKey, migratedToken.Key)
- }
- if migratedToken.Name != "legacy-token" {
- t.Fatalf("expected migrated token name to be preserved, got %q", migratedToken.Name)
- }
- inserted := model.Token{
- UserId: 8,
- Name: "long-token",
- Key: longKey,
- Status: common.TokenStatusEnabled,
- CreatedTime: 1,
- AccessedTime: 1,
- ExpiredTime: -1,
- RemainQuota: 200,
- UnlimitedQuota: true,
- ModelLimitsEnabled: false,
- ModelLimits: "",
- AllowIps: common.GetPointer(""),
- UsedQuota: 0,
- Group: "default",
- CrossGroupRetry: false,
- }
- if err := db.Create(&inserted).Error; err != nil {
- t.Fatalf("failed to insert long token after migration: %v", err)
- }
- var fetched model.Token
- if err := db.First(&fetched, "id = ?", inserted.Id).Error; err != nil {
- t.Fatalf("failed to fetch long token after migration: %v", err)
- }
- if fetched.Key != longKey {
- t.Fatalf("expected long token key %q, got %q", longKey, fetched.Key)
- }
- }
- func TestTokenAutoMigrateUsesVarchar128KeyColumn(t *testing.T) {
- db := setupTokenControllerTestDB(t)
- if got := getTokenKeyColumnType(t, db, "sqlite"); got != "varchar(128)" {
- t.Fatalf("expected key column type varchar(128), got %q", got)
- }
- }
- func TestTokenMigrationFromChar48ToVarchar128(t *testing.T) {
- db := openTokenControllerTestDB(t)
- runTokenMigrationCompatibilityTest(t, db, "sqlite", nil)
- }
- func TestTokenMigrationFromChar48ToVarchar128MySQL(t *testing.T) {
- dsn := os.Getenv("TEST_MYSQL_DSN")
- if dsn == "" {
- t.Skip("set TEST_MYSQL_DSN to run mysql migration compatibility test")
- }
- db, managedTokensTable := openTokenControllerExternalDB(t, "mysql", dsn)
- runTokenMigrationCompatibilityTest(t, db, "mysql", managedTokensTable)
- }
- func TestTokenMigrationFromChar48ToVarchar128Postgres(t *testing.T) {
- dsn := os.Getenv("TEST_POSTGRES_DSN")
- if dsn == "" {
- t.Skip("set TEST_POSTGRES_DSN to run postgres migration compatibility test")
- }
- db, managedTokensTable := openTokenControllerExternalDB(t, "postgres", dsn)
- runTokenMigrationCompatibilityTest(t, db, "postgres", managedTokensTable)
- }
- func TestGetAllTokensMasksKeyInResponse(t *testing.T) {
- db := setupTokenControllerTestDB(t)
- token := seedToken(t, db, 1, "list-token", "abcd1234efgh5678")
- seedToken(t, db, 2, "other-user-token", "zzzz1234yyyy5678")
- ctx, recorder := newAuthenticatedContext(t, http.MethodGet, "/api/token/?p=1&size=10", nil, 1)
- GetAllTokens(ctx)
- response := decodeAPIResponse(t, recorder)
- if !response.Success {
- t.Fatalf("expected success response, got message: %s", response.Message)
- }
- var page tokenPageResponse
- if err := common.Unmarshal(response.Data, &page); err != nil {
- t.Fatalf("failed to decode token page response: %v", err)
- }
- if len(page.Items) != 1 {
- t.Fatalf("expected exactly one token, got %d", len(page.Items))
- }
- if page.Items[0].Key != token.GetMaskedKey() {
- t.Fatalf("expected masked key %q, got %q", token.GetMaskedKey(), page.Items[0].Key)
- }
- if strings.Contains(recorder.Body.String(), token.Key) {
- t.Fatalf("list response leaked raw token key: %s", recorder.Body.String())
- }
- }
- func TestSearchTokensMasksKeyInResponse(t *testing.T) {
- db := setupTokenControllerTestDB(t)
- token := seedToken(t, db, 1, "searchable-token", "ijkl1234mnop5678")
- ctx, recorder := newAuthenticatedContext(t, http.MethodGet, "/api/token/search?keyword=searchable-token&p=1&size=10", nil, 1)
- SearchTokens(ctx)
- response := decodeAPIResponse(t, recorder)
- if !response.Success {
- t.Fatalf("expected success response, got message: %s", response.Message)
- }
- var page tokenPageResponse
- if err := common.Unmarshal(response.Data, &page); err != nil {
- t.Fatalf("failed to decode search response: %v", err)
- }
- if len(page.Items) != 1 {
- t.Fatalf("expected exactly one search result, got %d", len(page.Items))
- }
- if page.Items[0].Key != token.GetMaskedKey() {
- t.Fatalf("expected masked search key %q, got %q", token.GetMaskedKey(), page.Items[0].Key)
- }
- if strings.Contains(recorder.Body.String(), token.Key) {
- t.Fatalf("search response leaked raw token key: %s", recorder.Body.String())
- }
- }
- func TestGetTokenMasksKeyInResponse(t *testing.T) {
- db := setupTokenControllerTestDB(t)
- token := seedToken(t, db, 1, "detail-token", "qrst1234uvwx5678")
- ctx, recorder := newAuthenticatedContext(t, http.MethodGet, "/api/token/"+strconv.Itoa(token.Id), nil, 1)
- ctx.Params = gin.Params{{Key: "id", Value: strconv.Itoa(token.Id)}}
- GetToken(ctx)
- response := decodeAPIResponse(t, recorder)
- if !response.Success {
- t.Fatalf("expected success response, got message: %s", response.Message)
- }
- var detail tokenResponseItem
- if err := common.Unmarshal(response.Data, &detail); err != nil {
- t.Fatalf("failed to decode token detail response: %v", err)
- }
- if detail.Key != token.GetMaskedKey() {
- t.Fatalf("expected masked detail key %q, got %q", token.GetMaskedKey(), detail.Key)
- }
- if strings.Contains(recorder.Body.String(), token.Key) {
- t.Fatalf("detail response leaked raw token key: %s", recorder.Body.String())
- }
- }
- func TestUpdateTokenMasksKeyInResponse(t *testing.T) {
- db := setupTokenControllerTestDB(t)
- token := seedToken(t, db, 1, "editable-token", "yzab1234cdef5678")
- body := map[string]any{
- "id": token.Id,
- "name": "updated-token",
- "expired_time": -1,
- "remain_quota": 100,
- "unlimited_quota": true,
- "model_limits_enabled": false,
- "model_limits": "",
- "group": "default",
- "cross_group_retry": false,
- }
- ctx, recorder := newAuthenticatedContext(t, http.MethodPut, "/api/token/", body, 1)
- UpdateToken(ctx)
- response := decodeAPIResponse(t, recorder)
- if !response.Success {
- t.Fatalf("expected success response, got message: %s", response.Message)
- }
- var detail tokenResponseItem
- if err := common.Unmarshal(response.Data, &detail); err != nil {
- t.Fatalf("failed to decode token update response: %v", err)
- }
- if detail.Key != token.GetMaskedKey() {
- t.Fatalf("expected masked update key %q, got %q", token.GetMaskedKey(), detail.Key)
- }
- if strings.Contains(recorder.Body.String(), token.Key) {
- t.Fatalf("update response leaked raw token key: %s", recorder.Body.String())
- }
- }
- func TestGetTokenKeyRequiresOwnershipAndReturnsFullKey(t *testing.T) {
- db := setupTokenControllerTestDB(t)
- token := seedToken(t, db, 1, "owned-token", "owner1234token5678")
- authorizedCtx, authorizedRecorder := newAuthenticatedContext(t, http.MethodPost, "/api/token/"+strconv.Itoa(token.Id)+"/key", nil, 1)
- authorizedCtx.Params = gin.Params{{Key: "id", Value: strconv.Itoa(token.Id)}}
- GetTokenKey(authorizedCtx)
- authorizedResponse := decodeAPIResponse(t, authorizedRecorder)
- if !authorizedResponse.Success {
- t.Fatalf("expected authorized key fetch to succeed, got message: %s", authorizedResponse.Message)
- }
- var keyData tokenKeyResponse
- if err := common.Unmarshal(authorizedResponse.Data, &keyData); err != nil {
- t.Fatalf("failed to decode token key response: %v", err)
- }
- if keyData.Key != token.GetFullKey() {
- t.Fatalf("expected full key %q, got %q", token.GetFullKey(), keyData.Key)
- }
- unauthorizedCtx, unauthorizedRecorder := newAuthenticatedContext(t, http.MethodPost, "/api/token/"+strconv.Itoa(token.Id)+"/key", nil, 2)
- unauthorizedCtx.Params = gin.Params{{Key: "id", Value: strconv.Itoa(token.Id)}}
- GetTokenKey(unauthorizedCtx)
- unauthorizedResponse := decodeAPIResponse(t, unauthorizedRecorder)
- if unauthorizedResponse.Success {
- t.Fatalf("expected unauthorized key fetch to fail")
- }
- if strings.Contains(unauthorizedRecorder.Body.String(), token.Key) {
- t.Fatalf("unauthorized key response leaked raw token key: %s", unauthorizedRecorder.Body.String())
- }
- }
|