| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606 |
- package service
- import (
- "context"
- "encoding/json"
- "os"
- "testing"
- "time"
- "github.com/QuantumNous/new-api/common"
- "github.com/QuantumNous/new-api/model"
- "github.com/glebarez/sqlite"
- "github.com/stretchr/testify/assert"
- "github.com/stretchr/testify/require"
- "gorm.io/gorm"
- )
- func TestMain(m *testing.M) {
- db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
- if err != nil {
- panic("failed to open test db: " + err.Error())
- }
- sqlDB, err := db.DB()
- if err != nil {
- panic("failed to get sql.DB: " + err.Error())
- }
- sqlDB.SetMaxOpenConns(1)
- model.DB = db
- model.LOG_DB = db
- common.UsingSQLite = true
- common.RedisEnabled = false
- common.BatchUpdateEnabled = false
- common.LogConsumeEnabled = true
- if err := db.AutoMigrate(
- &model.Task{},
- &model.User{},
- &model.Token{},
- &model.Log{},
- &model.Channel{},
- &model.UserSubscription{},
- ); err != nil {
- panic("failed to migrate: " + err.Error())
- }
- os.Exit(m.Run())
- }
- // ---------------------------------------------------------------------------
- // Seed helpers
- // ---------------------------------------------------------------------------
- func truncate(t *testing.T) {
- t.Helper()
- t.Cleanup(func() {
- model.DB.Exec("DELETE FROM tasks")
- model.DB.Exec("DELETE FROM users")
- model.DB.Exec("DELETE FROM tokens")
- model.DB.Exec("DELETE FROM logs")
- model.DB.Exec("DELETE FROM channels")
- model.DB.Exec("DELETE FROM user_subscriptions")
- })
- }
- func seedUser(t *testing.T, id int, quota int) {
- t.Helper()
- user := &model.User{Id: id, Username: "test_user", Quota: quota, Status: common.UserStatusEnabled}
- require.NoError(t, model.DB.Create(user).Error)
- }
- func seedToken(t *testing.T, id int, userId int, key string, remainQuota int) {
- t.Helper()
- token := &model.Token{
- Id: id,
- UserId: userId,
- Key: key,
- Name: "test_token",
- Status: common.TokenStatusEnabled,
- RemainQuota: remainQuota,
- UsedQuota: 0,
- }
- require.NoError(t, model.DB.Create(token).Error)
- }
- func seedSubscription(t *testing.T, id int, userId int, amountTotal int64, amountUsed int64) {
- t.Helper()
- sub := &model.UserSubscription{
- Id: id,
- UserId: userId,
- AmountTotal: amountTotal,
- AmountUsed: amountUsed,
- Status: "active",
- StartTime: time.Now().Unix(),
- EndTime: time.Now().Add(30 * 24 * time.Hour).Unix(),
- }
- require.NoError(t, model.DB.Create(sub).Error)
- }
- func seedChannel(t *testing.T, id int) {
- t.Helper()
- ch := &model.Channel{Id: id, Name: "test_channel", Key: "sk-test", Status: common.ChannelStatusEnabled}
- require.NoError(t, model.DB.Create(ch).Error)
- }
- func makeTask(userId, channelId, quota, tokenId int, billingSource string, subscriptionId int) *model.Task {
- return &model.Task{
- TaskID: "task_" + time.Now().Format("150405.000"),
- UserId: userId,
- ChannelId: channelId,
- Quota: quota,
- Status: model.TaskStatus(model.TaskStatusInProgress),
- Group: "default",
- Data: json.RawMessage(`{}`),
- CreatedAt: time.Now().Unix(),
- UpdatedAt: time.Now().Unix(),
- Properties: model.Properties{
- OriginModelName: "test-model",
- },
- PrivateData: model.TaskPrivateData{
- BillingSource: billingSource,
- SubscriptionId: subscriptionId,
- TokenId: tokenId,
- BillingContext: &model.TaskBillingContext{
- ModelPrice: 0.02,
- GroupRatio: 1.0,
- ModelName: "test-model",
- },
- },
- }
- }
- // ---------------------------------------------------------------------------
- // Read-back helpers
- // ---------------------------------------------------------------------------
- func getUserQuota(t *testing.T, id int) int {
- t.Helper()
- var user model.User
- require.NoError(t, model.DB.Select("quota").Where("id = ?", id).First(&user).Error)
- return user.Quota
- }
- func getTokenRemainQuota(t *testing.T, id int) int {
- t.Helper()
- var token model.Token
- require.NoError(t, model.DB.Select("remain_quota").Where("id = ?", id).First(&token).Error)
- return token.RemainQuota
- }
- func getTokenUsedQuota(t *testing.T, id int) int {
- t.Helper()
- var token model.Token
- require.NoError(t, model.DB.Select("used_quota").Where("id = ?", id).First(&token).Error)
- return token.UsedQuota
- }
- func getSubscriptionUsed(t *testing.T, id int) int64 {
- t.Helper()
- var sub model.UserSubscription
- require.NoError(t, model.DB.Select("amount_used").Where("id = ?", id).First(&sub).Error)
- return sub.AmountUsed
- }
- func getLastLog(t *testing.T) *model.Log {
- t.Helper()
- var log model.Log
- err := model.LOG_DB.Order("id desc").First(&log).Error
- if err != nil {
- return nil
- }
- return &log
- }
- func countLogs(t *testing.T) int64 {
- t.Helper()
- var count int64
- model.LOG_DB.Model(&model.Log{}).Count(&count)
- return count
- }
- // ===========================================================================
- // RefundTaskQuota tests
- // ===========================================================================
- func TestRefundTaskQuota_Wallet(t *testing.T) {
- truncate(t)
- ctx := context.Background()
- const userID, tokenID, channelID = 1, 1, 1
- const initQuota, preConsumed = 10000, 3000
- const tokenRemain = 5000
- seedUser(t, userID, initQuota)
- seedToken(t, tokenID, userID, "sk-test-key", tokenRemain)
- seedChannel(t, channelID)
- task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0)
- RefundTaskQuota(ctx, task, "task failed: upstream error")
- // User quota should increase by preConsumed
- assert.Equal(t, initQuota+preConsumed, getUserQuota(t, userID))
- // Token remain_quota should increase, used_quota should decrease
- assert.Equal(t, tokenRemain+preConsumed, getTokenRemainQuota(t, tokenID))
- assert.Equal(t, -preConsumed, getTokenUsedQuota(t, tokenID))
- // A refund log should be created
- log := getLastLog(t)
- require.NotNil(t, log)
- assert.Equal(t, model.LogTypeRefund, log.Type)
- assert.Equal(t, preConsumed, log.Quota)
- assert.Equal(t, "test-model", log.ModelName)
- }
- func TestRefundTaskQuota_Subscription(t *testing.T) {
- truncate(t)
- ctx := context.Background()
- const userID, tokenID, channelID, subID = 2, 2, 2, 1
- const preConsumed = 2000
- const subTotal, subUsed int64 = 100000, 50000
- const tokenRemain = 8000
- seedUser(t, userID, 0)
- seedToken(t, tokenID, userID, "sk-sub-key", tokenRemain)
- seedChannel(t, channelID)
- seedSubscription(t, subID, userID, subTotal, subUsed)
- task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceSubscription, subID)
- RefundTaskQuota(ctx, task, "subscription task failed")
- // Subscription used should decrease by preConsumed
- assert.Equal(t, subUsed-int64(preConsumed), getSubscriptionUsed(t, subID))
- // Token should also be refunded
- assert.Equal(t, tokenRemain+preConsumed, getTokenRemainQuota(t, tokenID))
- log := getLastLog(t)
- require.NotNil(t, log)
- assert.Equal(t, model.LogTypeRefund, log.Type)
- }
- func TestRefundTaskQuota_ZeroQuota(t *testing.T) {
- truncate(t)
- ctx := context.Background()
- const userID = 3
- seedUser(t, userID, 5000)
- task := makeTask(userID, 0, 0, 0, BillingSourceWallet, 0)
- RefundTaskQuota(ctx, task, "zero quota task")
- // No change to user quota
- assert.Equal(t, 5000, getUserQuota(t, userID))
- // No log created
- assert.Equal(t, int64(0), countLogs(t))
- }
- func TestRefundTaskQuota_NoToken(t *testing.T) {
- truncate(t)
- ctx := context.Background()
- const userID, channelID = 4, 4
- const initQuota, preConsumed = 10000, 1500
- seedUser(t, userID, initQuota)
- seedChannel(t, channelID)
- task := makeTask(userID, channelID, preConsumed, 0, BillingSourceWallet, 0) // TokenId=0
- RefundTaskQuota(ctx, task, "no token task failed")
- // User quota refunded
- assert.Equal(t, initQuota+preConsumed, getUserQuota(t, userID))
- // Log created
- log := getLastLog(t)
- require.NotNil(t, log)
- assert.Equal(t, model.LogTypeRefund, log.Type)
- }
- // ===========================================================================
- // RecalculateTaskQuota tests
- // ===========================================================================
- func TestRecalculate_PositiveDelta(t *testing.T) {
- truncate(t)
- ctx := context.Background()
- const userID, tokenID, channelID = 10, 10, 10
- const initQuota, preConsumed = 10000, 2000
- const actualQuota = 3000 // under-charged by 1000
- const tokenRemain = 5000
- seedUser(t, userID, initQuota)
- seedToken(t, tokenID, userID, "sk-recalc-pos", tokenRemain)
- seedChannel(t, channelID)
- task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0)
- RecalculateTaskQuota(ctx, task, actualQuota, "adaptor adjustment")
- // User quota should decrease by the delta (1000 additional charge)
- assert.Equal(t, initQuota-(actualQuota-preConsumed), getUserQuota(t, userID))
- // Token should also be charged the delta
- assert.Equal(t, tokenRemain-(actualQuota-preConsumed), getTokenRemainQuota(t, tokenID))
- // task.Quota should be updated to actualQuota
- assert.Equal(t, actualQuota, task.Quota)
- // Log type should be Consume (additional charge)
- log := getLastLog(t)
- require.NotNil(t, log)
- assert.Equal(t, model.LogTypeConsume, log.Type)
- assert.Equal(t, actualQuota-preConsumed, log.Quota)
- }
- func TestRecalculate_NegativeDelta(t *testing.T) {
- truncate(t)
- ctx := context.Background()
- const userID, tokenID, channelID = 11, 11, 11
- const initQuota, preConsumed = 10000, 5000
- const actualQuota = 3000 // over-charged by 2000
- const tokenRemain = 5000
- seedUser(t, userID, initQuota)
- seedToken(t, tokenID, userID, "sk-recalc-neg", tokenRemain)
- seedChannel(t, channelID)
- task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0)
- RecalculateTaskQuota(ctx, task, actualQuota, "adaptor adjustment")
- // User quota should increase by abs(delta) = 2000 (refund overpayment)
- assert.Equal(t, initQuota+(preConsumed-actualQuota), getUserQuota(t, userID))
- // Token should be refunded the difference
- assert.Equal(t, tokenRemain+(preConsumed-actualQuota), getTokenRemainQuota(t, tokenID))
- // task.Quota updated
- assert.Equal(t, actualQuota, task.Quota)
- // Log type should be Refund
- log := getLastLog(t)
- require.NotNil(t, log)
- assert.Equal(t, model.LogTypeRefund, log.Type)
- assert.Equal(t, preConsumed-actualQuota, log.Quota)
- }
- func TestRecalculate_ZeroDelta(t *testing.T) {
- truncate(t)
- ctx := context.Background()
- const userID = 12
- const initQuota, preConsumed = 10000, 3000
- seedUser(t, userID, initQuota)
- task := makeTask(userID, 0, preConsumed, 0, BillingSourceWallet, 0)
- RecalculateTaskQuota(ctx, task, preConsumed, "exact match")
- // No change to user quota
- assert.Equal(t, initQuota, getUserQuota(t, userID))
- // No log created (delta is zero)
- assert.Equal(t, int64(0), countLogs(t))
- }
- func TestRecalculate_ActualQuotaZero(t *testing.T) {
- truncate(t)
- ctx := context.Background()
- const userID = 13
- const initQuota = 10000
- seedUser(t, userID, initQuota)
- task := makeTask(userID, 0, 5000, 0, BillingSourceWallet, 0)
- RecalculateTaskQuota(ctx, task, 0, "zero actual")
- // No change (early return)
- assert.Equal(t, initQuota, getUserQuota(t, userID))
- assert.Equal(t, int64(0), countLogs(t))
- }
- func TestRecalculate_Subscription_NegativeDelta(t *testing.T) {
- truncate(t)
- ctx := context.Background()
- const userID, tokenID, channelID, subID = 14, 14, 14, 2
- const preConsumed = 5000
- const actualQuota = 2000 // over-charged by 3000
- const subTotal, subUsed int64 = 100000, 50000
- const tokenRemain = 8000
- seedUser(t, userID, 0)
- seedToken(t, tokenID, userID, "sk-sub-recalc", tokenRemain)
- seedChannel(t, channelID)
- seedSubscription(t, subID, userID, subTotal, subUsed)
- task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceSubscription, subID)
- RecalculateTaskQuota(ctx, task, actualQuota, "subscription over-charge")
- // Subscription used should decrease by delta (refund 3000)
- assert.Equal(t, subUsed-int64(preConsumed-actualQuota), getSubscriptionUsed(t, subID))
- // Token refunded
- assert.Equal(t, tokenRemain+(preConsumed-actualQuota), getTokenRemainQuota(t, tokenID))
- assert.Equal(t, actualQuota, task.Quota)
- log := getLastLog(t)
- require.NotNil(t, log)
- assert.Equal(t, model.LogTypeRefund, log.Type)
- }
- // ===========================================================================
- // CAS + Billing integration tests
- // Simulates the flow in updateVideoSingleTask (service/task_polling.go)
- // ===========================================================================
- // simulatePollBilling reproduces the CAS + billing logic from updateVideoSingleTask.
- // It takes a persisted task (already in DB), applies the new status, and performs
- // the conditional update + billing exactly as the polling loop does.
- func simulatePollBilling(ctx context.Context, task *model.Task, newStatus model.TaskStatus, actualQuota int) {
- snap := task.Snapshot()
- shouldRefund := false
- shouldSettle := false
- quota := task.Quota
- task.Status = newStatus
- switch string(newStatus) {
- case model.TaskStatusSuccess:
- task.Progress = "100%"
- task.FinishTime = 9999
- shouldSettle = true
- case model.TaskStatusFailure:
- task.Progress = "100%"
- task.FinishTime = 9999
- task.FailReason = "upstream error"
- if quota != 0 {
- shouldRefund = true
- }
- default:
- task.Progress = "50%"
- }
- isDone := task.Status == model.TaskStatus(model.TaskStatusSuccess) || task.Status == model.TaskStatus(model.TaskStatusFailure)
- if isDone && snap.Status != task.Status {
- won, err := task.UpdateWithStatus(snap.Status)
- if err != nil {
- shouldRefund = false
- shouldSettle = false
- } else if !won {
- shouldRefund = false
- shouldSettle = false
- }
- } else if !snap.Equal(task.Snapshot()) {
- _, _ = task.UpdateWithStatus(snap.Status)
- }
- if shouldSettle && actualQuota > 0 {
- RecalculateTaskQuota(ctx, task, actualQuota, "test settle")
- }
- if shouldRefund {
- RefundTaskQuota(ctx, task, task.FailReason)
- }
- }
- func TestCASGuardedRefund_Win(t *testing.T) {
- truncate(t)
- ctx := context.Background()
- const userID, tokenID, channelID = 20, 20, 20
- const initQuota, preConsumed = 10000, 4000
- const tokenRemain = 6000
- seedUser(t, userID, initQuota)
- seedToken(t, tokenID, userID, "sk-cas-refund-win", tokenRemain)
- seedChannel(t, channelID)
- task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0)
- task.Status = model.TaskStatus(model.TaskStatusInProgress)
- require.NoError(t, model.DB.Create(task).Error)
- simulatePollBilling(ctx, task, model.TaskStatus(model.TaskStatusFailure), 0)
- // CAS wins: task in DB should now be FAILURE
- var reloaded model.Task
- require.NoError(t, model.DB.First(&reloaded, task.ID).Error)
- assert.EqualValues(t, model.TaskStatusFailure, reloaded.Status)
- // Refund should have happened
- assert.Equal(t, initQuota+preConsumed, getUserQuota(t, userID))
- assert.Equal(t, tokenRemain+preConsumed, getTokenRemainQuota(t, tokenID))
- log := getLastLog(t)
- require.NotNil(t, log)
- assert.Equal(t, model.LogTypeRefund, log.Type)
- }
- func TestCASGuardedRefund_Lose(t *testing.T) {
- truncate(t)
- ctx := context.Background()
- const userID, tokenID, channelID = 21, 21, 21
- const initQuota, preConsumed = 10000, 4000
- const tokenRemain = 6000
- seedUser(t, userID, initQuota)
- seedToken(t, tokenID, userID, "sk-cas-refund-lose", tokenRemain)
- seedChannel(t, channelID)
- // Create task with IN_PROGRESS in DB
- task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0)
- task.Status = model.TaskStatus(model.TaskStatusInProgress)
- require.NoError(t, model.DB.Create(task).Error)
- // Simulate another process already transitioning to FAILURE
- model.DB.Model(&model.Task{}).Where("id = ?", task.ID).Update("status", model.TaskStatusFailure)
- // Our process still has the old in-memory state (IN_PROGRESS) and tries to transition
- // task.Status is still IN_PROGRESS in the snapshot
- simulatePollBilling(ctx, task, model.TaskStatus(model.TaskStatusFailure), 0)
- // CAS lost: user quota should NOT change (no double refund)
- assert.Equal(t, initQuota, getUserQuota(t, userID))
- assert.Equal(t, tokenRemain, getTokenRemainQuota(t, tokenID))
- // No billing log should be created
- assert.Equal(t, int64(0), countLogs(t))
- }
- func TestCASGuardedSettle_Win(t *testing.T) {
- truncate(t)
- ctx := context.Background()
- const userID, tokenID, channelID = 22, 22, 22
- const initQuota, preConsumed = 10000, 5000
- const actualQuota = 3000 // over-charged, should get partial refund
- const tokenRemain = 8000
- seedUser(t, userID, initQuota)
- seedToken(t, tokenID, userID, "sk-cas-settle-win", tokenRemain)
- seedChannel(t, channelID)
- task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0)
- task.Status = model.TaskStatus(model.TaskStatusInProgress)
- require.NoError(t, model.DB.Create(task).Error)
- simulatePollBilling(ctx, task, model.TaskStatus(model.TaskStatusSuccess), actualQuota)
- // CAS wins: task should be SUCCESS
- var reloaded model.Task
- require.NoError(t, model.DB.First(&reloaded, task.ID).Error)
- assert.EqualValues(t, model.TaskStatusSuccess, reloaded.Status)
- // Settlement should refund the over-charge (5000 - 3000 = 2000 back to user)
- assert.Equal(t, initQuota+(preConsumed-actualQuota), getUserQuota(t, userID))
- assert.Equal(t, tokenRemain+(preConsumed-actualQuota), getTokenRemainQuota(t, tokenID))
- // task.Quota should be updated to actualQuota
- assert.Equal(t, actualQuota, task.Quota)
- }
- func TestNonTerminalUpdate_NoBilling(t *testing.T) {
- truncate(t)
- ctx := context.Background()
- const userID, channelID = 23, 23
- const initQuota, preConsumed = 10000, 3000
- seedUser(t, userID, initQuota)
- seedChannel(t, channelID)
- task := makeTask(userID, channelID, preConsumed, 0, BillingSourceWallet, 0)
- task.Status = model.TaskStatus(model.TaskStatusInProgress)
- task.Progress = "20%"
- require.NoError(t, model.DB.Create(task).Error)
- // Simulate a non-terminal poll update (still IN_PROGRESS, progress changed)
- simulatePollBilling(ctx, task, model.TaskStatus(model.TaskStatusInProgress), 0)
- // User quota should NOT change
- assert.Equal(t, initQuota, getUserQuota(t, userID))
- // No billing log
- assert.Equal(t, int64(0), countLogs(t))
- // Task progress should be updated in DB
- var reloaded model.Task
- require.NoError(t, model.DB.First(&reloaded, task.ID).Error)
- assert.Equal(t, "50%", reloaded.Progress)
- }
|