task_billing_test.go 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716
  1. package service
  2. import (
  3. "context"
  4. "encoding/json"
  5. "net/http"
  6. "os"
  7. "testing"
  8. "time"
  9. "github.com/QuantumNous/new-api/common"
  10. "github.com/QuantumNous/new-api/model"
  11. relaycommon "github.com/QuantumNous/new-api/relay/common"
  12. "github.com/glebarez/sqlite"
  13. "github.com/stretchr/testify/assert"
  14. "github.com/stretchr/testify/require"
  15. "gorm.io/gorm"
  16. )
  17. func TestMain(m *testing.M) {
  18. db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
  19. if err != nil {
  20. panic("failed to open test db: " + err.Error())
  21. }
  22. sqlDB, err := db.DB()
  23. if err != nil {
  24. panic("failed to get sql.DB: " + err.Error())
  25. }
  26. sqlDB.SetMaxOpenConns(1)
  27. model.DB = db
  28. model.LOG_DB = db
  29. common.UsingSQLite = true
  30. common.RedisEnabled = false
  31. common.BatchUpdateEnabled = false
  32. common.LogConsumeEnabled = true
  33. if err := db.AutoMigrate(
  34. &model.Task{},
  35. &model.User{},
  36. &model.Token{},
  37. &model.Log{},
  38. &model.Channel{},
  39. &model.TopUp{},
  40. &model.UserSubscription{},
  41. ); err != nil {
  42. panic("failed to migrate: " + err.Error())
  43. }
  44. os.Exit(m.Run())
  45. }
  46. // ---------------------------------------------------------------------------
  47. // Seed helpers
  48. // ---------------------------------------------------------------------------
  49. func truncate(t *testing.T) {
  50. t.Helper()
  51. t.Cleanup(func() {
  52. model.DB.Exec("DELETE FROM tasks")
  53. model.DB.Exec("DELETE FROM users")
  54. model.DB.Exec("DELETE FROM tokens")
  55. model.DB.Exec("DELETE FROM logs")
  56. model.DB.Exec("DELETE FROM channels")
  57. model.DB.Exec("DELETE FROM top_ups")
  58. model.DB.Exec("DELETE FROM user_subscriptions")
  59. })
  60. }
  61. func seedUser(t *testing.T, id int, quota int) {
  62. t.Helper()
  63. user := &model.User{Id: id, Username: "test_user", Quota: quota, Status: common.UserStatusEnabled}
  64. require.NoError(t, model.DB.Create(user).Error)
  65. }
  66. func seedToken(t *testing.T, id int, userId int, key string, remainQuota int) {
  67. t.Helper()
  68. token := &model.Token{
  69. Id: id,
  70. UserId: userId,
  71. Key: key,
  72. Name: "test_token",
  73. Status: common.TokenStatusEnabled,
  74. RemainQuota: remainQuota,
  75. UsedQuota: 0,
  76. }
  77. require.NoError(t, model.DB.Create(token).Error)
  78. }
  79. func seedSubscription(t *testing.T, id int, userId int, amountTotal int64, amountUsed int64) {
  80. t.Helper()
  81. sub := &model.UserSubscription{
  82. Id: id,
  83. UserId: userId,
  84. AmountTotal: amountTotal,
  85. AmountUsed: amountUsed,
  86. Status: "active",
  87. StartTime: time.Now().Unix(),
  88. EndTime: time.Now().Add(30 * 24 * time.Hour).Unix(),
  89. }
  90. require.NoError(t, model.DB.Create(sub).Error)
  91. }
  92. func seedChannel(t *testing.T, id int) {
  93. t.Helper()
  94. ch := &model.Channel{Id: id, Name: "test_channel", Key: "sk-test", Status: common.ChannelStatusEnabled}
  95. require.NoError(t, model.DB.Create(ch).Error)
  96. }
  97. func makeTask(userId, channelId, quota, tokenId int, billingSource string, subscriptionId int) *model.Task {
  98. return &model.Task{
  99. TaskID: "task_" + time.Now().Format("150405.000"),
  100. UserId: userId,
  101. ChannelId: channelId,
  102. Quota: quota,
  103. Status: model.TaskStatus(model.TaskStatusInProgress),
  104. Group: "default",
  105. Data: json.RawMessage(`{}`),
  106. CreatedAt: time.Now().Unix(),
  107. UpdatedAt: time.Now().Unix(),
  108. Properties: model.Properties{
  109. OriginModelName: "test-model",
  110. },
  111. PrivateData: model.TaskPrivateData{
  112. BillingSource: billingSource,
  113. SubscriptionId: subscriptionId,
  114. TokenId: tokenId,
  115. BillingContext: &model.TaskBillingContext{
  116. ModelPrice: 0.02,
  117. GroupRatio: 1.0,
  118. OriginModelName: "test-model",
  119. },
  120. },
  121. }
  122. }
  123. // ---------------------------------------------------------------------------
  124. // Read-back helpers
  125. // ---------------------------------------------------------------------------
  126. func getUserQuota(t *testing.T, id int) int {
  127. t.Helper()
  128. var user model.User
  129. require.NoError(t, model.DB.Select("quota").Where("id = ?", id).First(&user).Error)
  130. return user.Quota
  131. }
  132. func getTokenRemainQuota(t *testing.T, id int) int {
  133. t.Helper()
  134. var token model.Token
  135. require.NoError(t, model.DB.Select("remain_quota").Where("id = ?", id).First(&token).Error)
  136. return token.RemainQuota
  137. }
  138. func getTokenUsedQuota(t *testing.T, id int) int {
  139. t.Helper()
  140. var token model.Token
  141. require.NoError(t, model.DB.Select("used_quota").Where("id = ?", id).First(&token).Error)
  142. return token.UsedQuota
  143. }
  144. func getSubscriptionUsed(t *testing.T, id int) int64 {
  145. t.Helper()
  146. var sub model.UserSubscription
  147. require.NoError(t, model.DB.Select("amount_used").Where("id = ?", id).First(&sub).Error)
  148. return sub.AmountUsed
  149. }
  150. func getLastLog(t *testing.T) *model.Log {
  151. t.Helper()
  152. var log model.Log
  153. err := model.LOG_DB.Order("id desc").First(&log).Error
  154. if err != nil {
  155. return nil
  156. }
  157. return &log
  158. }
  159. func countLogs(t *testing.T) int64 {
  160. t.Helper()
  161. var count int64
  162. model.LOG_DB.Model(&model.Log{}).Count(&count)
  163. return count
  164. }
  165. // ===========================================================================
  166. // RefundTaskQuota tests
  167. // ===========================================================================
  168. func TestRefundTaskQuota_Wallet(t *testing.T) {
  169. truncate(t)
  170. ctx := context.Background()
  171. const userID, tokenID, channelID = 1, 1, 1
  172. const initQuota, preConsumed = 10000, 3000
  173. const tokenRemain = 5000
  174. seedUser(t, userID, initQuota)
  175. seedToken(t, tokenID, userID, "sk-test-key", tokenRemain)
  176. seedChannel(t, channelID)
  177. task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0)
  178. RefundTaskQuota(ctx, task, "task failed: upstream error")
  179. // User quota should increase by preConsumed
  180. assert.Equal(t, initQuota+preConsumed, getUserQuota(t, userID))
  181. // Token remain_quota should increase, used_quota should decrease
  182. assert.Equal(t, tokenRemain+preConsumed, getTokenRemainQuota(t, tokenID))
  183. assert.Equal(t, -preConsumed, getTokenUsedQuota(t, tokenID))
  184. // A refund log should be created
  185. log := getLastLog(t)
  186. require.NotNil(t, log)
  187. assert.Equal(t, model.LogTypeRefund, log.Type)
  188. assert.Equal(t, preConsumed, log.Quota)
  189. assert.Equal(t, "test-model", log.ModelName)
  190. }
  191. func TestRefundTaskQuota_Subscription(t *testing.T) {
  192. truncate(t)
  193. ctx := context.Background()
  194. const userID, tokenID, channelID, subID = 2, 2, 2, 1
  195. const preConsumed = 2000
  196. const subTotal, subUsed int64 = 100000, 50000
  197. const tokenRemain = 8000
  198. seedUser(t, userID, 0)
  199. seedToken(t, tokenID, userID, "sk-sub-key", tokenRemain)
  200. seedChannel(t, channelID)
  201. seedSubscription(t, subID, userID, subTotal, subUsed)
  202. task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceSubscription, subID)
  203. RefundTaskQuota(ctx, task, "subscription task failed")
  204. // Subscription used should decrease by preConsumed
  205. assert.Equal(t, subUsed-int64(preConsumed), getSubscriptionUsed(t, subID))
  206. // Token should also be refunded
  207. assert.Equal(t, tokenRemain+preConsumed, getTokenRemainQuota(t, tokenID))
  208. log := getLastLog(t)
  209. require.NotNil(t, log)
  210. assert.Equal(t, model.LogTypeRefund, log.Type)
  211. }
  212. func TestRefundTaskQuota_ZeroQuota(t *testing.T) {
  213. truncate(t)
  214. ctx := context.Background()
  215. const userID = 3
  216. seedUser(t, userID, 5000)
  217. task := makeTask(userID, 0, 0, 0, BillingSourceWallet, 0)
  218. RefundTaskQuota(ctx, task, "zero quota task")
  219. // No change to user quota
  220. assert.Equal(t, 5000, getUserQuota(t, userID))
  221. // No log created
  222. assert.Equal(t, int64(0), countLogs(t))
  223. }
  224. func TestRefundTaskQuota_NoToken(t *testing.T) {
  225. truncate(t)
  226. ctx := context.Background()
  227. const userID, channelID = 4, 4
  228. const initQuota, preConsumed = 10000, 1500
  229. seedUser(t, userID, initQuota)
  230. seedChannel(t, channelID)
  231. task := makeTask(userID, channelID, preConsumed, 0, BillingSourceWallet, 0) // TokenId=0
  232. RefundTaskQuota(ctx, task, "no token task failed")
  233. // User quota refunded
  234. assert.Equal(t, initQuota+preConsumed, getUserQuota(t, userID))
  235. // Log created
  236. log := getLastLog(t)
  237. require.NotNil(t, log)
  238. assert.Equal(t, model.LogTypeRefund, log.Type)
  239. }
  240. // ===========================================================================
  241. // RecalculateTaskQuota tests
  242. // ===========================================================================
  243. func TestRecalculate_PositiveDelta(t *testing.T) {
  244. truncate(t)
  245. ctx := context.Background()
  246. const userID, tokenID, channelID = 10, 10, 10
  247. const initQuota, preConsumed = 10000, 2000
  248. const actualQuota = 3000 // under-charged by 1000
  249. const tokenRemain = 5000
  250. seedUser(t, userID, initQuota)
  251. seedToken(t, tokenID, userID, "sk-recalc-pos", tokenRemain)
  252. seedChannel(t, channelID)
  253. task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0)
  254. RecalculateTaskQuota(ctx, task, actualQuota, "adaptor adjustment")
  255. // User quota should decrease by the delta (1000 additional charge)
  256. assert.Equal(t, initQuota-(actualQuota-preConsumed), getUserQuota(t, userID))
  257. // Token should also be charged the delta
  258. assert.Equal(t, tokenRemain-(actualQuota-preConsumed), getTokenRemainQuota(t, tokenID))
  259. // task.Quota should be updated to actualQuota
  260. assert.Equal(t, actualQuota, task.Quota)
  261. // Log type should be Consume (additional charge)
  262. log := getLastLog(t)
  263. require.NotNil(t, log)
  264. assert.Equal(t, model.LogTypeConsume, log.Type)
  265. assert.Equal(t, actualQuota-preConsumed, log.Quota)
  266. }
  267. func TestRecalculate_NegativeDelta(t *testing.T) {
  268. truncate(t)
  269. ctx := context.Background()
  270. const userID, tokenID, channelID = 11, 11, 11
  271. const initQuota, preConsumed = 10000, 5000
  272. const actualQuota = 3000 // over-charged by 2000
  273. const tokenRemain = 5000
  274. seedUser(t, userID, initQuota)
  275. seedToken(t, tokenID, userID, "sk-recalc-neg", tokenRemain)
  276. seedChannel(t, channelID)
  277. task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0)
  278. RecalculateTaskQuota(ctx, task, actualQuota, "adaptor adjustment")
  279. // User quota should increase by abs(delta) = 2000 (refund overpayment)
  280. assert.Equal(t, initQuota+(preConsumed-actualQuota), getUserQuota(t, userID))
  281. // Token should be refunded the difference
  282. assert.Equal(t, tokenRemain+(preConsumed-actualQuota), getTokenRemainQuota(t, tokenID))
  283. // task.Quota updated
  284. assert.Equal(t, actualQuota, task.Quota)
  285. // Log type should be Refund
  286. log := getLastLog(t)
  287. require.NotNil(t, log)
  288. assert.Equal(t, model.LogTypeRefund, log.Type)
  289. assert.Equal(t, preConsumed-actualQuota, log.Quota)
  290. }
  291. func TestRecalculate_ZeroDelta(t *testing.T) {
  292. truncate(t)
  293. ctx := context.Background()
  294. const userID = 12
  295. const initQuota, preConsumed = 10000, 3000
  296. seedUser(t, userID, initQuota)
  297. task := makeTask(userID, 0, preConsumed, 0, BillingSourceWallet, 0)
  298. RecalculateTaskQuota(ctx, task, preConsumed, "exact match")
  299. // No change to user quota
  300. assert.Equal(t, initQuota, getUserQuota(t, userID))
  301. // No log created (delta is zero)
  302. assert.Equal(t, int64(0), countLogs(t))
  303. }
  304. func TestRecalculate_ActualQuotaZero(t *testing.T) {
  305. truncate(t)
  306. ctx := context.Background()
  307. const userID = 13
  308. const initQuota = 10000
  309. seedUser(t, userID, initQuota)
  310. task := makeTask(userID, 0, 5000, 0, BillingSourceWallet, 0)
  311. RecalculateTaskQuota(ctx, task, 0, "zero actual")
  312. // No change (early return)
  313. assert.Equal(t, initQuota, getUserQuota(t, userID))
  314. assert.Equal(t, int64(0), countLogs(t))
  315. }
  316. func TestRecalculate_Subscription_NegativeDelta(t *testing.T) {
  317. truncate(t)
  318. ctx := context.Background()
  319. const userID, tokenID, channelID, subID = 14, 14, 14, 2
  320. const preConsumed = 5000
  321. const actualQuota = 2000 // over-charged by 3000
  322. const subTotal, subUsed int64 = 100000, 50000
  323. const tokenRemain = 8000
  324. seedUser(t, userID, 0)
  325. seedToken(t, tokenID, userID, "sk-sub-recalc", tokenRemain)
  326. seedChannel(t, channelID)
  327. seedSubscription(t, subID, userID, subTotal, subUsed)
  328. task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceSubscription, subID)
  329. RecalculateTaskQuota(ctx, task, actualQuota, "subscription over-charge")
  330. // Subscription used should decrease by delta (refund 3000)
  331. assert.Equal(t, subUsed-int64(preConsumed-actualQuota), getSubscriptionUsed(t, subID))
  332. // Token refunded
  333. assert.Equal(t, tokenRemain+(preConsumed-actualQuota), getTokenRemainQuota(t, tokenID))
  334. assert.Equal(t, actualQuota, task.Quota)
  335. log := getLastLog(t)
  336. require.NotNil(t, log)
  337. assert.Equal(t, model.LogTypeRefund, log.Type)
  338. }
  339. // ===========================================================================
  340. // CAS + Billing integration tests
  341. // Simulates the flow in updateVideoSingleTask (service/task_polling.go)
  342. // ===========================================================================
  343. // simulatePollBilling reproduces the CAS + billing logic from updateVideoSingleTask.
  344. // It takes a persisted task (already in DB), applies the new status, and performs
  345. // the conditional update + billing exactly as the polling loop does.
  346. func simulatePollBilling(ctx context.Context, task *model.Task, newStatus model.TaskStatus, actualQuota int) {
  347. snap := task.Snapshot()
  348. shouldRefund := false
  349. shouldSettle := false
  350. quota := task.Quota
  351. task.Status = newStatus
  352. switch string(newStatus) {
  353. case model.TaskStatusSuccess:
  354. task.Progress = "100%"
  355. task.FinishTime = 9999
  356. shouldSettle = true
  357. case model.TaskStatusFailure:
  358. task.Progress = "100%"
  359. task.FinishTime = 9999
  360. task.FailReason = "upstream error"
  361. if quota != 0 {
  362. shouldRefund = true
  363. }
  364. default:
  365. task.Progress = "50%"
  366. }
  367. isDone := task.Status == model.TaskStatus(model.TaskStatusSuccess) || task.Status == model.TaskStatus(model.TaskStatusFailure)
  368. if isDone && snap.Status != task.Status {
  369. won, err := task.UpdateWithStatus(snap.Status)
  370. if err != nil {
  371. shouldRefund = false
  372. shouldSettle = false
  373. } else if !won {
  374. shouldRefund = false
  375. shouldSettle = false
  376. }
  377. } else if !snap.Equal(task.Snapshot()) {
  378. _, _ = task.UpdateWithStatus(snap.Status)
  379. }
  380. if shouldSettle && actualQuota > 0 {
  381. RecalculateTaskQuota(ctx, task, actualQuota, "test settle")
  382. }
  383. if shouldRefund {
  384. RefundTaskQuota(ctx, task, task.FailReason)
  385. }
  386. }
  387. func TestCASGuardedRefund_Win(t *testing.T) {
  388. truncate(t)
  389. ctx := context.Background()
  390. const userID, tokenID, channelID = 20, 20, 20
  391. const initQuota, preConsumed = 10000, 4000
  392. const tokenRemain = 6000
  393. seedUser(t, userID, initQuota)
  394. seedToken(t, tokenID, userID, "sk-cas-refund-win", tokenRemain)
  395. seedChannel(t, channelID)
  396. task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0)
  397. task.Status = model.TaskStatus(model.TaskStatusInProgress)
  398. require.NoError(t, model.DB.Create(task).Error)
  399. simulatePollBilling(ctx, task, model.TaskStatus(model.TaskStatusFailure), 0)
  400. // CAS wins: task in DB should now be FAILURE
  401. var reloaded model.Task
  402. require.NoError(t, model.DB.First(&reloaded, task.ID).Error)
  403. assert.EqualValues(t, model.TaskStatusFailure, reloaded.Status)
  404. // Refund should have happened
  405. assert.Equal(t, initQuota+preConsumed, getUserQuota(t, userID))
  406. assert.Equal(t, tokenRemain+preConsumed, getTokenRemainQuota(t, tokenID))
  407. log := getLastLog(t)
  408. require.NotNil(t, log)
  409. assert.Equal(t, model.LogTypeRefund, log.Type)
  410. }
  411. func TestCASGuardedRefund_Lose(t *testing.T) {
  412. truncate(t)
  413. ctx := context.Background()
  414. const userID, tokenID, channelID = 21, 21, 21
  415. const initQuota, preConsumed = 10000, 4000
  416. const tokenRemain = 6000
  417. seedUser(t, userID, initQuota)
  418. seedToken(t, tokenID, userID, "sk-cas-refund-lose", tokenRemain)
  419. seedChannel(t, channelID)
  420. // Create task with IN_PROGRESS in DB
  421. task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0)
  422. task.Status = model.TaskStatus(model.TaskStatusInProgress)
  423. require.NoError(t, model.DB.Create(task).Error)
  424. // Simulate another process already transitioning to FAILURE
  425. model.DB.Model(&model.Task{}).Where("id = ?", task.ID).Update("status", model.TaskStatusFailure)
  426. // Our process still has the old in-memory state (IN_PROGRESS) and tries to transition
  427. // task.Status is still IN_PROGRESS in the snapshot
  428. simulatePollBilling(ctx, task, model.TaskStatus(model.TaskStatusFailure), 0)
  429. // CAS lost: user quota should NOT change (no double refund)
  430. assert.Equal(t, initQuota, getUserQuota(t, userID))
  431. assert.Equal(t, tokenRemain, getTokenRemainQuota(t, tokenID))
  432. // No billing log should be created
  433. assert.Equal(t, int64(0), countLogs(t))
  434. }
  435. func TestCASGuardedSettle_Win(t *testing.T) {
  436. truncate(t)
  437. ctx := context.Background()
  438. const userID, tokenID, channelID = 22, 22, 22
  439. const initQuota, preConsumed = 10000, 5000
  440. const actualQuota = 3000 // over-charged, should get partial refund
  441. const tokenRemain = 8000
  442. seedUser(t, userID, initQuota)
  443. seedToken(t, tokenID, userID, "sk-cas-settle-win", tokenRemain)
  444. seedChannel(t, channelID)
  445. task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0)
  446. task.Status = model.TaskStatus(model.TaskStatusInProgress)
  447. require.NoError(t, model.DB.Create(task).Error)
  448. simulatePollBilling(ctx, task, model.TaskStatus(model.TaskStatusSuccess), actualQuota)
  449. // CAS wins: task should be SUCCESS
  450. var reloaded model.Task
  451. require.NoError(t, model.DB.First(&reloaded, task.ID).Error)
  452. assert.EqualValues(t, model.TaskStatusSuccess, reloaded.Status)
  453. // Settlement should refund the over-charge (5000 - 3000 = 2000 back to user)
  454. assert.Equal(t, initQuota+(preConsumed-actualQuota), getUserQuota(t, userID))
  455. assert.Equal(t, tokenRemain+(preConsumed-actualQuota), getTokenRemainQuota(t, tokenID))
  456. // task.Quota should be updated to actualQuota
  457. assert.Equal(t, actualQuota, task.Quota)
  458. }
  459. func TestNonTerminalUpdate_NoBilling(t *testing.T) {
  460. truncate(t)
  461. ctx := context.Background()
  462. const userID, channelID = 23, 23
  463. const initQuota, preConsumed = 10000, 3000
  464. seedUser(t, userID, initQuota)
  465. seedChannel(t, channelID)
  466. task := makeTask(userID, channelID, preConsumed, 0, BillingSourceWallet, 0)
  467. task.Status = model.TaskStatus(model.TaskStatusInProgress)
  468. task.Progress = "20%"
  469. require.NoError(t, model.DB.Create(task).Error)
  470. // Simulate a non-terminal poll update (still IN_PROGRESS, progress changed)
  471. simulatePollBilling(ctx, task, model.TaskStatus(model.TaskStatusInProgress), 0)
  472. // User quota should NOT change
  473. assert.Equal(t, initQuota, getUserQuota(t, userID))
  474. // No billing log
  475. assert.Equal(t, int64(0), countLogs(t))
  476. // Task progress should be updated in DB
  477. var reloaded model.Task
  478. require.NoError(t, model.DB.First(&reloaded, task.ID).Error)
  479. assert.Equal(t, "50%", reloaded.Progress)
  480. }
  481. // ===========================================================================
  482. // Mock adaptor for settleTaskBillingOnComplete tests
  483. // ===========================================================================
  484. type mockAdaptor struct {
  485. adjustReturn int
  486. }
  487. func (m *mockAdaptor) Init(_ *relaycommon.RelayInfo) {}
  488. func (m *mockAdaptor) FetchTask(string, string, map[string]any, string) (*http.Response, error) {
  489. return nil, nil
  490. }
  491. func (m *mockAdaptor) ParseTaskResult([]byte) (*relaycommon.TaskInfo, error) { return nil, nil }
  492. func (m *mockAdaptor) AdjustBillingOnComplete(_ *model.Task, _ *relaycommon.TaskInfo) int {
  493. return m.adjustReturn
  494. }
  495. // ===========================================================================
  496. // PerCallBilling tests — settleTaskBillingOnComplete
  497. // ===========================================================================
  498. func TestSettle_PerCallBilling_SkipsAdaptorAdjust(t *testing.T) {
  499. truncate(t)
  500. ctx := context.Background()
  501. const userID, tokenID, channelID = 30, 30, 30
  502. const initQuota, preConsumed = 10000, 5000
  503. const tokenRemain = 8000
  504. seedUser(t, userID, initQuota)
  505. seedToken(t, tokenID, userID, "sk-percall-adaptor", tokenRemain)
  506. seedChannel(t, channelID)
  507. task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0)
  508. task.PrivateData.BillingContext.PerCallBilling = true
  509. adaptor := &mockAdaptor{adjustReturn: 2000}
  510. taskResult := &relaycommon.TaskInfo{Status: model.TaskStatusSuccess}
  511. settleTaskBillingOnComplete(ctx, adaptor, task, taskResult)
  512. // Per-call: no adjustment despite adaptor returning 2000
  513. assert.Equal(t, initQuota, getUserQuota(t, userID))
  514. assert.Equal(t, tokenRemain, getTokenRemainQuota(t, tokenID))
  515. assert.Equal(t, preConsumed, task.Quota)
  516. assert.Equal(t, int64(0), countLogs(t))
  517. }
  518. func TestSettle_PerCallBilling_SkipsTotalTokens(t *testing.T) {
  519. truncate(t)
  520. ctx := context.Background()
  521. const userID, tokenID, channelID = 31, 31, 31
  522. const initQuota, preConsumed = 10000, 4000
  523. const tokenRemain = 7000
  524. seedUser(t, userID, initQuota)
  525. seedToken(t, tokenID, userID, "sk-percall-tokens", tokenRemain)
  526. seedChannel(t, channelID)
  527. task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0)
  528. task.PrivateData.BillingContext.PerCallBilling = true
  529. adaptor := &mockAdaptor{adjustReturn: 0}
  530. taskResult := &relaycommon.TaskInfo{Status: model.TaskStatusSuccess, TotalTokens: 9999}
  531. settleTaskBillingOnComplete(ctx, adaptor, task, taskResult)
  532. // Per-call: no recalculation by tokens
  533. assert.Equal(t, initQuota, getUserQuota(t, userID))
  534. assert.Equal(t, tokenRemain, getTokenRemainQuota(t, tokenID))
  535. assert.Equal(t, preConsumed, task.Quota)
  536. assert.Equal(t, int64(0), countLogs(t))
  537. }
  538. func TestSettle_NonPerCall_AdaptorAdjustWorks(t *testing.T) {
  539. truncate(t)
  540. ctx := context.Background()
  541. const userID, tokenID, channelID = 32, 32, 32
  542. const initQuota, preConsumed = 10000, 5000
  543. const adaptorQuota = 3000
  544. const tokenRemain = 8000
  545. seedUser(t, userID, initQuota)
  546. seedToken(t, tokenID, userID, "sk-nonpercall-adj", tokenRemain)
  547. seedChannel(t, channelID)
  548. task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0)
  549. // PerCallBilling defaults to false
  550. adaptor := &mockAdaptor{adjustReturn: adaptorQuota}
  551. taskResult := &relaycommon.TaskInfo{Status: model.TaskStatusSuccess}
  552. settleTaskBillingOnComplete(ctx, adaptor, task, taskResult)
  553. // Non-per-call: adaptor adjustment applies (refund 2000)
  554. assert.Equal(t, initQuota+(preConsumed-adaptorQuota), getUserQuota(t, userID))
  555. assert.Equal(t, tokenRemain+(preConsumed-adaptorQuota), getTokenRemainQuota(t, tokenID))
  556. assert.Equal(t, adaptorQuota, task.Quota)
  557. log := getLastLog(t)
  558. require.NotNil(t, log)
  559. assert.Equal(t, model.LogTypeRefund, log.Type)
  560. }