task_cas_test.go 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217
  1. package model
  2. import (
  3. "encoding/json"
  4. "os"
  5. "sync"
  6. "testing"
  7. "time"
  8. "github.com/QuantumNous/new-api/common"
  9. "github.com/glebarez/sqlite"
  10. "github.com/stretchr/testify/assert"
  11. "github.com/stretchr/testify/require"
  12. "gorm.io/gorm"
  13. )
  14. func TestMain(m *testing.M) {
  15. db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
  16. if err != nil {
  17. panic("failed to open test db: " + err.Error())
  18. }
  19. DB = db
  20. LOG_DB = db
  21. common.UsingSQLite = true
  22. common.RedisEnabled = false
  23. common.BatchUpdateEnabled = false
  24. common.LogConsumeEnabled = true
  25. sqlDB, err := db.DB()
  26. if err != nil {
  27. panic("failed to get sql.DB: " + err.Error())
  28. }
  29. sqlDB.SetMaxOpenConns(1)
  30. if err := db.AutoMigrate(&Task{}, &User{}, &Token{}, &Log{}, &Channel{}); err != nil {
  31. panic("failed to migrate: " + err.Error())
  32. }
  33. os.Exit(m.Run())
  34. }
  35. func truncateTables(t *testing.T) {
  36. t.Helper()
  37. t.Cleanup(func() {
  38. DB.Exec("DELETE FROM tasks")
  39. DB.Exec("DELETE FROM users")
  40. DB.Exec("DELETE FROM tokens")
  41. DB.Exec("DELETE FROM logs")
  42. DB.Exec("DELETE FROM channels")
  43. })
  44. }
  45. func insertTask(t *testing.T, task *Task) {
  46. t.Helper()
  47. task.CreatedAt = time.Now().Unix()
  48. task.UpdatedAt = time.Now().Unix()
  49. require.NoError(t, DB.Create(task).Error)
  50. }
  51. // ---------------------------------------------------------------------------
  52. // Snapshot / Equal — pure logic tests (no DB)
  53. // ---------------------------------------------------------------------------
  54. func TestSnapshotEqual_Same(t *testing.T) {
  55. s := taskSnapshot{
  56. Status: TaskStatusInProgress,
  57. Progress: "50%",
  58. StartTime: 1000,
  59. FinishTime: 0,
  60. FailReason: "",
  61. ResultURL: "",
  62. Data: json.RawMessage(`{"key":"value"}`),
  63. }
  64. assert.True(t, s.Equal(s))
  65. }
  66. func TestSnapshotEqual_DifferentStatus(t *testing.T) {
  67. a := taskSnapshot{Status: TaskStatusInProgress, Data: json.RawMessage(`{}`)}
  68. b := taskSnapshot{Status: TaskStatusSuccess, Data: json.RawMessage(`{}`)}
  69. assert.False(t, a.Equal(b))
  70. }
  71. func TestSnapshotEqual_DifferentProgress(t *testing.T) {
  72. a := taskSnapshot{Status: TaskStatusInProgress, Progress: "30%", Data: json.RawMessage(`{}`)}
  73. b := taskSnapshot{Status: TaskStatusInProgress, Progress: "60%", Data: json.RawMessage(`{}`)}
  74. assert.False(t, a.Equal(b))
  75. }
  76. func TestSnapshotEqual_DifferentData(t *testing.T) {
  77. a := taskSnapshot{Status: TaskStatusInProgress, Data: json.RawMessage(`{"a":1}`)}
  78. b := taskSnapshot{Status: TaskStatusInProgress, Data: json.RawMessage(`{"a":2}`)}
  79. assert.False(t, a.Equal(b))
  80. }
  81. func TestSnapshotEqual_NilVsEmpty(t *testing.T) {
  82. a := taskSnapshot{Status: TaskStatusInProgress, Data: nil}
  83. b := taskSnapshot{Status: TaskStatusInProgress, Data: json.RawMessage{}}
  84. // bytes.Equal(nil, []byte{}) == true
  85. assert.True(t, a.Equal(b))
  86. }
  87. func TestSnapshot_Roundtrip(t *testing.T) {
  88. task := &Task{
  89. Status: TaskStatusInProgress,
  90. Progress: "42%",
  91. StartTime: 1234,
  92. FinishTime: 5678,
  93. FailReason: "timeout",
  94. PrivateData: TaskPrivateData{
  95. ResultURL: "https://example.com/result.mp4",
  96. },
  97. Data: json.RawMessage(`{"model":"test-model"}`),
  98. }
  99. snap := task.Snapshot()
  100. assert.Equal(t, task.Status, snap.Status)
  101. assert.Equal(t, task.Progress, snap.Progress)
  102. assert.Equal(t, task.StartTime, snap.StartTime)
  103. assert.Equal(t, task.FinishTime, snap.FinishTime)
  104. assert.Equal(t, task.FailReason, snap.FailReason)
  105. assert.Equal(t, task.PrivateData.ResultURL, snap.ResultURL)
  106. assert.JSONEq(t, string(task.Data), string(snap.Data))
  107. }
  108. // ---------------------------------------------------------------------------
  109. // UpdateWithStatus CAS — DB integration tests
  110. // ---------------------------------------------------------------------------
  111. func TestUpdateWithStatus_Win(t *testing.T) {
  112. truncateTables(t)
  113. task := &Task{
  114. TaskID: "task_cas_win",
  115. Status: TaskStatusInProgress,
  116. Progress: "50%",
  117. Data: json.RawMessage(`{}`),
  118. }
  119. insertTask(t, task)
  120. task.Status = TaskStatusSuccess
  121. task.Progress = "100%"
  122. won, err := task.UpdateWithStatus(TaskStatusInProgress)
  123. require.NoError(t, err)
  124. assert.True(t, won)
  125. var reloaded Task
  126. require.NoError(t, DB.First(&reloaded, task.ID).Error)
  127. assert.EqualValues(t, TaskStatusSuccess, reloaded.Status)
  128. assert.Equal(t, "100%", reloaded.Progress)
  129. }
  130. func TestUpdateWithStatus_Lose(t *testing.T) {
  131. truncateTables(t)
  132. task := &Task{
  133. TaskID: "task_cas_lose",
  134. Status: TaskStatusFailure,
  135. Data: json.RawMessage(`{}`),
  136. }
  137. insertTask(t, task)
  138. task.Status = TaskStatusSuccess
  139. won, err := task.UpdateWithStatus(TaskStatusInProgress) // wrong fromStatus
  140. require.NoError(t, err)
  141. assert.False(t, won)
  142. var reloaded Task
  143. require.NoError(t, DB.First(&reloaded, task.ID).Error)
  144. assert.EqualValues(t, TaskStatusFailure, reloaded.Status) // unchanged
  145. }
  146. func TestUpdateWithStatus_ConcurrentWinner(t *testing.T) {
  147. truncateTables(t)
  148. task := &Task{
  149. TaskID: "task_cas_race",
  150. Status: TaskStatusInProgress,
  151. Quota: 1000,
  152. Data: json.RawMessage(`{}`),
  153. }
  154. insertTask(t, task)
  155. const goroutines = 5
  156. wins := make([]bool, goroutines)
  157. var wg sync.WaitGroup
  158. wg.Add(goroutines)
  159. for i := 0; i < goroutines; i++ {
  160. go func(idx int) {
  161. defer wg.Done()
  162. t := &Task{}
  163. *t = Task{
  164. ID: task.ID,
  165. TaskID: task.TaskID,
  166. Status: TaskStatusSuccess,
  167. Progress: "100%",
  168. Quota: task.Quota,
  169. Data: json.RawMessage(`{}`),
  170. }
  171. t.CreatedAt = task.CreatedAt
  172. t.UpdatedAt = time.Now().Unix()
  173. won, err := t.UpdateWithStatus(TaskStatusInProgress)
  174. if err == nil {
  175. wins[idx] = won
  176. }
  177. }(i)
  178. }
  179. wg.Wait()
  180. winCount := 0
  181. for _, w := range wins {
  182. if w {
  183. winCount++
  184. }
  185. }
  186. assert.Equal(t, 1, winCount, "exactly one goroutine should win the CAS")
  187. }