package model import ( "encoding/json" "os" "sync" "testing" "time" "github.com/QuantumNous/new-api/common" "github.com/glebarez/sqlite" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "gorm.io/gorm" ) func TestMain(m *testing.M) { db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{}) if err != nil { panic("failed to open test db: " + err.Error()) } DB = db LOG_DB = db common.UsingSQLite = true common.RedisEnabled = false common.BatchUpdateEnabled = false common.LogConsumeEnabled = true sqlDB, err := db.DB() if err != nil { panic("failed to get sql.DB: " + err.Error()) } sqlDB.SetMaxOpenConns(1) if err := db.AutoMigrate(&Task{}, &User{}, &Token{}, &Log{}, &Channel{}); err != nil { panic("failed to migrate: " + err.Error()) } os.Exit(m.Run()) } func truncateTables(t *testing.T) { t.Helper() t.Cleanup(func() { DB.Exec("DELETE FROM tasks") DB.Exec("DELETE FROM users") DB.Exec("DELETE FROM tokens") DB.Exec("DELETE FROM logs") DB.Exec("DELETE FROM channels") }) } func insertTask(t *testing.T, task *Task) { t.Helper() task.CreatedAt = time.Now().Unix() task.UpdatedAt = time.Now().Unix() require.NoError(t, DB.Create(task).Error) } // --------------------------------------------------------------------------- // Snapshot / Equal — pure logic tests (no DB) // --------------------------------------------------------------------------- func TestSnapshotEqual_Same(t *testing.T) { s := taskSnapshot{ Status: TaskStatusInProgress, Progress: "50%", StartTime: 1000, FinishTime: 0, FailReason: "", ResultURL: "", Data: json.RawMessage(`{"key":"value"}`), } assert.True(t, s.Equal(s)) } func TestSnapshotEqual_DifferentStatus(t *testing.T) { a := taskSnapshot{Status: TaskStatusInProgress, Data: json.RawMessage(`{}`)} b := taskSnapshot{Status: TaskStatusSuccess, Data: json.RawMessage(`{}`)} assert.False(t, a.Equal(b)) } func TestSnapshotEqual_DifferentProgress(t *testing.T) { a := taskSnapshot{Status: TaskStatusInProgress, Progress: "30%", Data: json.RawMessage(`{}`)} b := taskSnapshot{Status: TaskStatusInProgress, Progress: "60%", Data: json.RawMessage(`{}`)} assert.False(t, a.Equal(b)) } func TestSnapshotEqual_DifferentData(t *testing.T) { a := taskSnapshot{Status: TaskStatusInProgress, Data: json.RawMessage(`{"a":1}`)} b := taskSnapshot{Status: TaskStatusInProgress, Data: json.RawMessage(`{"a":2}`)} assert.False(t, a.Equal(b)) } func TestSnapshotEqual_NilVsEmpty(t *testing.T) { a := taskSnapshot{Status: TaskStatusInProgress, Data: nil} b := taskSnapshot{Status: TaskStatusInProgress, Data: json.RawMessage{}} // bytes.Equal(nil, []byte{}) == true assert.True(t, a.Equal(b)) } func TestSnapshot_Roundtrip(t *testing.T) { task := &Task{ Status: TaskStatusInProgress, Progress: "42%", StartTime: 1234, FinishTime: 5678, FailReason: "timeout", PrivateData: TaskPrivateData{ ResultURL: "https://example.com/result.mp4", }, Data: json.RawMessage(`{"model":"test-model"}`), } snap := task.Snapshot() assert.Equal(t, task.Status, snap.Status) assert.Equal(t, task.Progress, snap.Progress) assert.Equal(t, task.StartTime, snap.StartTime) assert.Equal(t, task.FinishTime, snap.FinishTime) assert.Equal(t, task.FailReason, snap.FailReason) assert.Equal(t, task.PrivateData.ResultURL, snap.ResultURL) assert.JSONEq(t, string(task.Data), string(snap.Data)) } // --------------------------------------------------------------------------- // UpdateWithStatus CAS — DB integration tests // --------------------------------------------------------------------------- func TestUpdateWithStatus_Win(t *testing.T) { truncateTables(t) task := &Task{ TaskID: "task_cas_win", Status: TaskStatusInProgress, Progress: "50%", Data: json.RawMessage(`{}`), } insertTask(t, task) task.Status = TaskStatusSuccess task.Progress = "100%" won, err := task.UpdateWithStatus(TaskStatusInProgress) require.NoError(t, err) assert.True(t, won) var reloaded Task require.NoError(t, DB.First(&reloaded, task.ID).Error) assert.EqualValues(t, TaskStatusSuccess, reloaded.Status) assert.Equal(t, "100%", reloaded.Progress) } func TestUpdateWithStatus_Lose(t *testing.T) { truncateTables(t) task := &Task{ TaskID: "task_cas_lose", Status: TaskStatusFailure, Data: json.RawMessage(`{}`), } insertTask(t, task) task.Status = TaskStatusSuccess won, err := task.UpdateWithStatus(TaskStatusInProgress) // wrong fromStatus require.NoError(t, err) assert.False(t, won) var reloaded Task require.NoError(t, DB.First(&reloaded, task.ID).Error) assert.EqualValues(t, TaskStatusFailure, reloaded.Status) // unchanged } func TestUpdateWithStatus_ConcurrentWinner(t *testing.T) { truncateTables(t) task := &Task{ TaskID: "task_cas_race", Status: TaskStatusInProgress, Quota: 1000, Data: json.RawMessage(`{}`), } insertTask(t, task) const goroutines = 5 wins := make([]bool, goroutines) var wg sync.WaitGroup wg.Add(goroutines) for i := 0; i < goroutines; i++ { go func(idx int) { defer wg.Done() t := &Task{} *t = Task{ ID: task.ID, TaskID: task.TaskID, Status: TaskStatusSuccess, Progress: "100%", Quota: task.Quota, Data: json.RawMessage(`{}`), } t.CreatedAt = task.CreatedAt t.UpdatedAt = time.Now().Unix() won, err := t.UpdateWithStatus(TaskStatusInProgress) if err == nil { wins[idx] = won } }(i) } wg.Wait() winCount := 0 for _, w := range wins { if w { winCount++ } } assert.Equal(t, 1, winCount, "exactly one goroutine should win the CAS") }