Quellcode durchsuchen

Merge pull request #4401 from XiaoAI1024/codex/legacy-token-key-compat

Relax token key column length for legacy migration compatibility
Calcium-Ion vor 2 Wochen
Ursprung
Commit
11f8d42d66
2 geänderte Dateien mit 269 neuen und 3 gelöschten Zeilen
  1. 268 2
      controller/token_test.go
  2. 1 1
      model/token.go

+ 268 - 2
controller/token_test.go

@@ -2,10 +2,12 @@ package controller
 
 import (
 	"bytes"
+	"database/sql"
 	"encoding/json"
 	"fmt"
 	"net/http"
 	"net/http/httptest"
+	"os"
 	"strconv"
 	"strings"
 	"testing"
@@ -14,6 +16,8 @@ import (
 	"github.com/QuantumNous/new-api/model"
 	"github.com/gin-gonic/gin"
 	"github.com/glebarez/sqlite"
+	"gorm.io/driver/mysql"
+	"gorm.io/driver/postgres"
 	"gorm.io/gorm"
 )
 
@@ -38,7 +42,36 @@ type tokenKeyResponse struct {
 	Key string `json:"key"`
 }
 
-func setupTokenControllerTestDB(t *testing.T) *gorm.DB {
+type sqliteColumnInfo struct {
+	Name string `gorm:"column:name"`
+	Type string `gorm:"column:type"`
+}
+
+type legacyToken struct {
+	Id                 int            `gorm:"primaryKey"`
+	UserId             int            `gorm:"index"`
+	Key                string         `gorm:"column:key;type:char(48);uniqueIndex"`
+	Status             int            `gorm:"default:1"`
+	Name               string         `gorm:"index"`
+	CreatedTime        int64          `gorm:"bigint"`
+	AccessedTime       int64          `gorm:"bigint"`
+	ExpiredTime        int64          `gorm:"bigint;default:-1"`
+	RemainQuota        int            `gorm:"default:0"`
+	UnlimitedQuota     bool
+	ModelLimitsEnabled bool
+	ModelLimits        string         `gorm:"type:text"`
+	AllowIps           *string        `gorm:"default:''"`
+	UsedQuota          int            `gorm:"default:0"`
+	Group              string         `gorm:"column:group;default:''"`
+	CrossGroupRetry    bool
+	DeletedAt          gorm.DeletedAt `gorm:"index"`
+}
+
+func (legacyToken) TableName() string {
+	return "tokens"
+}
+
+func openTokenControllerTestDB(t *testing.T) *gorm.DB {
 	t.Helper()
 
 	gin.SetMode(gin.TestMode)
@@ -55,18 +88,77 @@ func setupTokenControllerTestDB(t *testing.T) *gorm.DB {
 	model.DB = db
 	model.LOG_DB = db
 
+	t.Cleanup(func() {
+		sqlDB, err := db.DB()
+		if err == nil {
+			_ = sqlDB.Close()
+		}
+	})
+
+	return db
+}
+
+func migrateTokenControllerTestDB(t *testing.T, db *gorm.DB) {
+	t.Helper()
+
 	if err := db.AutoMigrate(&model.Token{}); err != nil {
 		t.Fatalf("failed to migrate token table: %v", err)
 	}
+}
+
+func setupTokenControllerTestDB(t *testing.T) *gorm.DB {
+	t.Helper()
+
+	db := openTokenControllerTestDB(t)
+	migrateTokenControllerTestDB(t, db)
+	return db
+}
+
+func openTokenControllerExternalDB(t *testing.T, dialect string, dsn string) (*gorm.DB, *bool) {
+	t.Helper()
+
+	gin.SetMode(gin.TestMode)
+	common.RedisEnabled = false
+	common.UsingSQLite = false
+	common.UsingMySQL = dialect == "mysql"
+	common.UsingPostgreSQL = dialect == "postgres"
+
+	var (
+		db  *gorm.DB
+		err error
+	)
+	switch dialect {
+	case "mysql":
+		db, err = gorm.Open(mysql.Open(dsn), &gorm.Config{})
+	case "postgres":
+		db, err = gorm.Open(postgres.Open(dsn), &gorm.Config{})
+	default:
+		t.Fatalf("unsupported dialect %q", dialect)
+	}
+	if err != nil {
+		t.Fatalf("failed to open %s db: %v", dialect, err)
+	}
+
+	model.DB = db
+	model.LOG_DB = db
+
+	if db.Migrator().HasTable("tokens") {
+		t.Skipf("refusing to run %s migration compatibility test against external database because tokens table already exists", dialect)
+	}
+
+	managedTokensTable := new(bool)
 
 	t.Cleanup(func() {
+		if *managedTokensTable && db.Migrator().HasTable("tokens") {
+			_ = db.Migrator().DropTable("tokens")
+		}
 		sqlDB, err := db.DB()
 		if err == nil {
 			_ = sqlDB.Close()
 		}
 	})
 
-	return db
+	return db, managedTokensTable
 }
 
 func seedToken(t *testing.T, db *gorm.DB, userID int, name string, rawKey string) *model.Token {
@@ -124,6 +216,180 @@ func decodeAPIResponse(t *testing.T, recorder *httptest.ResponseRecorder) tokenA
 	return response
 }
 
+func getSQLiteColumnType(t *testing.T, db *gorm.DB, tableName string, columnName string) string {
+	t.Helper()
+
+	var columns []sqliteColumnInfo
+	if err := db.Raw("PRAGMA table_info(" + tableName + ")").Scan(&columns).Error; err != nil {
+		t.Fatalf("failed to inspect %s schema: %v", tableName, err)
+	}
+
+	for _, column := range columns {
+		if column.Name == columnName {
+			return strings.ToLower(column.Type)
+		}
+	}
+
+	t.Fatalf("column %s not found in %s schema", columnName, tableName)
+	return ""
+}
+
+func getTokenKeyColumnType(t *testing.T, db *gorm.DB, dialect string) string {
+	t.Helper()
+
+	switch dialect {
+	case "sqlite":
+		return getSQLiteColumnType(t, db, "tokens", "key")
+	case "mysql":
+		var columnType string
+		if err := db.Raw(`SELECT COLUMN_TYPE FROM information_schema.columns
+			WHERE table_schema = DATABASE() AND table_name = ? AND column_name = ?`,
+			"tokens", "key").Scan(&columnType).Error; err != nil {
+			t.Fatalf("failed to inspect mysql token key column: %v", err)
+		}
+		return strings.ToLower(columnType)
+	case "postgres":
+		var dataType string
+		var maxLength sql.NullInt64
+		if err := db.Raw(`SELECT data_type, character_maximum_length
+			FROM information_schema.columns
+			WHERE table_schema = current_schema() AND table_name = ? AND column_name = ?`,
+			"tokens", "key").Row().Scan(&dataType, &maxLength); err != nil {
+			t.Fatalf("failed to inspect postgres token key column: %v", err)
+		}
+		switch strings.ToLower(dataType) {
+		case "character varying":
+			return fmt.Sprintf("varchar(%d)", maxLength.Int64)
+		case "character":
+			return fmt.Sprintf("char(%d)", maxLength.Int64)
+		default:
+			if maxLength.Valid {
+				return fmt.Sprintf("%s(%d)", strings.ToLower(dataType), maxLength.Int64)
+			}
+			return strings.ToLower(dataType)
+		}
+	default:
+		t.Fatalf("unsupported dialect %q", dialect)
+		return ""
+	}
+}
+
+func runTokenMigrationCompatibilityTest(t *testing.T, db *gorm.DB, dialect string, managedTokensTable *bool) {
+	t.Helper()
+
+	legacyKey := strings.Repeat("a", 48)
+	longKey := strings.Repeat("b", 64)
+
+	if err := db.AutoMigrate(&legacyToken{}); err != nil {
+		t.Fatalf("failed to create legacy token schema: %v", err)
+	}
+	if managedTokensTable != nil {
+		*managedTokensTable = true
+	}
+	if err := db.Create(&legacyToken{
+		UserId:             7,
+		Key:                legacyKey,
+		Status:             common.TokenStatusEnabled,
+		Name:               "legacy-token",
+		CreatedTime:        1,
+		AccessedTime:       1,
+		ExpiredTime:        -1,
+		RemainQuota:        100,
+		UnlimitedQuota:     true,
+		ModelLimitsEnabled: false,
+		ModelLimits:        "",
+		AllowIps:           common.GetPointer(""),
+		UsedQuota:          0,
+		Group:              "default",
+		CrossGroupRetry:    false,
+	}).Error; err != nil {
+		t.Fatalf("failed to seed legacy token row: %v", err)
+	}
+
+	if got := getTokenKeyColumnType(t, db, dialect); got != "char(48)" {
+		t.Fatalf("expected legacy key column type char(48), got %q", got)
+	}
+
+	migrateTokenControllerTestDB(t, db)
+
+	if got := getTokenKeyColumnType(t, db, dialect); got != "varchar(128)" {
+		t.Fatalf("expected migrated key column type varchar(128), got %q", got)
+	}
+
+	var migratedToken model.Token
+	if err := db.First(&migratedToken, "name = ?", "legacy-token").Error; err != nil {
+		t.Fatalf("failed to load migrated token row: %v", err)
+	}
+	if migratedToken.Key != legacyKey {
+		t.Fatalf("expected migrated token key %q, got %q", legacyKey, migratedToken.Key)
+	}
+	if migratedToken.Name != "legacy-token" {
+		t.Fatalf("expected migrated token name to be preserved, got %q", migratedToken.Name)
+	}
+
+	inserted := model.Token{
+		UserId:             8,
+		Name:               "long-token",
+		Key:                longKey,
+		Status:             common.TokenStatusEnabled,
+		CreatedTime:        1,
+		AccessedTime:       1,
+		ExpiredTime:        -1,
+		RemainQuota:        200,
+		UnlimitedQuota:     true,
+		ModelLimitsEnabled: false,
+		ModelLimits:        "",
+		AllowIps:           common.GetPointer(""),
+		UsedQuota:          0,
+		Group:              "default",
+		CrossGroupRetry:    false,
+	}
+	if err := db.Create(&inserted).Error; err != nil {
+		t.Fatalf("failed to insert long token after migration: %v", err)
+	}
+
+	var fetched model.Token
+	if err := db.First(&fetched, "id = ?", inserted.Id).Error; err != nil {
+		t.Fatalf("failed to fetch long token after migration: %v", err)
+	}
+	if fetched.Key != longKey {
+		t.Fatalf("expected long token key %q, got %q", longKey, fetched.Key)
+	}
+}
+
+func TestTokenAutoMigrateUsesVarchar128KeyColumn(t *testing.T) {
+	db := setupTokenControllerTestDB(t)
+
+	if got := getTokenKeyColumnType(t, db, "sqlite"); got != "varchar(128)" {
+		t.Fatalf("expected key column type varchar(128), got %q", got)
+	}
+}
+
+func TestTokenMigrationFromChar48ToVarchar128(t *testing.T) {
+	db := openTokenControllerTestDB(t)
+	runTokenMigrationCompatibilityTest(t, db, "sqlite", nil)
+}
+
+func TestTokenMigrationFromChar48ToVarchar128MySQL(t *testing.T) {
+	dsn := os.Getenv("TEST_MYSQL_DSN")
+	if dsn == "" {
+		t.Skip("set TEST_MYSQL_DSN to run mysql migration compatibility test")
+	}
+
+	db, managedTokensTable := openTokenControllerExternalDB(t, "mysql", dsn)
+	runTokenMigrationCompatibilityTest(t, db, "mysql", managedTokensTable)
+}
+
+func TestTokenMigrationFromChar48ToVarchar128Postgres(t *testing.T) {
+	dsn := os.Getenv("TEST_POSTGRES_DSN")
+	if dsn == "" {
+		t.Skip("set TEST_POSTGRES_DSN to run postgres migration compatibility test")
+	}
+
+	db, managedTokensTable := openTokenControllerExternalDB(t, "postgres", dsn)
+	runTokenMigrationCompatibilityTest(t, db, "postgres", managedTokensTable)
+}
+
 func TestGetAllTokensMasksKeyInResponse(t *testing.T) {
 	db := setupTokenControllerTestDB(t)
 	token := seedToken(t, db, 1, "list-token", "abcd1234efgh5678")

+ 1 - 1
model/token.go

@@ -14,7 +14,7 @@ import (
 type Token struct {
 	Id                 int            `json:"id"`
 	UserId             int            `json:"user_id" gorm:"index"`
-	Key                string         `json:"key" gorm:"type:char(48);uniqueIndex"`
+	Key                string         `json:"key" gorm:"type:varchar(128);uniqueIndex"`
 	Status             int            `json:"status" gorm:"default:1"`
 	Name               string         `json:"name" gorm:"index" `
 	CreatedTime        int64          `json:"created_time" gorm:"bigint"`