task_billing_test.go 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714
  1. package service
  2. import (
  3. "context"
  4. "encoding/json"
  5. "net/http"
  6. "os"
  7. "testing"
  8. "time"
  9. "github.com/QuantumNous/new-api/common"
  10. "github.com/QuantumNous/new-api/model"
  11. relaycommon "github.com/QuantumNous/new-api/relay/common"
  12. "github.com/glebarez/sqlite"
  13. "github.com/stretchr/testify/assert"
  14. "github.com/stretchr/testify/require"
  15. "gorm.io/gorm"
  16. )
  17. func TestMain(m *testing.M) {
  18. db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
  19. if err != nil {
  20. panic("failed to open test db: " + err.Error())
  21. }
  22. sqlDB, err := db.DB()
  23. if err != nil {
  24. panic("failed to get sql.DB: " + err.Error())
  25. }
  26. sqlDB.SetMaxOpenConns(1)
  27. model.DB = db
  28. model.LOG_DB = db
  29. common.UsingSQLite = true
  30. common.RedisEnabled = false
  31. common.BatchUpdateEnabled = false
  32. common.LogConsumeEnabled = true
  33. if err := db.AutoMigrate(
  34. &model.Task{},
  35. &model.User{},
  36. &model.Token{},
  37. &model.Log{},
  38. &model.Channel{},
  39. &model.UserSubscription{},
  40. ); err != nil {
  41. panic("failed to migrate: " + err.Error())
  42. }
  43. os.Exit(m.Run())
  44. }
  45. // ---------------------------------------------------------------------------
  46. // Seed helpers
  47. // ---------------------------------------------------------------------------
  48. func truncate(t *testing.T) {
  49. t.Helper()
  50. t.Cleanup(func() {
  51. model.DB.Exec("DELETE FROM tasks")
  52. model.DB.Exec("DELETE FROM users")
  53. model.DB.Exec("DELETE FROM tokens")
  54. model.DB.Exec("DELETE FROM logs")
  55. model.DB.Exec("DELETE FROM channels")
  56. model.DB.Exec("DELETE FROM user_subscriptions")
  57. })
  58. }
  59. func seedUser(t *testing.T, id int, quota int) {
  60. t.Helper()
  61. user := &model.User{Id: id, Username: "test_user", Quota: quota, Status: common.UserStatusEnabled}
  62. require.NoError(t, model.DB.Create(user).Error)
  63. }
  64. func seedToken(t *testing.T, id int, userId int, key string, remainQuota int) {
  65. t.Helper()
  66. token := &model.Token{
  67. Id: id,
  68. UserId: userId,
  69. Key: key,
  70. Name: "test_token",
  71. Status: common.TokenStatusEnabled,
  72. RemainQuota: remainQuota,
  73. UsedQuota: 0,
  74. }
  75. require.NoError(t, model.DB.Create(token).Error)
  76. }
  77. func seedSubscription(t *testing.T, id int, userId int, amountTotal int64, amountUsed int64) {
  78. t.Helper()
  79. sub := &model.UserSubscription{
  80. Id: id,
  81. UserId: userId,
  82. AmountTotal: amountTotal,
  83. AmountUsed: amountUsed,
  84. Status: "active",
  85. StartTime: time.Now().Unix(),
  86. EndTime: time.Now().Add(30 * 24 * time.Hour).Unix(),
  87. }
  88. require.NoError(t, model.DB.Create(sub).Error)
  89. }
  90. func seedChannel(t *testing.T, id int) {
  91. t.Helper()
  92. ch := &model.Channel{Id: id, Name: "test_channel", Key: "sk-test", Status: common.ChannelStatusEnabled}
  93. require.NoError(t, model.DB.Create(ch).Error)
  94. }
  95. func makeTask(userId, channelId, quota, tokenId int, billingSource string, subscriptionId int) *model.Task {
  96. return &model.Task{
  97. TaskID: "task_" + time.Now().Format("150405.000"),
  98. UserId: userId,
  99. ChannelId: channelId,
  100. Quota: quota,
  101. Status: model.TaskStatus(model.TaskStatusInProgress),
  102. Group: "default",
  103. Data: json.RawMessage(`{}`),
  104. CreatedAt: time.Now().Unix(),
  105. UpdatedAt: time.Now().Unix(),
  106. Properties: model.Properties{
  107. OriginModelName: "test-model",
  108. },
  109. PrivateData: model.TaskPrivateData{
  110. BillingSource: billingSource,
  111. SubscriptionId: subscriptionId,
  112. TokenId: tokenId,
  113. BillingContext: &model.TaskBillingContext{
  114. ModelPrice: 0.02,
  115. GroupRatio: 1.0,
  116. OriginModelName: "test-model",
  117. },
  118. },
  119. }
  120. }
  121. // ---------------------------------------------------------------------------
  122. // Read-back helpers
  123. // ---------------------------------------------------------------------------
  124. func getUserQuota(t *testing.T, id int) int {
  125. t.Helper()
  126. var user model.User
  127. require.NoError(t, model.DB.Select("quota").Where("id = ?", id).First(&user).Error)
  128. return user.Quota
  129. }
  130. func getTokenRemainQuota(t *testing.T, id int) int {
  131. t.Helper()
  132. var token model.Token
  133. require.NoError(t, model.DB.Select("remain_quota").Where("id = ?", id).First(&token).Error)
  134. return token.RemainQuota
  135. }
  136. func getTokenUsedQuota(t *testing.T, id int) int {
  137. t.Helper()
  138. var token model.Token
  139. require.NoError(t, model.DB.Select("used_quota").Where("id = ?", id).First(&token).Error)
  140. return token.UsedQuota
  141. }
  142. func getSubscriptionUsed(t *testing.T, id int) int64 {
  143. t.Helper()
  144. var sub model.UserSubscription
  145. require.NoError(t, model.DB.Select("amount_used").Where("id = ?", id).First(&sub).Error)
  146. return sub.AmountUsed
  147. }
  148. func getLastLog(t *testing.T) *model.Log {
  149. t.Helper()
  150. var log model.Log
  151. err := model.LOG_DB.Order("id desc").First(&log).Error
  152. if err != nil {
  153. return nil
  154. }
  155. return &log
  156. }
  157. func countLogs(t *testing.T) int64 {
  158. t.Helper()
  159. var count int64
  160. model.LOG_DB.Model(&model.Log{}).Count(&count)
  161. return count
  162. }
  163. // ===========================================================================
  164. // RefundTaskQuota tests
  165. // ===========================================================================
  166. func TestRefundTaskQuota_Wallet(t *testing.T) {
  167. truncate(t)
  168. ctx := context.Background()
  169. const userID, tokenID, channelID = 1, 1, 1
  170. const initQuota, preConsumed = 10000, 3000
  171. const tokenRemain = 5000
  172. seedUser(t, userID, initQuota)
  173. seedToken(t, tokenID, userID, "sk-test-key", tokenRemain)
  174. seedChannel(t, channelID)
  175. task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0)
  176. RefundTaskQuota(ctx, task, "task failed: upstream error")
  177. // User quota should increase by preConsumed
  178. assert.Equal(t, initQuota+preConsumed, getUserQuota(t, userID))
  179. // Token remain_quota should increase, used_quota should decrease
  180. assert.Equal(t, tokenRemain+preConsumed, getTokenRemainQuota(t, tokenID))
  181. assert.Equal(t, -preConsumed, getTokenUsedQuota(t, tokenID))
  182. // A refund log should be created
  183. log := getLastLog(t)
  184. require.NotNil(t, log)
  185. assert.Equal(t, model.LogTypeRefund, log.Type)
  186. assert.Equal(t, preConsumed, log.Quota)
  187. assert.Equal(t, "test-model", log.ModelName)
  188. }
  189. func TestRefundTaskQuota_Subscription(t *testing.T) {
  190. truncate(t)
  191. ctx := context.Background()
  192. const userID, tokenID, channelID, subID = 2, 2, 2, 1
  193. const preConsumed = 2000
  194. const subTotal, subUsed int64 = 100000, 50000
  195. const tokenRemain = 8000
  196. seedUser(t, userID, 0)
  197. seedToken(t, tokenID, userID, "sk-sub-key", tokenRemain)
  198. seedChannel(t, channelID)
  199. seedSubscription(t, subID, userID, subTotal, subUsed)
  200. task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceSubscription, subID)
  201. RefundTaskQuota(ctx, task, "subscription task failed")
  202. // Subscription used should decrease by preConsumed
  203. assert.Equal(t, subUsed-int64(preConsumed), getSubscriptionUsed(t, subID))
  204. // Token should also be refunded
  205. assert.Equal(t, tokenRemain+preConsumed, getTokenRemainQuota(t, tokenID))
  206. log := getLastLog(t)
  207. require.NotNil(t, log)
  208. assert.Equal(t, model.LogTypeRefund, log.Type)
  209. }
  210. func TestRefundTaskQuota_ZeroQuota(t *testing.T) {
  211. truncate(t)
  212. ctx := context.Background()
  213. const userID = 3
  214. seedUser(t, userID, 5000)
  215. task := makeTask(userID, 0, 0, 0, BillingSourceWallet, 0)
  216. RefundTaskQuota(ctx, task, "zero quota task")
  217. // No change to user quota
  218. assert.Equal(t, 5000, getUserQuota(t, userID))
  219. // No log created
  220. assert.Equal(t, int64(0), countLogs(t))
  221. }
  222. func TestRefundTaskQuota_NoToken(t *testing.T) {
  223. truncate(t)
  224. ctx := context.Background()
  225. const userID, channelID = 4, 4
  226. const initQuota, preConsumed = 10000, 1500
  227. seedUser(t, userID, initQuota)
  228. seedChannel(t, channelID)
  229. task := makeTask(userID, channelID, preConsumed, 0, BillingSourceWallet, 0) // TokenId=0
  230. RefundTaskQuota(ctx, task, "no token task failed")
  231. // User quota refunded
  232. assert.Equal(t, initQuota+preConsumed, getUserQuota(t, userID))
  233. // Log created
  234. log := getLastLog(t)
  235. require.NotNil(t, log)
  236. assert.Equal(t, model.LogTypeRefund, log.Type)
  237. }
  238. // ===========================================================================
  239. // RecalculateTaskQuota tests
  240. // ===========================================================================
  241. func TestRecalculate_PositiveDelta(t *testing.T) {
  242. truncate(t)
  243. ctx := context.Background()
  244. const userID, tokenID, channelID = 10, 10, 10
  245. const initQuota, preConsumed = 10000, 2000
  246. const actualQuota = 3000 // under-charged by 1000
  247. const tokenRemain = 5000
  248. seedUser(t, userID, initQuota)
  249. seedToken(t, tokenID, userID, "sk-recalc-pos", tokenRemain)
  250. seedChannel(t, channelID)
  251. task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0)
  252. RecalculateTaskQuota(ctx, task, actualQuota, "adaptor adjustment")
  253. // User quota should decrease by the delta (1000 additional charge)
  254. assert.Equal(t, initQuota-(actualQuota-preConsumed), getUserQuota(t, userID))
  255. // Token should also be charged the delta
  256. assert.Equal(t, tokenRemain-(actualQuota-preConsumed), getTokenRemainQuota(t, tokenID))
  257. // task.Quota should be updated to actualQuota
  258. assert.Equal(t, actualQuota, task.Quota)
  259. // Log type should be Consume (additional charge)
  260. log := getLastLog(t)
  261. require.NotNil(t, log)
  262. assert.Equal(t, model.LogTypeConsume, log.Type)
  263. assert.Equal(t, actualQuota-preConsumed, log.Quota)
  264. }
  265. func TestRecalculate_NegativeDelta(t *testing.T) {
  266. truncate(t)
  267. ctx := context.Background()
  268. const userID, tokenID, channelID = 11, 11, 11
  269. const initQuota, preConsumed = 10000, 5000
  270. const actualQuota = 3000 // over-charged by 2000
  271. const tokenRemain = 5000
  272. seedUser(t, userID, initQuota)
  273. seedToken(t, tokenID, userID, "sk-recalc-neg", tokenRemain)
  274. seedChannel(t, channelID)
  275. task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0)
  276. RecalculateTaskQuota(ctx, task, actualQuota, "adaptor adjustment")
  277. // User quota should increase by abs(delta) = 2000 (refund overpayment)
  278. assert.Equal(t, initQuota+(preConsumed-actualQuota), getUserQuota(t, userID))
  279. // Token should be refunded the difference
  280. assert.Equal(t, tokenRemain+(preConsumed-actualQuota), getTokenRemainQuota(t, tokenID))
  281. // task.Quota updated
  282. assert.Equal(t, actualQuota, task.Quota)
  283. // Log type should be Refund
  284. log := getLastLog(t)
  285. require.NotNil(t, log)
  286. assert.Equal(t, model.LogTypeRefund, log.Type)
  287. assert.Equal(t, preConsumed-actualQuota, log.Quota)
  288. }
  289. func TestRecalculate_ZeroDelta(t *testing.T) {
  290. truncate(t)
  291. ctx := context.Background()
  292. const userID = 12
  293. const initQuota, preConsumed = 10000, 3000
  294. seedUser(t, userID, initQuota)
  295. task := makeTask(userID, 0, preConsumed, 0, BillingSourceWallet, 0)
  296. RecalculateTaskQuota(ctx, task, preConsumed, "exact match")
  297. // No change to user quota
  298. assert.Equal(t, initQuota, getUserQuota(t, userID))
  299. // No log created (delta is zero)
  300. assert.Equal(t, int64(0), countLogs(t))
  301. }
  302. func TestRecalculate_ActualQuotaZero(t *testing.T) {
  303. truncate(t)
  304. ctx := context.Background()
  305. const userID = 13
  306. const initQuota = 10000
  307. seedUser(t, userID, initQuota)
  308. task := makeTask(userID, 0, 5000, 0, BillingSourceWallet, 0)
  309. RecalculateTaskQuota(ctx, task, 0, "zero actual")
  310. // No change (early return)
  311. assert.Equal(t, initQuota, getUserQuota(t, userID))
  312. assert.Equal(t, int64(0), countLogs(t))
  313. }
  314. func TestRecalculate_Subscription_NegativeDelta(t *testing.T) {
  315. truncate(t)
  316. ctx := context.Background()
  317. const userID, tokenID, channelID, subID = 14, 14, 14, 2
  318. const preConsumed = 5000
  319. const actualQuota = 2000 // over-charged by 3000
  320. const subTotal, subUsed int64 = 100000, 50000
  321. const tokenRemain = 8000
  322. seedUser(t, userID, 0)
  323. seedToken(t, tokenID, userID, "sk-sub-recalc", tokenRemain)
  324. seedChannel(t, channelID)
  325. seedSubscription(t, subID, userID, subTotal, subUsed)
  326. task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceSubscription, subID)
  327. RecalculateTaskQuota(ctx, task, actualQuota, "subscription over-charge")
  328. // Subscription used should decrease by delta (refund 3000)
  329. assert.Equal(t, subUsed-int64(preConsumed-actualQuota), getSubscriptionUsed(t, subID))
  330. // Token refunded
  331. assert.Equal(t, tokenRemain+(preConsumed-actualQuota), getTokenRemainQuota(t, tokenID))
  332. assert.Equal(t, actualQuota, task.Quota)
  333. log := getLastLog(t)
  334. require.NotNil(t, log)
  335. assert.Equal(t, model.LogTypeRefund, log.Type)
  336. }
  337. // ===========================================================================
  338. // CAS + Billing integration tests
  339. // Simulates the flow in updateVideoSingleTask (service/task_polling.go)
  340. // ===========================================================================
  341. // simulatePollBilling reproduces the CAS + billing logic from updateVideoSingleTask.
  342. // It takes a persisted task (already in DB), applies the new status, and performs
  343. // the conditional update + billing exactly as the polling loop does.
  344. func simulatePollBilling(ctx context.Context, task *model.Task, newStatus model.TaskStatus, actualQuota int) {
  345. snap := task.Snapshot()
  346. shouldRefund := false
  347. shouldSettle := false
  348. quota := task.Quota
  349. task.Status = newStatus
  350. switch string(newStatus) {
  351. case model.TaskStatusSuccess:
  352. task.Progress = "100%"
  353. task.FinishTime = 9999
  354. shouldSettle = true
  355. case model.TaskStatusFailure:
  356. task.Progress = "100%"
  357. task.FinishTime = 9999
  358. task.FailReason = "upstream error"
  359. if quota != 0 {
  360. shouldRefund = true
  361. }
  362. default:
  363. task.Progress = "50%"
  364. }
  365. isDone := task.Status == model.TaskStatus(model.TaskStatusSuccess) || task.Status == model.TaskStatus(model.TaskStatusFailure)
  366. if isDone && snap.Status != task.Status {
  367. won, err := task.UpdateWithStatus(snap.Status)
  368. if err != nil {
  369. shouldRefund = false
  370. shouldSettle = false
  371. } else if !won {
  372. shouldRefund = false
  373. shouldSettle = false
  374. }
  375. } else if !snap.Equal(task.Snapshot()) {
  376. _, _ = task.UpdateWithStatus(snap.Status)
  377. }
  378. if shouldSettle && actualQuota > 0 {
  379. RecalculateTaskQuota(ctx, task, actualQuota, "test settle")
  380. }
  381. if shouldRefund {
  382. RefundTaskQuota(ctx, task, task.FailReason)
  383. }
  384. }
  385. func TestCASGuardedRefund_Win(t *testing.T) {
  386. truncate(t)
  387. ctx := context.Background()
  388. const userID, tokenID, channelID = 20, 20, 20
  389. const initQuota, preConsumed = 10000, 4000
  390. const tokenRemain = 6000
  391. seedUser(t, userID, initQuota)
  392. seedToken(t, tokenID, userID, "sk-cas-refund-win", tokenRemain)
  393. seedChannel(t, channelID)
  394. task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0)
  395. task.Status = model.TaskStatus(model.TaskStatusInProgress)
  396. require.NoError(t, model.DB.Create(task).Error)
  397. simulatePollBilling(ctx, task, model.TaskStatus(model.TaskStatusFailure), 0)
  398. // CAS wins: task in DB should now be FAILURE
  399. var reloaded model.Task
  400. require.NoError(t, model.DB.First(&reloaded, task.ID).Error)
  401. assert.EqualValues(t, model.TaskStatusFailure, reloaded.Status)
  402. // Refund should have happened
  403. assert.Equal(t, initQuota+preConsumed, getUserQuota(t, userID))
  404. assert.Equal(t, tokenRemain+preConsumed, getTokenRemainQuota(t, tokenID))
  405. log := getLastLog(t)
  406. require.NotNil(t, log)
  407. assert.Equal(t, model.LogTypeRefund, log.Type)
  408. }
  409. func TestCASGuardedRefund_Lose(t *testing.T) {
  410. truncate(t)
  411. ctx := context.Background()
  412. const userID, tokenID, channelID = 21, 21, 21
  413. const initQuota, preConsumed = 10000, 4000
  414. const tokenRemain = 6000
  415. seedUser(t, userID, initQuota)
  416. seedToken(t, tokenID, userID, "sk-cas-refund-lose", tokenRemain)
  417. seedChannel(t, channelID)
  418. // Create task with IN_PROGRESS in DB
  419. task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0)
  420. task.Status = model.TaskStatus(model.TaskStatusInProgress)
  421. require.NoError(t, model.DB.Create(task).Error)
  422. // Simulate another process already transitioning to FAILURE
  423. model.DB.Model(&model.Task{}).Where("id = ?", task.ID).Update("status", model.TaskStatusFailure)
  424. // Our process still has the old in-memory state (IN_PROGRESS) and tries to transition
  425. // task.Status is still IN_PROGRESS in the snapshot
  426. simulatePollBilling(ctx, task, model.TaskStatus(model.TaskStatusFailure), 0)
  427. // CAS lost: user quota should NOT change (no double refund)
  428. assert.Equal(t, initQuota, getUserQuota(t, userID))
  429. assert.Equal(t, tokenRemain, getTokenRemainQuota(t, tokenID))
  430. // No billing log should be created
  431. assert.Equal(t, int64(0), countLogs(t))
  432. }
  433. func TestCASGuardedSettle_Win(t *testing.T) {
  434. truncate(t)
  435. ctx := context.Background()
  436. const userID, tokenID, channelID = 22, 22, 22
  437. const initQuota, preConsumed = 10000, 5000
  438. const actualQuota = 3000 // over-charged, should get partial refund
  439. const tokenRemain = 8000
  440. seedUser(t, userID, initQuota)
  441. seedToken(t, tokenID, userID, "sk-cas-settle-win", tokenRemain)
  442. seedChannel(t, channelID)
  443. task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0)
  444. task.Status = model.TaskStatus(model.TaskStatusInProgress)
  445. require.NoError(t, model.DB.Create(task).Error)
  446. simulatePollBilling(ctx, task, model.TaskStatus(model.TaskStatusSuccess), actualQuota)
  447. // CAS wins: task should be SUCCESS
  448. var reloaded model.Task
  449. require.NoError(t, model.DB.First(&reloaded, task.ID).Error)
  450. assert.EqualValues(t, model.TaskStatusSuccess, reloaded.Status)
  451. // Settlement should refund the over-charge (5000 - 3000 = 2000 back to user)
  452. assert.Equal(t, initQuota+(preConsumed-actualQuota), getUserQuota(t, userID))
  453. assert.Equal(t, tokenRemain+(preConsumed-actualQuota), getTokenRemainQuota(t, tokenID))
  454. // task.Quota should be updated to actualQuota
  455. assert.Equal(t, actualQuota, task.Quota)
  456. }
  457. func TestNonTerminalUpdate_NoBilling(t *testing.T) {
  458. truncate(t)
  459. ctx := context.Background()
  460. const userID, channelID = 23, 23
  461. const initQuota, preConsumed = 10000, 3000
  462. seedUser(t, userID, initQuota)
  463. seedChannel(t, channelID)
  464. task := makeTask(userID, channelID, preConsumed, 0, BillingSourceWallet, 0)
  465. task.Status = model.TaskStatus(model.TaskStatusInProgress)
  466. task.Progress = "20%"
  467. require.NoError(t, model.DB.Create(task).Error)
  468. // Simulate a non-terminal poll update (still IN_PROGRESS, progress changed)
  469. simulatePollBilling(ctx, task, model.TaskStatus(model.TaskStatusInProgress), 0)
  470. // User quota should NOT change
  471. assert.Equal(t, initQuota, getUserQuota(t, userID))
  472. // No billing log
  473. assert.Equal(t, int64(0), countLogs(t))
  474. // Task progress should be updated in DB
  475. var reloaded model.Task
  476. require.NoError(t, model.DB.First(&reloaded, task.ID).Error)
  477. assert.Equal(t, "50%", reloaded.Progress)
  478. }
  479. // ===========================================================================
  480. // Mock adaptor for settleTaskBillingOnComplete tests
  481. // ===========================================================================
  482. type mockAdaptor struct {
  483. adjustReturn int
  484. }
  485. func (m *mockAdaptor) Init(_ *relaycommon.RelayInfo) {}
  486. func (m *mockAdaptor) FetchTask(string, string, map[string]any, string) (*http.Response, error) {
  487. return nil, nil
  488. }
  489. func (m *mockAdaptor) ParseTaskResult([]byte) (*relaycommon.TaskInfo, error) { return nil, nil }
  490. func (m *mockAdaptor) AdjustBillingOnComplete(_ *model.Task, _ *relaycommon.TaskInfo) int {
  491. return m.adjustReturn
  492. }
  493. // ===========================================================================
  494. // PerCallBilling tests — settleTaskBillingOnComplete
  495. // ===========================================================================
  496. func TestSettle_PerCallBilling_SkipsAdaptorAdjust(t *testing.T) {
  497. truncate(t)
  498. ctx := context.Background()
  499. const userID, tokenID, channelID = 30, 30, 30
  500. const initQuota, preConsumed = 10000, 5000
  501. const tokenRemain = 8000
  502. seedUser(t, userID, initQuota)
  503. seedToken(t, tokenID, userID, "sk-percall-adaptor", tokenRemain)
  504. seedChannel(t, channelID)
  505. task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0)
  506. task.PrivateData.BillingContext.PerCallBilling = true
  507. adaptor := &mockAdaptor{adjustReturn: 2000}
  508. taskResult := &relaycommon.TaskInfo{Status: model.TaskStatusSuccess}
  509. settleTaskBillingOnComplete(ctx, adaptor, task, taskResult)
  510. // Per-call: no adjustment despite adaptor returning 2000
  511. assert.Equal(t, initQuota, getUserQuota(t, userID))
  512. assert.Equal(t, tokenRemain, getTokenRemainQuota(t, tokenID))
  513. assert.Equal(t, preConsumed, task.Quota)
  514. assert.Equal(t, int64(0), countLogs(t))
  515. }
  516. func TestSettle_PerCallBilling_SkipsTotalTokens(t *testing.T) {
  517. truncate(t)
  518. ctx := context.Background()
  519. const userID, tokenID, channelID = 31, 31, 31
  520. const initQuota, preConsumed = 10000, 4000
  521. const tokenRemain = 7000
  522. seedUser(t, userID, initQuota)
  523. seedToken(t, tokenID, userID, "sk-percall-tokens", tokenRemain)
  524. seedChannel(t, channelID)
  525. task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0)
  526. task.PrivateData.BillingContext.PerCallBilling = true
  527. adaptor := &mockAdaptor{adjustReturn: 0}
  528. taskResult := &relaycommon.TaskInfo{Status: model.TaskStatusSuccess, TotalTokens: 9999}
  529. settleTaskBillingOnComplete(ctx, adaptor, task, taskResult)
  530. // Per-call: no recalculation by tokens
  531. assert.Equal(t, initQuota, getUserQuota(t, userID))
  532. assert.Equal(t, tokenRemain, getTokenRemainQuota(t, tokenID))
  533. assert.Equal(t, preConsumed, task.Quota)
  534. assert.Equal(t, int64(0), countLogs(t))
  535. }
  536. func TestSettle_NonPerCall_AdaptorAdjustWorks(t *testing.T) {
  537. truncate(t)
  538. ctx := context.Background()
  539. const userID, tokenID, channelID = 32, 32, 32
  540. const initQuota, preConsumed = 10000, 5000
  541. const adaptorQuota = 3000
  542. const tokenRemain = 8000
  543. seedUser(t, userID, initQuota)
  544. seedToken(t, tokenID, userID, "sk-nonpercall-adj", tokenRemain)
  545. seedChannel(t, channelID)
  546. task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0)
  547. // PerCallBilling defaults to false
  548. adaptor := &mockAdaptor{adjustReturn: adaptorQuota}
  549. taskResult := &relaycommon.TaskInfo{Status: model.TaskStatusSuccess}
  550. settleTaskBillingOnComplete(ctx, adaptor, task, taskResult)
  551. // Non-per-call: adaptor adjustment applies (refund 2000)
  552. assert.Equal(t, initQuota+(preConsumed-adaptorQuota), getUserQuota(t, userID))
  553. assert.Equal(t, tokenRemain+(preConsumed-adaptorQuota), getTokenRemainQuota(t, tokenID))
  554. assert.Equal(t, adaptorQuota, task.Quota)
  555. log := getLastLog(t)
  556. require.NotNil(t, log)
  557. assert.Equal(t, model.LogTypeRefund, log.Type)
  558. }