| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217 |
- package model
- import (
- "encoding/json"
- "os"
- "sync"
- "testing"
- "time"
- "github.com/QuantumNous/new-api/common"
- "github.com/glebarez/sqlite"
- "github.com/stretchr/testify/assert"
- "github.com/stretchr/testify/require"
- "gorm.io/gorm"
- )
- func TestMain(m *testing.M) {
- db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
- if err != nil {
- panic("failed to open test db: " + err.Error())
- }
- DB = db
- LOG_DB = db
- common.UsingSQLite = true
- common.RedisEnabled = false
- common.BatchUpdateEnabled = false
- common.LogConsumeEnabled = true
- sqlDB, err := db.DB()
- if err != nil {
- panic("failed to get sql.DB: " + err.Error())
- }
- sqlDB.SetMaxOpenConns(1)
- if err := db.AutoMigrate(&Task{}, &User{}, &Token{}, &Log{}, &Channel{}); err != nil {
- panic("failed to migrate: " + err.Error())
- }
- os.Exit(m.Run())
- }
- func truncateTables(t *testing.T) {
- t.Helper()
- t.Cleanup(func() {
- DB.Exec("DELETE FROM tasks")
- DB.Exec("DELETE FROM users")
- DB.Exec("DELETE FROM tokens")
- DB.Exec("DELETE FROM logs")
- DB.Exec("DELETE FROM channels")
- })
- }
- func insertTask(t *testing.T, task *Task) {
- t.Helper()
- task.CreatedAt = time.Now().Unix()
- task.UpdatedAt = time.Now().Unix()
- require.NoError(t, DB.Create(task).Error)
- }
- // ---------------------------------------------------------------------------
- // Snapshot / Equal — pure logic tests (no DB)
- // ---------------------------------------------------------------------------
- func TestSnapshotEqual_Same(t *testing.T) {
- s := taskSnapshot{
- Status: TaskStatusInProgress,
- Progress: "50%",
- StartTime: 1000,
- FinishTime: 0,
- FailReason: "",
- ResultURL: "",
- Data: json.RawMessage(`{"key":"value"}`),
- }
- assert.True(t, s.Equal(s))
- }
- func TestSnapshotEqual_DifferentStatus(t *testing.T) {
- a := taskSnapshot{Status: TaskStatusInProgress, Data: json.RawMessage(`{}`)}
- b := taskSnapshot{Status: TaskStatusSuccess, Data: json.RawMessage(`{}`)}
- assert.False(t, a.Equal(b))
- }
- func TestSnapshotEqual_DifferentProgress(t *testing.T) {
- a := taskSnapshot{Status: TaskStatusInProgress, Progress: "30%", Data: json.RawMessage(`{}`)}
- b := taskSnapshot{Status: TaskStatusInProgress, Progress: "60%", Data: json.RawMessage(`{}`)}
- assert.False(t, a.Equal(b))
- }
- func TestSnapshotEqual_DifferentData(t *testing.T) {
- a := taskSnapshot{Status: TaskStatusInProgress, Data: json.RawMessage(`{"a":1}`)}
- b := taskSnapshot{Status: TaskStatusInProgress, Data: json.RawMessage(`{"a":2}`)}
- assert.False(t, a.Equal(b))
- }
- func TestSnapshotEqual_NilVsEmpty(t *testing.T) {
- a := taskSnapshot{Status: TaskStatusInProgress, Data: nil}
- b := taskSnapshot{Status: TaskStatusInProgress, Data: json.RawMessage{}}
- // bytes.Equal(nil, []byte{}) == true
- assert.True(t, a.Equal(b))
- }
- func TestSnapshot_Roundtrip(t *testing.T) {
- task := &Task{
- Status: TaskStatusInProgress,
- Progress: "42%",
- StartTime: 1234,
- FinishTime: 5678,
- FailReason: "timeout",
- PrivateData: TaskPrivateData{
- ResultURL: "https://example.com/result.mp4",
- },
- Data: json.RawMessage(`{"model":"test-model"}`),
- }
- snap := task.Snapshot()
- assert.Equal(t, task.Status, snap.Status)
- assert.Equal(t, task.Progress, snap.Progress)
- assert.Equal(t, task.StartTime, snap.StartTime)
- assert.Equal(t, task.FinishTime, snap.FinishTime)
- assert.Equal(t, task.FailReason, snap.FailReason)
- assert.Equal(t, task.PrivateData.ResultURL, snap.ResultURL)
- assert.JSONEq(t, string(task.Data), string(snap.Data))
- }
- // ---------------------------------------------------------------------------
- // UpdateWithStatus CAS — DB integration tests
- // ---------------------------------------------------------------------------
- func TestUpdateWithStatus_Win(t *testing.T) {
- truncateTables(t)
- task := &Task{
- TaskID: "task_cas_win",
- Status: TaskStatusInProgress,
- Progress: "50%",
- Data: json.RawMessage(`{}`),
- }
- insertTask(t, task)
- task.Status = TaskStatusSuccess
- task.Progress = "100%"
- won, err := task.UpdateWithStatus(TaskStatusInProgress)
- require.NoError(t, err)
- assert.True(t, won)
- var reloaded Task
- require.NoError(t, DB.First(&reloaded, task.ID).Error)
- assert.EqualValues(t, TaskStatusSuccess, reloaded.Status)
- assert.Equal(t, "100%", reloaded.Progress)
- }
- func TestUpdateWithStatus_Lose(t *testing.T) {
- truncateTables(t)
- task := &Task{
- TaskID: "task_cas_lose",
- Status: TaskStatusFailure,
- Data: json.RawMessage(`{}`),
- }
- insertTask(t, task)
- task.Status = TaskStatusSuccess
- won, err := task.UpdateWithStatus(TaskStatusInProgress) // wrong fromStatus
- require.NoError(t, err)
- assert.False(t, won)
- var reloaded Task
- require.NoError(t, DB.First(&reloaded, task.ID).Error)
- assert.EqualValues(t, TaskStatusFailure, reloaded.Status) // unchanged
- }
- func TestUpdateWithStatus_ConcurrentWinner(t *testing.T) {
- truncateTables(t)
- task := &Task{
- TaskID: "task_cas_race",
- Status: TaskStatusInProgress,
- Quota: 1000,
- Data: json.RawMessage(`{}`),
- }
- insertTask(t, task)
- const goroutines = 5
- wins := make([]bool, goroutines)
- var wg sync.WaitGroup
- wg.Add(goroutines)
- for i := 0; i < goroutines; i++ {
- go func(idx int) {
- defer wg.Done()
- t := &Task{}
- *t = Task{
- ID: task.ID,
- TaskID: task.TaskID,
- Status: TaskStatusSuccess,
- Progress: "100%",
- Quota: task.Quota,
- Data: json.RawMessage(`{}`),
- }
- t.CreatedAt = task.CreatedAt
- t.UpdatedAt = time.Now().Unix()
- won, err := t.UpdateWithStatus(TaskStatusInProgress)
- if err == nil {
- wins[idx] = won
- }
- }(i)
- }
- wg.Wait()
- winCount := 0
- for _, w := range wins {
- if w {
- winCount++
- }
- }
- assert.Equal(t, 1, winCount, "exactly one goroutine should win the CAS")
- }
|