task_cas_test.go 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231
  1. package model
  2. import (
  3. "encoding/json"
  4. "os"
  5. "sync"
  6. "testing"
  7. "time"
  8. "github.com/QuantumNous/new-api/common"
  9. "github.com/glebarez/sqlite"
  10. "github.com/stretchr/testify/assert"
  11. "github.com/stretchr/testify/require"
  12. "gorm.io/gorm"
  13. )
  14. func TestMain(m *testing.M) {
  15. db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
  16. if err != nil {
  17. panic("failed to open test db: " + err.Error())
  18. }
  19. DB = db
  20. LOG_DB = db
  21. common.UsingSQLite = true
  22. common.RedisEnabled = false
  23. common.BatchUpdateEnabled = false
  24. common.LogConsumeEnabled = true
  25. sqlDB, err := db.DB()
  26. if err != nil {
  27. panic("failed to get sql.DB: " + err.Error())
  28. }
  29. sqlDB.SetMaxOpenConns(1)
  30. if err := db.AutoMigrate(
  31. &Task{},
  32. &User{},
  33. &Token{},
  34. &Log{},
  35. &Channel{},
  36. &TopUp{},
  37. &SubscriptionPlan{},
  38. &SubscriptionOrder{},
  39. &UserSubscription{},
  40. ); err != nil {
  41. panic("failed to migrate: " + err.Error())
  42. }
  43. os.Exit(m.Run())
  44. }
  45. func truncateTables(t *testing.T) {
  46. t.Helper()
  47. t.Cleanup(func() {
  48. DB.Exec("DELETE FROM tasks")
  49. DB.Exec("DELETE FROM users")
  50. DB.Exec("DELETE FROM tokens")
  51. DB.Exec("DELETE FROM logs")
  52. DB.Exec("DELETE FROM channels")
  53. DB.Exec("DELETE FROM top_ups")
  54. DB.Exec("DELETE FROM subscription_orders")
  55. DB.Exec("DELETE FROM subscription_plans")
  56. DB.Exec("DELETE FROM user_subscriptions")
  57. })
  58. }
  59. func insertTask(t *testing.T, task *Task) {
  60. t.Helper()
  61. task.CreatedAt = time.Now().Unix()
  62. task.UpdatedAt = time.Now().Unix()
  63. require.NoError(t, DB.Create(task).Error)
  64. }
  65. // ---------------------------------------------------------------------------
  66. // Snapshot / Equal — pure logic tests (no DB)
  67. // ---------------------------------------------------------------------------
  68. func TestSnapshotEqual_Same(t *testing.T) {
  69. s := taskSnapshot{
  70. Status: TaskStatusInProgress,
  71. Progress: "50%",
  72. StartTime: 1000,
  73. FinishTime: 0,
  74. FailReason: "",
  75. ResultURL: "",
  76. Data: json.RawMessage(`{"key":"value"}`),
  77. }
  78. assert.True(t, s.Equal(s))
  79. }
  80. func TestSnapshotEqual_DifferentStatus(t *testing.T) {
  81. a := taskSnapshot{Status: TaskStatusInProgress, Data: json.RawMessage(`{}`)}
  82. b := taskSnapshot{Status: TaskStatusSuccess, Data: json.RawMessage(`{}`)}
  83. assert.False(t, a.Equal(b))
  84. }
  85. func TestSnapshotEqual_DifferentProgress(t *testing.T) {
  86. a := taskSnapshot{Status: TaskStatusInProgress, Progress: "30%", Data: json.RawMessage(`{}`)}
  87. b := taskSnapshot{Status: TaskStatusInProgress, Progress: "60%", Data: json.RawMessage(`{}`)}
  88. assert.False(t, a.Equal(b))
  89. }
  90. func TestSnapshotEqual_DifferentData(t *testing.T) {
  91. a := taskSnapshot{Status: TaskStatusInProgress, Data: json.RawMessage(`{"a":1}`)}
  92. b := taskSnapshot{Status: TaskStatusInProgress, Data: json.RawMessage(`{"a":2}`)}
  93. assert.False(t, a.Equal(b))
  94. }
  95. func TestSnapshotEqual_NilVsEmpty(t *testing.T) {
  96. a := taskSnapshot{Status: TaskStatusInProgress, Data: nil}
  97. b := taskSnapshot{Status: TaskStatusInProgress, Data: json.RawMessage{}}
  98. // bytes.Equal(nil, []byte{}) == true
  99. assert.True(t, a.Equal(b))
  100. }
  101. func TestSnapshot_Roundtrip(t *testing.T) {
  102. task := &Task{
  103. Status: TaskStatusInProgress,
  104. Progress: "42%",
  105. StartTime: 1234,
  106. FinishTime: 5678,
  107. FailReason: "timeout",
  108. PrivateData: TaskPrivateData{
  109. ResultURL: "https://example.com/result.mp4",
  110. },
  111. Data: json.RawMessage(`{"model":"test-model"}`),
  112. }
  113. snap := task.Snapshot()
  114. assert.Equal(t, task.Status, snap.Status)
  115. assert.Equal(t, task.Progress, snap.Progress)
  116. assert.Equal(t, task.StartTime, snap.StartTime)
  117. assert.Equal(t, task.FinishTime, snap.FinishTime)
  118. assert.Equal(t, task.FailReason, snap.FailReason)
  119. assert.Equal(t, task.PrivateData.ResultURL, snap.ResultURL)
  120. assert.JSONEq(t, string(task.Data), string(snap.Data))
  121. }
  122. // ---------------------------------------------------------------------------
  123. // UpdateWithStatus CAS — DB integration tests
  124. // ---------------------------------------------------------------------------
  125. func TestUpdateWithStatus_Win(t *testing.T) {
  126. truncateTables(t)
  127. task := &Task{
  128. TaskID: "task_cas_win",
  129. Status: TaskStatusInProgress,
  130. Progress: "50%",
  131. Data: json.RawMessage(`{}`),
  132. }
  133. insertTask(t, task)
  134. task.Status = TaskStatusSuccess
  135. task.Progress = "100%"
  136. won, err := task.UpdateWithStatus(TaskStatusInProgress)
  137. require.NoError(t, err)
  138. assert.True(t, won)
  139. var reloaded Task
  140. require.NoError(t, DB.First(&reloaded, task.ID).Error)
  141. assert.EqualValues(t, TaskStatusSuccess, reloaded.Status)
  142. assert.Equal(t, "100%", reloaded.Progress)
  143. }
  144. func TestUpdateWithStatus_Lose(t *testing.T) {
  145. truncateTables(t)
  146. task := &Task{
  147. TaskID: "task_cas_lose",
  148. Status: TaskStatusFailure,
  149. Data: json.RawMessage(`{}`),
  150. }
  151. insertTask(t, task)
  152. task.Status = TaskStatusSuccess
  153. won, err := task.UpdateWithStatus(TaskStatusInProgress) // wrong fromStatus
  154. require.NoError(t, err)
  155. assert.False(t, won)
  156. var reloaded Task
  157. require.NoError(t, DB.First(&reloaded, task.ID).Error)
  158. assert.EqualValues(t, TaskStatusFailure, reloaded.Status) // unchanged
  159. }
  160. func TestUpdateWithStatus_ConcurrentWinner(t *testing.T) {
  161. truncateTables(t)
  162. task := &Task{
  163. TaskID: "task_cas_race",
  164. Status: TaskStatusInProgress,
  165. Quota: 1000,
  166. Data: json.RawMessage(`{}`),
  167. }
  168. insertTask(t, task)
  169. const goroutines = 5
  170. wins := make([]bool, goroutines)
  171. var wg sync.WaitGroup
  172. wg.Add(goroutines)
  173. for i := 0; i < goroutines; i++ {
  174. go func(idx int) {
  175. defer wg.Done()
  176. t := &Task{}
  177. *t = Task{
  178. ID: task.ID,
  179. TaskID: task.TaskID,
  180. Status: TaskStatusSuccess,
  181. Progress: "100%",
  182. Quota: task.Quota,
  183. Data: json.RawMessage(`{}`),
  184. }
  185. t.CreatedAt = task.CreatedAt
  186. t.UpdatedAt = time.Now().Unix()
  187. won, err := t.UpdateWithStatus(TaskStatusInProgress)
  188. if err == nil {
  189. wins[idx] = won
  190. }
  191. }(i)
  192. }
  193. wg.Wait()
  194. winCount := 0
  195. for _, w := range wins {
  196. if w {
  197. winCount++
  198. }
  199. }
  200. assert.Equal(t, 1, winCount, "exactly one goroutine should win the CAS")
  201. }