task_billing_test.go 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606
  1. package service
  2. import (
  3. "context"
  4. "encoding/json"
  5. "os"
  6. "testing"
  7. "time"
  8. "github.com/QuantumNous/new-api/common"
  9. "github.com/QuantumNous/new-api/model"
  10. "github.com/glebarez/sqlite"
  11. "github.com/stretchr/testify/assert"
  12. "github.com/stretchr/testify/require"
  13. "gorm.io/gorm"
  14. )
  15. func TestMain(m *testing.M) {
  16. db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
  17. if err != nil {
  18. panic("failed to open test db: " + err.Error())
  19. }
  20. sqlDB, err := db.DB()
  21. if err != nil {
  22. panic("failed to get sql.DB: " + err.Error())
  23. }
  24. sqlDB.SetMaxOpenConns(1)
  25. model.DB = db
  26. model.LOG_DB = db
  27. common.UsingSQLite = true
  28. common.RedisEnabled = false
  29. common.BatchUpdateEnabled = false
  30. common.LogConsumeEnabled = true
  31. if err := db.AutoMigrate(
  32. &model.Task{},
  33. &model.User{},
  34. &model.Token{},
  35. &model.Log{},
  36. &model.Channel{},
  37. &model.UserSubscription{},
  38. ); err != nil {
  39. panic("failed to migrate: " + err.Error())
  40. }
  41. os.Exit(m.Run())
  42. }
  43. // ---------------------------------------------------------------------------
  44. // Seed helpers
  45. // ---------------------------------------------------------------------------
  46. func truncate(t *testing.T) {
  47. t.Helper()
  48. t.Cleanup(func() {
  49. model.DB.Exec("DELETE FROM tasks")
  50. model.DB.Exec("DELETE FROM users")
  51. model.DB.Exec("DELETE FROM tokens")
  52. model.DB.Exec("DELETE FROM logs")
  53. model.DB.Exec("DELETE FROM channels")
  54. model.DB.Exec("DELETE FROM user_subscriptions")
  55. })
  56. }
  57. func seedUser(t *testing.T, id int, quota int) {
  58. t.Helper()
  59. user := &model.User{Id: id, Username: "test_user", Quota: quota, Status: common.UserStatusEnabled}
  60. require.NoError(t, model.DB.Create(user).Error)
  61. }
  62. func seedToken(t *testing.T, id int, userId int, key string, remainQuota int) {
  63. t.Helper()
  64. token := &model.Token{
  65. Id: id,
  66. UserId: userId,
  67. Key: key,
  68. Name: "test_token",
  69. Status: common.TokenStatusEnabled,
  70. RemainQuota: remainQuota,
  71. UsedQuota: 0,
  72. }
  73. require.NoError(t, model.DB.Create(token).Error)
  74. }
  75. func seedSubscription(t *testing.T, id int, userId int, amountTotal int64, amountUsed int64) {
  76. t.Helper()
  77. sub := &model.UserSubscription{
  78. Id: id,
  79. UserId: userId,
  80. AmountTotal: amountTotal,
  81. AmountUsed: amountUsed,
  82. Status: "active",
  83. StartTime: time.Now().Unix(),
  84. EndTime: time.Now().Add(30 * 24 * time.Hour).Unix(),
  85. }
  86. require.NoError(t, model.DB.Create(sub).Error)
  87. }
  88. func seedChannel(t *testing.T, id int) {
  89. t.Helper()
  90. ch := &model.Channel{Id: id, Name: "test_channel", Key: "sk-test", Status: common.ChannelStatusEnabled}
  91. require.NoError(t, model.DB.Create(ch).Error)
  92. }
  93. func makeTask(userId, channelId, quota, tokenId int, billingSource string, subscriptionId int) *model.Task {
  94. return &model.Task{
  95. TaskID: "task_" + time.Now().Format("150405.000"),
  96. UserId: userId,
  97. ChannelId: channelId,
  98. Quota: quota,
  99. Status: model.TaskStatus(model.TaskStatusInProgress),
  100. Group: "default",
  101. Data: json.RawMessage(`{}`),
  102. CreatedAt: time.Now().Unix(),
  103. UpdatedAt: time.Now().Unix(),
  104. Properties: model.Properties{
  105. OriginModelName: "test-model",
  106. },
  107. PrivateData: model.TaskPrivateData{
  108. BillingSource: billingSource,
  109. SubscriptionId: subscriptionId,
  110. TokenId: tokenId,
  111. BillingContext: &model.TaskBillingContext{
  112. ModelPrice: 0.02,
  113. GroupRatio: 1.0,
  114. ModelName: "test-model",
  115. },
  116. },
  117. }
  118. }
  119. // ---------------------------------------------------------------------------
  120. // Read-back helpers
  121. // ---------------------------------------------------------------------------
  122. func getUserQuota(t *testing.T, id int) int {
  123. t.Helper()
  124. var user model.User
  125. require.NoError(t, model.DB.Select("quota").Where("id = ?", id).First(&user).Error)
  126. return user.Quota
  127. }
  128. func getTokenRemainQuota(t *testing.T, id int) int {
  129. t.Helper()
  130. var token model.Token
  131. require.NoError(t, model.DB.Select("remain_quota").Where("id = ?", id).First(&token).Error)
  132. return token.RemainQuota
  133. }
  134. func getTokenUsedQuota(t *testing.T, id int) int {
  135. t.Helper()
  136. var token model.Token
  137. require.NoError(t, model.DB.Select("used_quota").Where("id = ?", id).First(&token).Error)
  138. return token.UsedQuota
  139. }
  140. func getSubscriptionUsed(t *testing.T, id int) int64 {
  141. t.Helper()
  142. var sub model.UserSubscription
  143. require.NoError(t, model.DB.Select("amount_used").Where("id = ?", id).First(&sub).Error)
  144. return sub.AmountUsed
  145. }
  146. func getLastLog(t *testing.T) *model.Log {
  147. t.Helper()
  148. var log model.Log
  149. err := model.LOG_DB.Order("id desc").First(&log).Error
  150. if err != nil {
  151. return nil
  152. }
  153. return &log
  154. }
  155. func countLogs(t *testing.T) int64 {
  156. t.Helper()
  157. var count int64
  158. model.LOG_DB.Model(&model.Log{}).Count(&count)
  159. return count
  160. }
  161. // ===========================================================================
  162. // RefundTaskQuota tests
  163. // ===========================================================================
  164. func TestRefundTaskQuota_Wallet(t *testing.T) {
  165. truncate(t)
  166. ctx := context.Background()
  167. const userID, tokenID, channelID = 1, 1, 1
  168. const initQuota, preConsumed = 10000, 3000
  169. const tokenRemain = 5000
  170. seedUser(t, userID, initQuota)
  171. seedToken(t, tokenID, userID, "sk-test-key", tokenRemain)
  172. seedChannel(t, channelID)
  173. task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0)
  174. RefundTaskQuota(ctx, task, "task failed: upstream error")
  175. // User quota should increase by preConsumed
  176. assert.Equal(t, initQuota+preConsumed, getUserQuota(t, userID))
  177. // Token remain_quota should increase, used_quota should decrease
  178. assert.Equal(t, tokenRemain+preConsumed, getTokenRemainQuota(t, tokenID))
  179. assert.Equal(t, -preConsumed, getTokenUsedQuota(t, tokenID))
  180. // A refund log should be created
  181. log := getLastLog(t)
  182. require.NotNil(t, log)
  183. assert.Equal(t, model.LogTypeRefund, log.Type)
  184. assert.Equal(t, preConsumed, log.Quota)
  185. assert.Equal(t, "test-model", log.ModelName)
  186. }
  187. func TestRefundTaskQuota_Subscription(t *testing.T) {
  188. truncate(t)
  189. ctx := context.Background()
  190. const userID, tokenID, channelID, subID = 2, 2, 2, 1
  191. const preConsumed = 2000
  192. const subTotal, subUsed int64 = 100000, 50000
  193. const tokenRemain = 8000
  194. seedUser(t, userID, 0)
  195. seedToken(t, tokenID, userID, "sk-sub-key", tokenRemain)
  196. seedChannel(t, channelID)
  197. seedSubscription(t, subID, userID, subTotal, subUsed)
  198. task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceSubscription, subID)
  199. RefundTaskQuota(ctx, task, "subscription task failed")
  200. // Subscription used should decrease by preConsumed
  201. assert.Equal(t, subUsed-int64(preConsumed), getSubscriptionUsed(t, subID))
  202. // Token should also be refunded
  203. assert.Equal(t, tokenRemain+preConsumed, getTokenRemainQuota(t, tokenID))
  204. log := getLastLog(t)
  205. require.NotNil(t, log)
  206. assert.Equal(t, model.LogTypeRefund, log.Type)
  207. }
  208. func TestRefundTaskQuota_ZeroQuota(t *testing.T) {
  209. truncate(t)
  210. ctx := context.Background()
  211. const userID = 3
  212. seedUser(t, userID, 5000)
  213. task := makeTask(userID, 0, 0, 0, BillingSourceWallet, 0)
  214. RefundTaskQuota(ctx, task, "zero quota task")
  215. // No change to user quota
  216. assert.Equal(t, 5000, getUserQuota(t, userID))
  217. // No log created
  218. assert.Equal(t, int64(0), countLogs(t))
  219. }
  220. func TestRefundTaskQuota_NoToken(t *testing.T) {
  221. truncate(t)
  222. ctx := context.Background()
  223. const userID, channelID = 4, 4
  224. const initQuota, preConsumed = 10000, 1500
  225. seedUser(t, userID, initQuota)
  226. seedChannel(t, channelID)
  227. task := makeTask(userID, channelID, preConsumed, 0, BillingSourceWallet, 0) // TokenId=0
  228. RefundTaskQuota(ctx, task, "no token task failed")
  229. // User quota refunded
  230. assert.Equal(t, initQuota+preConsumed, getUserQuota(t, userID))
  231. // Log created
  232. log := getLastLog(t)
  233. require.NotNil(t, log)
  234. assert.Equal(t, model.LogTypeRefund, log.Type)
  235. }
  236. // ===========================================================================
  237. // RecalculateTaskQuota tests
  238. // ===========================================================================
  239. func TestRecalculate_PositiveDelta(t *testing.T) {
  240. truncate(t)
  241. ctx := context.Background()
  242. const userID, tokenID, channelID = 10, 10, 10
  243. const initQuota, preConsumed = 10000, 2000
  244. const actualQuota = 3000 // under-charged by 1000
  245. const tokenRemain = 5000
  246. seedUser(t, userID, initQuota)
  247. seedToken(t, tokenID, userID, "sk-recalc-pos", tokenRemain)
  248. seedChannel(t, channelID)
  249. task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0)
  250. RecalculateTaskQuota(ctx, task, actualQuota, "adaptor adjustment")
  251. // User quota should decrease by the delta (1000 additional charge)
  252. assert.Equal(t, initQuota-(actualQuota-preConsumed), getUserQuota(t, userID))
  253. // Token should also be charged the delta
  254. assert.Equal(t, tokenRemain-(actualQuota-preConsumed), getTokenRemainQuota(t, tokenID))
  255. // task.Quota should be updated to actualQuota
  256. assert.Equal(t, actualQuota, task.Quota)
  257. // Log type should be Consume (additional charge)
  258. log := getLastLog(t)
  259. require.NotNil(t, log)
  260. assert.Equal(t, model.LogTypeConsume, log.Type)
  261. assert.Equal(t, actualQuota-preConsumed, log.Quota)
  262. }
  263. func TestRecalculate_NegativeDelta(t *testing.T) {
  264. truncate(t)
  265. ctx := context.Background()
  266. const userID, tokenID, channelID = 11, 11, 11
  267. const initQuota, preConsumed = 10000, 5000
  268. const actualQuota = 3000 // over-charged by 2000
  269. const tokenRemain = 5000
  270. seedUser(t, userID, initQuota)
  271. seedToken(t, tokenID, userID, "sk-recalc-neg", tokenRemain)
  272. seedChannel(t, channelID)
  273. task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0)
  274. RecalculateTaskQuota(ctx, task, actualQuota, "adaptor adjustment")
  275. // User quota should increase by abs(delta) = 2000 (refund overpayment)
  276. assert.Equal(t, initQuota+(preConsumed-actualQuota), getUserQuota(t, userID))
  277. // Token should be refunded the difference
  278. assert.Equal(t, tokenRemain+(preConsumed-actualQuota), getTokenRemainQuota(t, tokenID))
  279. // task.Quota updated
  280. assert.Equal(t, actualQuota, task.Quota)
  281. // Log type should be Refund
  282. log := getLastLog(t)
  283. require.NotNil(t, log)
  284. assert.Equal(t, model.LogTypeRefund, log.Type)
  285. assert.Equal(t, preConsumed-actualQuota, log.Quota)
  286. }
  287. func TestRecalculate_ZeroDelta(t *testing.T) {
  288. truncate(t)
  289. ctx := context.Background()
  290. const userID = 12
  291. const initQuota, preConsumed = 10000, 3000
  292. seedUser(t, userID, initQuota)
  293. task := makeTask(userID, 0, preConsumed, 0, BillingSourceWallet, 0)
  294. RecalculateTaskQuota(ctx, task, preConsumed, "exact match")
  295. // No change to user quota
  296. assert.Equal(t, initQuota, getUserQuota(t, userID))
  297. // No log created (delta is zero)
  298. assert.Equal(t, int64(0), countLogs(t))
  299. }
  300. func TestRecalculate_ActualQuotaZero(t *testing.T) {
  301. truncate(t)
  302. ctx := context.Background()
  303. const userID = 13
  304. const initQuota = 10000
  305. seedUser(t, userID, initQuota)
  306. task := makeTask(userID, 0, 5000, 0, BillingSourceWallet, 0)
  307. RecalculateTaskQuota(ctx, task, 0, "zero actual")
  308. // No change (early return)
  309. assert.Equal(t, initQuota, getUserQuota(t, userID))
  310. assert.Equal(t, int64(0), countLogs(t))
  311. }
  312. func TestRecalculate_Subscription_NegativeDelta(t *testing.T) {
  313. truncate(t)
  314. ctx := context.Background()
  315. const userID, tokenID, channelID, subID = 14, 14, 14, 2
  316. const preConsumed = 5000
  317. const actualQuota = 2000 // over-charged by 3000
  318. const subTotal, subUsed int64 = 100000, 50000
  319. const tokenRemain = 8000
  320. seedUser(t, userID, 0)
  321. seedToken(t, tokenID, userID, "sk-sub-recalc", tokenRemain)
  322. seedChannel(t, channelID)
  323. seedSubscription(t, subID, userID, subTotal, subUsed)
  324. task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceSubscription, subID)
  325. RecalculateTaskQuota(ctx, task, actualQuota, "subscription over-charge")
  326. // Subscription used should decrease by delta (refund 3000)
  327. assert.Equal(t, subUsed-int64(preConsumed-actualQuota), getSubscriptionUsed(t, subID))
  328. // Token refunded
  329. assert.Equal(t, tokenRemain+(preConsumed-actualQuota), getTokenRemainQuota(t, tokenID))
  330. assert.Equal(t, actualQuota, task.Quota)
  331. log := getLastLog(t)
  332. require.NotNil(t, log)
  333. assert.Equal(t, model.LogTypeRefund, log.Type)
  334. }
  335. // ===========================================================================
  336. // CAS + Billing integration tests
  337. // Simulates the flow in updateVideoSingleTask (service/task_polling.go)
  338. // ===========================================================================
  339. // simulatePollBilling reproduces the CAS + billing logic from updateVideoSingleTask.
  340. // It takes a persisted task (already in DB), applies the new status, and performs
  341. // the conditional update + billing exactly as the polling loop does.
  342. func simulatePollBilling(ctx context.Context, task *model.Task, newStatus model.TaskStatus, actualQuota int) {
  343. snap := task.Snapshot()
  344. shouldRefund := false
  345. shouldSettle := false
  346. quota := task.Quota
  347. task.Status = newStatus
  348. switch string(newStatus) {
  349. case model.TaskStatusSuccess:
  350. task.Progress = "100%"
  351. task.FinishTime = 9999
  352. shouldSettle = true
  353. case model.TaskStatusFailure:
  354. task.Progress = "100%"
  355. task.FinishTime = 9999
  356. task.FailReason = "upstream error"
  357. if quota != 0 {
  358. shouldRefund = true
  359. }
  360. default:
  361. task.Progress = "50%"
  362. }
  363. isDone := task.Status == model.TaskStatus(model.TaskStatusSuccess) || task.Status == model.TaskStatus(model.TaskStatusFailure)
  364. if isDone && snap.Status != task.Status {
  365. won, err := task.UpdateWithStatus(snap.Status)
  366. if err != nil {
  367. shouldRefund = false
  368. shouldSettle = false
  369. } else if !won {
  370. shouldRefund = false
  371. shouldSettle = false
  372. }
  373. } else if !snap.Equal(task.Snapshot()) {
  374. _, _ = task.UpdateWithStatus(snap.Status)
  375. }
  376. if shouldSettle && actualQuota > 0 {
  377. RecalculateTaskQuota(ctx, task, actualQuota, "test settle")
  378. }
  379. if shouldRefund {
  380. RefundTaskQuota(ctx, task, task.FailReason)
  381. }
  382. }
  383. func TestCASGuardedRefund_Win(t *testing.T) {
  384. truncate(t)
  385. ctx := context.Background()
  386. const userID, tokenID, channelID = 20, 20, 20
  387. const initQuota, preConsumed = 10000, 4000
  388. const tokenRemain = 6000
  389. seedUser(t, userID, initQuota)
  390. seedToken(t, tokenID, userID, "sk-cas-refund-win", tokenRemain)
  391. seedChannel(t, channelID)
  392. task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0)
  393. task.Status = model.TaskStatus(model.TaskStatusInProgress)
  394. require.NoError(t, model.DB.Create(task).Error)
  395. simulatePollBilling(ctx, task, model.TaskStatus(model.TaskStatusFailure), 0)
  396. // CAS wins: task in DB should now be FAILURE
  397. var reloaded model.Task
  398. require.NoError(t, model.DB.First(&reloaded, task.ID).Error)
  399. assert.EqualValues(t, model.TaskStatusFailure, reloaded.Status)
  400. // Refund should have happened
  401. assert.Equal(t, initQuota+preConsumed, getUserQuota(t, userID))
  402. assert.Equal(t, tokenRemain+preConsumed, getTokenRemainQuota(t, tokenID))
  403. log := getLastLog(t)
  404. require.NotNil(t, log)
  405. assert.Equal(t, model.LogTypeRefund, log.Type)
  406. }
  407. func TestCASGuardedRefund_Lose(t *testing.T) {
  408. truncate(t)
  409. ctx := context.Background()
  410. const userID, tokenID, channelID = 21, 21, 21
  411. const initQuota, preConsumed = 10000, 4000
  412. const tokenRemain = 6000
  413. seedUser(t, userID, initQuota)
  414. seedToken(t, tokenID, userID, "sk-cas-refund-lose", tokenRemain)
  415. seedChannel(t, channelID)
  416. // Create task with IN_PROGRESS in DB
  417. task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0)
  418. task.Status = model.TaskStatus(model.TaskStatusInProgress)
  419. require.NoError(t, model.DB.Create(task).Error)
  420. // Simulate another process already transitioning to FAILURE
  421. model.DB.Model(&model.Task{}).Where("id = ?", task.ID).Update("status", model.TaskStatusFailure)
  422. // Our process still has the old in-memory state (IN_PROGRESS) and tries to transition
  423. // task.Status is still IN_PROGRESS in the snapshot
  424. simulatePollBilling(ctx, task, model.TaskStatus(model.TaskStatusFailure), 0)
  425. // CAS lost: user quota should NOT change (no double refund)
  426. assert.Equal(t, initQuota, getUserQuota(t, userID))
  427. assert.Equal(t, tokenRemain, getTokenRemainQuota(t, tokenID))
  428. // No billing log should be created
  429. assert.Equal(t, int64(0), countLogs(t))
  430. }
  431. func TestCASGuardedSettle_Win(t *testing.T) {
  432. truncate(t)
  433. ctx := context.Background()
  434. const userID, tokenID, channelID = 22, 22, 22
  435. const initQuota, preConsumed = 10000, 5000
  436. const actualQuota = 3000 // over-charged, should get partial refund
  437. const tokenRemain = 8000
  438. seedUser(t, userID, initQuota)
  439. seedToken(t, tokenID, userID, "sk-cas-settle-win", tokenRemain)
  440. seedChannel(t, channelID)
  441. task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0)
  442. task.Status = model.TaskStatus(model.TaskStatusInProgress)
  443. require.NoError(t, model.DB.Create(task).Error)
  444. simulatePollBilling(ctx, task, model.TaskStatus(model.TaskStatusSuccess), actualQuota)
  445. // CAS wins: task should be SUCCESS
  446. var reloaded model.Task
  447. require.NoError(t, model.DB.First(&reloaded, task.ID).Error)
  448. assert.EqualValues(t, model.TaskStatusSuccess, reloaded.Status)
  449. // Settlement should refund the over-charge (5000 - 3000 = 2000 back to user)
  450. assert.Equal(t, initQuota+(preConsumed-actualQuota), getUserQuota(t, userID))
  451. assert.Equal(t, tokenRemain+(preConsumed-actualQuota), getTokenRemainQuota(t, tokenID))
  452. // task.Quota should be updated to actualQuota
  453. assert.Equal(t, actualQuota, task.Quota)
  454. }
  455. func TestNonTerminalUpdate_NoBilling(t *testing.T) {
  456. truncate(t)
  457. ctx := context.Background()
  458. const userID, channelID = 23, 23
  459. const initQuota, preConsumed = 10000, 3000
  460. seedUser(t, userID, initQuota)
  461. seedChannel(t, channelID)
  462. task := makeTask(userID, channelID, preConsumed, 0, BillingSourceWallet, 0)
  463. task.Status = model.TaskStatus(model.TaskStatusInProgress)
  464. task.Progress = "20%"
  465. require.NoError(t, model.DB.Create(task).Error)
  466. // Simulate a non-terminal poll update (still IN_PROGRESS, progress changed)
  467. simulatePollBilling(ctx, task, model.TaskStatus(model.TaskStatusInProgress), 0)
  468. // User quota should NOT change
  469. assert.Equal(t, initQuota, getUserQuota(t, userID))
  470. // No billing log
  471. assert.Equal(t, int64(0), countLogs(t))
  472. // Task progress should be updated in DB
  473. var reloaded model.Task
  474. require.NoError(t, model.DB.First(&reloaded, task.ID).Error)
  475. assert.Equal(t, "50%", reloaded.Progress)
  476. }