|
|
@@ -2,10 +2,12 @@ package controller
|
|
|
|
|
|
import (
|
|
|
"bytes"
|
|
|
+ "database/sql"
|
|
|
"encoding/json"
|
|
|
"fmt"
|
|
|
"net/http"
|
|
|
"net/http/httptest"
|
|
|
+ "os"
|
|
|
"strconv"
|
|
|
"strings"
|
|
|
"testing"
|
|
|
@@ -14,6 +16,8 @@ import (
|
|
|
"github.com/QuantumNous/new-api/model"
|
|
|
"github.com/gin-gonic/gin"
|
|
|
"github.com/glebarez/sqlite"
|
|
|
+ "gorm.io/driver/mysql"
|
|
|
+ "gorm.io/driver/postgres"
|
|
|
"gorm.io/gorm"
|
|
|
)
|
|
|
|
|
|
@@ -110,6 +114,45 @@ func setupTokenControllerTestDB(t *testing.T) *gorm.DB {
|
|
|
return db
|
|
|
}
|
|
|
|
|
|
+func openTokenControllerExternalDB(t *testing.T, dialect string, dsn string) *gorm.DB {
|
|
|
+ t.Helper()
|
|
|
+
|
|
|
+ gin.SetMode(gin.TestMode)
|
|
|
+ common.RedisEnabled = false
|
|
|
+ common.UsingSQLite = false
|
|
|
+ common.UsingMySQL = dialect == "mysql"
|
|
|
+ common.UsingPostgreSQL = dialect == "postgres"
|
|
|
+
|
|
|
+ var (
|
|
|
+ db *gorm.DB
|
|
|
+ err error
|
|
|
+ )
|
|
|
+ switch dialect {
|
|
|
+ case "mysql":
|
|
|
+ db, err = gorm.Open(mysql.Open(dsn), &gorm.Config{})
|
|
|
+ case "postgres":
|
|
|
+ db, err = gorm.Open(postgres.Open(dsn), &gorm.Config{})
|
|
|
+ default:
|
|
|
+ t.Fatalf("unsupported dialect %q", dialect)
|
|
|
+ }
|
|
|
+ if err != nil {
|
|
|
+ t.Fatalf("failed to open %s db: %v", dialect, err)
|
|
|
+ }
|
|
|
+
|
|
|
+ model.DB = db
|
|
|
+ model.LOG_DB = db
|
|
|
+
|
|
|
+ t.Cleanup(func() {
|
|
|
+ _ = db.Exec("DROP TABLE IF EXISTS tokens").Error
|
|
|
+ sqlDB, err := db.DB()
|
|
|
+ if err == nil {
|
|
|
+ _ = sqlDB.Close()
|
|
|
+ }
|
|
|
+ })
|
|
|
+
|
|
|
+ return db
|
|
|
+}
|
|
|
+
|
|
|
func seedToken(t *testing.T, db *gorm.DB, userID int, name string, rawKey string) *model.Token {
|
|
|
t.Helper()
|
|
|
|
|
|
@@ -183,23 +226,59 @@ func getSQLiteColumnType(t *testing.T, db *gorm.DB, tableName string, columnName
|
|
|
return ""
|
|
|
}
|
|
|
|
|
|
-func TestTokenAutoMigrateUsesVarchar128KeyColumn(t *testing.T) {
|
|
|
- db := setupTokenControllerTestDB(t)
|
|
|
+func getTokenKeyColumnType(t *testing.T, db *gorm.DB, dialect string) string {
|
|
|
+ t.Helper()
|
|
|
|
|
|
- if got := getSQLiteColumnType(t, db, "tokens", "key"); got != "varchar(128)" {
|
|
|
- t.Fatalf("expected key column type varchar(128), got %q", got)
|
|
|
+ switch dialect {
|
|
|
+ case "sqlite":
|
|
|
+ return getSQLiteColumnType(t, db, "tokens", "key")
|
|
|
+ case "mysql":
|
|
|
+ var columnType string
|
|
|
+ if err := db.Raw(`SELECT COLUMN_TYPE FROM information_schema.columns
|
|
|
+ WHERE table_schema = DATABASE() AND table_name = ? AND column_name = ?`,
|
|
|
+ "tokens", "key").Scan(&columnType).Error; err != nil {
|
|
|
+ t.Fatalf("failed to inspect mysql token key column: %v", err)
|
|
|
+ }
|
|
|
+ return strings.ToLower(columnType)
|
|
|
+ case "postgres":
|
|
|
+ var dataType string
|
|
|
+ var maxLength sql.NullInt64
|
|
|
+ if err := db.Raw(`SELECT data_type, character_maximum_length
|
|
|
+ FROM information_schema.columns
|
|
|
+ WHERE table_schema = current_schema() AND table_name = ? AND column_name = ?`,
|
|
|
+ "tokens", "key").Row().Scan(&dataType, &maxLength); err != nil {
|
|
|
+ t.Fatalf("failed to inspect postgres token key column: %v", err)
|
|
|
+ }
|
|
|
+ switch strings.ToLower(dataType) {
|
|
|
+ case "character varying":
|
|
|
+ return fmt.Sprintf("varchar(%d)", maxLength.Int64)
|
|
|
+ case "character":
|
|
|
+ return fmt.Sprintf("char(%d)", maxLength.Int64)
|
|
|
+ default:
|
|
|
+ if maxLength.Valid {
|
|
|
+ return fmt.Sprintf("%s(%d)", strings.ToLower(dataType), maxLength.Int64)
|
|
|
+ }
|
|
|
+ return strings.ToLower(dataType)
|
|
|
+ }
|
|
|
+ default:
|
|
|
+ t.Fatalf("unsupported dialect %q", dialect)
|
|
|
+ return ""
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-func TestTokenMigrationFromChar48ToVarchar128(t *testing.T) {
|
|
|
- db := openTokenControllerTestDB(t)
|
|
|
+func runTokenMigrationCompatibilityTest(t *testing.T, db *gorm.DB, dialect string) {
|
|
|
+ t.Helper()
|
|
|
+
|
|
|
legacyKey := strings.Repeat("a", 48)
|
|
|
+ longKey := strings.Repeat("b", 64)
|
|
|
|
|
|
+ if err := db.Exec("DROP TABLE IF EXISTS tokens").Error; err != nil {
|
|
|
+ t.Fatalf("failed to drop pre-existing token table: %v", err)
|
|
|
+ }
|
|
|
if err := db.AutoMigrate(&legacyToken{}); err != nil {
|
|
|
t.Fatalf("failed to create legacy token schema: %v", err)
|
|
|
}
|
|
|
if err := db.Create(&legacyToken{
|
|
|
- Id: 1,
|
|
|
UserId: 7,
|
|
|
Key: legacyKey,
|
|
|
Status: common.TokenStatusEnabled,
|
|
|
@@ -219,18 +298,18 @@ func TestTokenMigrationFromChar48ToVarchar128(t *testing.T) {
|
|
|
t.Fatalf("failed to seed legacy token row: %v", err)
|
|
|
}
|
|
|
|
|
|
- if got := getSQLiteColumnType(t, db, "tokens", "key"); got != "char(48)" {
|
|
|
+ if got := getTokenKeyColumnType(t, db, dialect); got != "char(48)" {
|
|
|
t.Fatalf("expected legacy key column type char(48), got %q", got)
|
|
|
}
|
|
|
|
|
|
migrateTokenControllerTestDB(t, db)
|
|
|
|
|
|
- if got := getSQLiteColumnType(t, db, "tokens", "key"); got != "varchar(128)" {
|
|
|
+ if got := getTokenKeyColumnType(t, db, dialect); got != "varchar(128)" {
|
|
|
t.Fatalf("expected migrated key column type varchar(128), got %q", got)
|
|
|
}
|
|
|
|
|
|
var migratedToken model.Token
|
|
|
- if err := db.First(&migratedToken, "id = ?", 1).Error; err != nil {
|
|
|
+ if err := db.First(&migratedToken, "name = ?", "legacy-token").Error; err != nil {
|
|
|
t.Fatalf("failed to load migrated token row: %v", err)
|
|
|
}
|
|
|
if migratedToken.Key != legacyKey {
|
|
|
@@ -239,6 +318,68 @@ func TestTokenMigrationFromChar48ToVarchar128(t *testing.T) {
|
|
|
if migratedToken.Name != "legacy-token" {
|
|
|
t.Fatalf("expected migrated token name to be preserved, got %q", migratedToken.Name)
|
|
|
}
|
|
|
+
|
|
|
+ inserted := model.Token{
|
|
|
+ UserId: 8,
|
|
|
+ Name: "long-token",
|
|
|
+ Key: longKey,
|
|
|
+ Status: common.TokenStatusEnabled,
|
|
|
+ CreatedTime: 1,
|
|
|
+ AccessedTime: 1,
|
|
|
+ ExpiredTime: -1,
|
|
|
+ RemainQuota: 200,
|
|
|
+ UnlimitedQuota: true,
|
|
|
+ ModelLimitsEnabled: false,
|
|
|
+ ModelLimits: "",
|
|
|
+ AllowIps: common.GetPointer(""),
|
|
|
+ UsedQuota: 0,
|
|
|
+ Group: "default",
|
|
|
+ CrossGroupRetry: false,
|
|
|
+ }
|
|
|
+ if err := db.Create(&inserted).Error; err != nil {
|
|
|
+ t.Fatalf("failed to insert long token after migration: %v", err)
|
|
|
+ }
|
|
|
+
|
|
|
+ var fetched model.Token
|
|
|
+ if err := db.First(&fetched, "id = ?", inserted.Id).Error; err != nil {
|
|
|
+ t.Fatalf("failed to fetch long token after migration: %v", err)
|
|
|
+ }
|
|
|
+ if fetched.Key != longKey {
|
|
|
+ t.Fatalf("expected long token key %q, got %q", longKey, fetched.Key)
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func TestTokenAutoMigrateUsesVarchar128KeyColumn(t *testing.T) {
|
|
|
+ db := setupTokenControllerTestDB(t)
|
|
|
+
|
|
|
+ if got := getTokenKeyColumnType(t, db, "sqlite"); got != "varchar(128)" {
|
|
|
+ t.Fatalf("expected key column type varchar(128), got %q", got)
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func TestTokenMigrationFromChar48ToVarchar128(t *testing.T) {
|
|
|
+ db := openTokenControllerTestDB(t)
|
|
|
+ runTokenMigrationCompatibilityTest(t, db, "sqlite")
|
|
|
+}
|
|
|
+
|
|
|
+func TestTokenMigrationFromChar48ToVarchar128MySQL(t *testing.T) {
|
|
|
+ dsn := os.Getenv("TEST_MYSQL_DSN")
|
|
|
+ if dsn == "" {
|
|
|
+ t.Skip("set TEST_MYSQL_DSN to run mysql migration compatibility test")
|
|
|
+ }
|
|
|
+
|
|
|
+ db := openTokenControllerExternalDB(t, "mysql", dsn)
|
|
|
+ runTokenMigrationCompatibilityTest(t, db, "mysql")
|
|
|
+}
|
|
|
+
|
|
|
+func TestTokenMigrationFromChar48ToVarchar128Postgres(t *testing.T) {
|
|
|
+ dsn := os.Getenv("TEST_POSTGRES_DSN")
|
|
|
+ if dsn == "" {
|
|
|
+ t.Skip("set TEST_POSTGRES_DSN to run postgres migration compatibility test")
|
|
|
+ }
|
|
|
+
|
|
|
+ db := openTokenControllerExternalDB(t, "postgres", dsn)
|
|
|
+ runTokenMigrationCompatibilityTest(t, db, "postgres")
|
|
|
}
|
|
|
|
|
|
func TestGetAllTokensMasksKeyInResponse(t *testing.T) {
|