package service import ( "context" "encoding/json" "net/http" "os" "testing" "time" "github.com/QuantumNous/new-api/common" "github.com/QuantumNous/new-api/model" relaycommon "github.com/QuantumNous/new-api/relay/common" "github.com/glebarez/sqlite" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "gorm.io/gorm" ) func TestMain(m *testing.M) { db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{}) if err != nil { panic("failed to open test db: " + err.Error()) } sqlDB, err := db.DB() if err != nil { panic("failed to get sql.DB: " + err.Error()) } sqlDB.SetMaxOpenConns(1) model.DB = db model.LOG_DB = db common.UsingSQLite = true common.RedisEnabled = false common.BatchUpdateEnabled = false common.LogConsumeEnabled = true if err := db.AutoMigrate( &model.Task{}, &model.User{}, &model.Token{}, &model.Log{}, &model.Channel{}, &model.UserSubscription{}, ); err != nil { panic("failed to migrate: " + err.Error()) } os.Exit(m.Run()) } // --------------------------------------------------------------------------- // Seed helpers // --------------------------------------------------------------------------- func truncate(t *testing.T) { t.Helper() t.Cleanup(func() { model.DB.Exec("DELETE FROM tasks") model.DB.Exec("DELETE FROM users") model.DB.Exec("DELETE FROM tokens") model.DB.Exec("DELETE FROM logs") model.DB.Exec("DELETE FROM channels") model.DB.Exec("DELETE FROM user_subscriptions") }) } func seedUser(t *testing.T, id int, quota int) { t.Helper() user := &model.User{Id: id, Username: "test_user", Quota: quota, Status: common.UserStatusEnabled} require.NoError(t, model.DB.Create(user).Error) } func seedToken(t *testing.T, id int, userId int, key string, remainQuota int) { t.Helper() token := &model.Token{ Id: id, UserId: userId, Key: key, Name: "test_token", Status: common.TokenStatusEnabled, RemainQuota: remainQuota, UsedQuota: 0, } require.NoError(t, model.DB.Create(token).Error) } func seedSubscription(t *testing.T, id int, userId int, amountTotal int64, amountUsed int64) { t.Helper() sub := &model.UserSubscription{ Id: id, UserId: userId, AmountTotal: amountTotal, AmountUsed: amountUsed, Status: "active", StartTime: time.Now().Unix(), EndTime: time.Now().Add(30 * 24 * time.Hour).Unix(), } require.NoError(t, model.DB.Create(sub).Error) } func seedChannel(t *testing.T, id int) { t.Helper() ch := &model.Channel{Id: id, Name: "test_channel", Key: "sk-test", Status: common.ChannelStatusEnabled} require.NoError(t, model.DB.Create(ch).Error) } func makeTask(userId, channelId, quota, tokenId int, billingSource string, subscriptionId int) *model.Task { return &model.Task{ TaskID: "task_" + time.Now().Format("150405.000"), UserId: userId, ChannelId: channelId, Quota: quota, Status: model.TaskStatus(model.TaskStatusInProgress), Group: "default", Data: json.RawMessage(`{}`), CreatedAt: time.Now().Unix(), UpdatedAt: time.Now().Unix(), Properties: model.Properties{ OriginModelName: "test-model", }, PrivateData: model.TaskPrivateData{ BillingSource: billingSource, SubscriptionId: subscriptionId, TokenId: tokenId, BillingContext: &model.TaskBillingContext{ ModelPrice: 0.02, GroupRatio: 1.0, OriginModelName: "test-model", }, }, } } // --------------------------------------------------------------------------- // Read-back helpers // --------------------------------------------------------------------------- func getUserQuota(t *testing.T, id int) int { t.Helper() var user model.User require.NoError(t, model.DB.Select("quota").Where("id = ?", id).First(&user).Error) return user.Quota } func getTokenRemainQuota(t *testing.T, id int) int { t.Helper() var token model.Token require.NoError(t, model.DB.Select("remain_quota").Where("id = ?", id).First(&token).Error) return token.RemainQuota } func getTokenUsedQuota(t *testing.T, id int) int { t.Helper() var token model.Token require.NoError(t, model.DB.Select("used_quota").Where("id = ?", id).First(&token).Error) return token.UsedQuota } func getSubscriptionUsed(t *testing.T, id int) int64 { t.Helper() var sub model.UserSubscription require.NoError(t, model.DB.Select("amount_used").Where("id = ?", id).First(&sub).Error) return sub.AmountUsed } func getLastLog(t *testing.T) *model.Log { t.Helper() var log model.Log err := model.LOG_DB.Order("id desc").First(&log).Error if err != nil { return nil } return &log } func countLogs(t *testing.T) int64 { t.Helper() var count int64 model.LOG_DB.Model(&model.Log{}).Count(&count) return count } // =========================================================================== // RefundTaskQuota tests // =========================================================================== func TestRefundTaskQuota_Wallet(t *testing.T) { truncate(t) ctx := context.Background() const userID, tokenID, channelID = 1, 1, 1 const initQuota, preConsumed = 10000, 3000 const tokenRemain = 5000 seedUser(t, userID, initQuota) seedToken(t, tokenID, userID, "sk-test-key", tokenRemain) seedChannel(t, channelID) task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0) RefundTaskQuota(ctx, task, "task failed: upstream error") // User quota should increase by preConsumed assert.Equal(t, initQuota+preConsumed, getUserQuota(t, userID)) // Token remain_quota should increase, used_quota should decrease assert.Equal(t, tokenRemain+preConsumed, getTokenRemainQuota(t, tokenID)) assert.Equal(t, -preConsumed, getTokenUsedQuota(t, tokenID)) // A refund log should be created log := getLastLog(t) require.NotNil(t, log) assert.Equal(t, model.LogTypeRefund, log.Type) assert.Equal(t, preConsumed, log.Quota) assert.Equal(t, "test-model", log.ModelName) } func TestRefundTaskQuota_Subscription(t *testing.T) { truncate(t) ctx := context.Background() const userID, tokenID, channelID, subID = 2, 2, 2, 1 const preConsumed = 2000 const subTotal, subUsed int64 = 100000, 50000 const tokenRemain = 8000 seedUser(t, userID, 0) seedToken(t, tokenID, userID, "sk-sub-key", tokenRemain) seedChannel(t, channelID) seedSubscription(t, subID, userID, subTotal, subUsed) task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceSubscription, subID) RefundTaskQuota(ctx, task, "subscription task failed") // Subscription used should decrease by preConsumed assert.Equal(t, subUsed-int64(preConsumed), getSubscriptionUsed(t, subID)) // Token should also be refunded assert.Equal(t, tokenRemain+preConsumed, getTokenRemainQuota(t, tokenID)) log := getLastLog(t) require.NotNil(t, log) assert.Equal(t, model.LogTypeRefund, log.Type) } func TestRefundTaskQuota_ZeroQuota(t *testing.T) { truncate(t) ctx := context.Background() const userID = 3 seedUser(t, userID, 5000) task := makeTask(userID, 0, 0, 0, BillingSourceWallet, 0) RefundTaskQuota(ctx, task, "zero quota task") // No change to user quota assert.Equal(t, 5000, getUserQuota(t, userID)) // No log created assert.Equal(t, int64(0), countLogs(t)) } func TestRefundTaskQuota_NoToken(t *testing.T) { truncate(t) ctx := context.Background() const userID, channelID = 4, 4 const initQuota, preConsumed = 10000, 1500 seedUser(t, userID, initQuota) seedChannel(t, channelID) task := makeTask(userID, channelID, preConsumed, 0, BillingSourceWallet, 0) // TokenId=0 RefundTaskQuota(ctx, task, "no token task failed") // User quota refunded assert.Equal(t, initQuota+preConsumed, getUserQuota(t, userID)) // Log created log := getLastLog(t) require.NotNil(t, log) assert.Equal(t, model.LogTypeRefund, log.Type) } // =========================================================================== // RecalculateTaskQuota tests // =========================================================================== func TestRecalculate_PositiveDelta(t *testing.T) { truncate(t) ctx := context.Background() const userID, tokenID, channelID = 10, 10, 10 const initQuota, preConsumed = 10000, 2000 const actualQuota = 3000 // under-charged by 1000 const tokenRemain = 5000 seedUser(t, userID, initQuota) seedToken(t, tokenID, userID, "sk-recalc-pos", tokenRemain) seedChannel(t, channelID) task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0) RecalculateTaskQuota(ctx, task, actualQuota, "adaptor adjustment") // User quota should decrease by the delta (1000 additional charge) assert.Equal(t, initQuota-(actualQuota-preConsumed), getUserQuota(t, userID)) // Token should also be charged the delta assert.Equal(t, tokenRemain-(actualQuota-preConsumed), getTokenRemainQuota(t, tokenID)) // task.Quota should be updated to actualQuota assert.Equal(t, actualQuota, task.Quota) // Log type should be Consume (additional charge) log := getLastLog(t) require.NotNil(t, log) assert.Equal(t, model.LogTypeConsume, log.Type) assert.Equal(t, actualQuota-preConsumed, log.Quota) } func TestRecalculate_NegativeDelta(t *testing.T) { truncate(t) ctx := context.Background() const userID, tokenID, channelID = 11, 11, 11 const initQuota, preConsumed = 10000, 5000 const actualQuota = 3000 // over-charged by 2000 const tokenRemain = 5000 seedUser(t, userID, initQuota) seedToken(t, tokenID, userID, "sk-recalc-neg", tokenRemain) seedChannel(t, channelID) task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0) RecalculateTaskQuota(ctx, task, actualQuota, "adaptor adjustment") // User quota should increase by abs(delta) = 2000 (refund overpayment) assert.Equal(t, initQuota+(preConsumed-actualQuota), getUserQuota(t, userID)) // Token should be refunded the difference assert.Equal(t, tokenRemain+(preConsumed-actualQuota), getTokenRemainQuota(t, tokenID)) // task.Quota updated assert.Equal(t, actualQuota, task.Quota) // Log type should be Refund log := getLastLog(t) require.NotNil(t, log) assert.Equal(t, model.LogTypeRefund, log.Type) assert.Equal(t, preConsumed-actualQuota, log.Quota) } func TestRecalculate_ZeroDelta(t *testing.T) { truncate(t) ctx := context.Background() const userID = 12 const initQuota, preConsumed = 10000, 3000 seedUser(t, userID, initQuota) task := makeTask(userID, 0, preConsumed, 0, BillingSourceWallet, 0) RecalculateTaskQuota(ctx, task, preConsumed, "exact match") // No change to user quota assert.Equal(t, initQuota, getUserQuota(t, userID)) // No log created (delta is zero) assert.Equal(t, int64(0), countLogs(t)) } func TestRecalculate_ActualQuotaZero(t *testing.T) { truncate(t) ctx := context.Background() const userID = 13 const initQuota = 10000 seedUser(t, userID, initQuota) task := makeTask(userID, 0, 5000, 0, BillingSourceWallet, 0) RecalculateTaskQuota(ctx, task, 0, "zero actual") // No change (early return) assert.Equal(t, initQuota, getUserQuota(t, userID)) assert.Equal(t, int64(0), countLogs(t)) } func TestRecalculate_Subscription_NegativeDelta(t *testing.T) { truncate(t) ctx := context.Background() const userID, tokenID, channelID, subID = 14, 14, 14, 2 const preConsumed = 5000 const actualQuota = 2000 // over-charged by 3000 const subTotal, subUsed int64 = 100000, 50000 const tokenRemain = 8000 seedUser(t, userID, 0) seedToken(t, tokenID, userID, "sk-sub-recalc", tokenRemain) seedChannel(t, channelID) seedSubscription(t, subID, userID, subTotal, subUsed) task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceSubscription, subID) RecalculateTaskQuota(ctx, task, actualQuota, "subscription over-charge") // Subscription used should decrease by delta (refund 3000) assert.Equal(t, subUsed-int64(preConsumed-actualQuota), getSubscriptionUsed(t, subID)) // Token refunded assert.Equal(t, tokenRemain+(preConsumed-actualQuota), getTokenRemainQuota(t, tokenID)) assert.Equal(t, actualQuota, task.Quota) log := getLastLog(t) require.NotNil(t, log) assert.Equal(t, model.LogTypeRefund, log.Type) } // =========================================================================== // CAS + Billing integration tests // Simulates the flow in updateVideoSingleTask (service/task_polling.go) // =========================================================================== // simulatePollBilling reproduces the CAS + billing logic from updateVideoSingleTask. // It takes a persisted task (already in DB), applies the new status, and performs // the conditional update + billing exactly as the polling loop does. func simulatePollBilling(ctx context.Context, task *model.Task, newStatus model.TaskStatus, actualQuota int) { snap := task.Snapshot() shouldRefund := false shouldSettle := false quota := task.Quota task.Status = newStatus switch string(newStatus) { case model.TaskStatusSuccess: task.Progress = "100%" task.FinishTime = 9999 shouldSettle = true case model.TaskStatusFailure: task.Progress = "100%" task.FinishTime = 9999 task.FailReason = "upstream error" if quota != 0 { shouldRefund = true } default: task.Progress = "50%" } isDone := task.Status == model.TaskStatus(model.TaskStatusSuccess) || task.Status == model.TaskStatus(model.TaskStatusFailure) if isDone && snap.Status != task.Status { won, err := task.UpdateWithStatus(snap.Status) if err != nil { shouldRefund = false shouldSettle = false } else if !won { shouldRefund = false shouldSettle = false } } else if !snap.Equal(task.Snapshot()) { _, _ = task.UpdateWithStatus(snap.Status) } if shouldSettle && actualQuota > 0 { RecalculateTaskQuota(ctx, task, actualQuota, "test settle") } if shouldRefund { RefundTaskQuota(ctx, task, task.FailReason) } } func TestCASGuardedRefund_Win(t *testing.T) { truncate(t) ctx := context.Background() const userID, tokenID, channelID = 20, 20, 20 const initQuota, preConsumed = 10000, 4000 const tokenRemain = 6000 seedUser(t, userID, initQuota) seedToken(t, tokenID, userID, "sk-cas-refund-win", tokenRemain) seedChannel(t, channelID) task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0) task.Status = model.TaskStatus(model.TaskStatusInProgress) require.NoError(t, model.DB.Create(task).Error) simulatePollBilling(ctx, task, model.TaskStatus(model.TaskStatusFailure), 0) // CAS wins: task in DB should now be FAILURE var reloaded model.Task require.NoError(t, model.DB.First(&reloaded, task.ID).Error) assert.EqualValues(t, model.TaskStatusFailure, reloaded.Status) // Refund should have happened assert.Equal(t, initQuota+preConsumed, getUserQuota(t, userID)) assert.Equal(t, tokenRemain+preConsumed, getTokenRemainQuota(t, tokenID)) log := getLastLog(t) require.NotNil(t, log) assert.Equal(t, model.LogTypeRefund, log.Type) } func TestCASGuardedRefund_Lose(t *testing.T) { truncate(t) ctx := context.Background() const userID, tokenID, channelID = 21, 21, 21 const initQuota, preConsumed = 10000, 4000 const tokenRemain = 6000 seedUser(t, userID, initQuota) seedToken(t, tokenID, userID, "sk-cas-refund-lose", tokenRemain) seedChannel(t, channelID) // Create task with IN_PROGRESS in DB task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0) task.Status = model.TaskStatus(model.TaskStatusInProgress) require.NoError(t, model.DB.Create(task).Error) // Simulate another process already transitioning to FAILURE model.DB.Model(&model.Task{}).Where("id = ?", task.ID).Update("status", model.TaskStatusFailure) // Our process still has the old in-memory state (IN_PROGRESS) and tries to transition // task.Status is still IN_PROGRESS in the snapshot simulatePollBilling(ctx, task, model.TaskStatus(model.TaskStatusFailure), 0) // CAS lost: user quota should NOT change (no double refund) assert.Equal(t, initQuota, getUserQuota(t, userID)) assert.Equal(t, tokenRemain, getTokenRemainQuota(t, tokenID)) // No billing log should be created assert.Equal(t, int64(0), countLogs(t)) } func TestCASGuardedSettle_Win(t *testing.T) { truncate(t) ctx := context.Background() const userID, tokenID, channelID = 22, 22, 22 const initQuota, preConsumed = 10000, 5000 const actualQuota = 3000 // over-charged, should get partial refund const tokenRemain = 8000 seedUser(t, userID, initQuota) seedToken(t, tokenID, userID, "sk-cas-settle-win", tokenRemain) seedChannel(t, channelID) task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0) task.Status = model.TaskStatus(model.TaskStatusInProgress) require.NoError(t, model.DB.Create(task).Error) simulatePollBilling(ctx, task, model.TaskStatus(model.TaskStatusSuccess), actualQuota) // CAS wins: task should be SUCCESS var reloaded model.Task require.NoError(t, model.DB.First(&reloaded, task.ID).Error) assert.EqualValues(t, model.TaskStatusSuccess, reloaded.Status) // Settlement should refund the over-charge (5000 - 3000 = 2000 back to user) assert.Equal(t, initQuota+(preConsumed-actualQuota), getUserQuota(t, userID)) assert.Equal(t, tokenRemain+(preConsumed-actualQuota), getTokenRemainQuota(t, tokenID)) // task.Quota should be updated to actualQuota assert.Equal(t, actualQuota, task.Quota) } func TestNonTerminalUpdate_NoBilling(t *testing.T) { truncate(t) ctx := context.Background() const userID, channelID = 23, 23 const initQuota, preConsumed = 10000, 3000 seedUser(t, userID, initQuota) seedChannel(t, channelID) task := makeTask(userID, channelID, preConsumed, 0, BillingSourceWallet, 0) task.Status = model.TaskStatus(model.TaskStatusInProgress) task.Progress = "20%" require.NoError(t, model.DB.Create(task).Error) // Simulate a non-terminal poll update (still IN_PROGRESS, progress changed) simulatePollBilling(ctx, task, model.TaskStatus(model.TaskStatusInProgress), 0) // User quota should NOT change assert.Equal(t, initQuota, getUserQuota(t, userID)) // No billing log assert.Equal(t, int64(0), countLogs(t)) // Task progress should be updated in DB var reloaded model.Task require.NoError(t, model.DB.First(&reloaded, task.ID).Error) assert.Equal(t, "50%", reloaded.Progress) } // =========================================================================== // Mock adaptor for settleTaskBillingOnComplete tests // =========================================================================== type mockAdaptor struct { adjustReturn int } func (m *mockAdaptor) Init(_ *relaycommon.RelayInfo) {} func (m *mockAdaptor) FetchTask(string, string, map[string]any, string) (*http.Response, error) { return nil, nil } func (m *mockAdaptor) ParseTaskResult([]byte) (*relaycommon.TaskInfo, error) { return nil, nil } func (m *mockAdaptor) AdjustBillingOnComplete(_ *model.Task, _ *relaycommon.TaskInfo) int { return m.adjustReturn } // =========================================================================== // PerCallBilling tests — settleTaskBillingOnComplete // =========================================================================== func TestSettle_PerCallBilling_SkipsAdaptorAdjust(t *testing.T) { truncate(t) ctx := context.Background() const userID, tokenID, channelID = 30, 30, 30 const initQuota, preConsumed = 10000, 5000 const tokenRemain = 8000 seedUser(t, userID, initQuota) seedToken(t, tokenID, userID, "sk-percall-adaptor", tokenRemain) seedChannel(t, channelID) task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0) task.PrivateData.BillingContext.PerCallBilling = true adaptor := &mockAdaptor{adjustReturn: 2000} taskResult := &relaycommon.TaskInfo{Status: model.TaskStatusSuccess} settleTaskBillingOnComplete(ctx, adaptor, task, taskResult) // Per-call: no adjustment despite adaptor returning 2000 assert.Equal(t, initQuota, getUserQuota(t, userID)) assert.Equal(t, tokenRemain, getTokenRemainQuota(t, tokenID)) assert.Equal(t, preConsumed, task.Quota) assert.Equal(t, int64(0), countLogs(t)) } func TestSettle_PerCallBilling_SkipsTotalTokens(t *testing.T) { truncate(t) ctx := context.Background() const userID, tokenID, channelID = 31, 31, 31 const initQuota, preConsumed = 10000, 4000 const tokenRemain = 7000 seedUser(t, userID, initQuota) seedToken(t, tokenID, userID, "sk-percall-tokens", tokenRemain) seedChannel(t, channelID) task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0) task.PrivateData.BillingContext.PerCallBilling = true adaptor := &mockAdaptor{adjustReturn: 0} taskResult := &relaycommon.TaskInfo{Status: model.TaskStatusSuccess, TotalTokens: 9999} settleTaskBillingOnComplete(ctx, adaptor, task, taskResult) // Per-call: no recalculation by tokens assert.Equal(t, initQuota, getUserQuota(t, userID)) assert.Equal(t, tokenRemain, getTokenRemainQuota(t, tokenID)) assert.Equal(t, preConsumed, task.Quota) assert.Equal(t, int64(0), countLogs(t)) } func TestSettle_NonPerCall_AdaptorAdjustWorks(t *testing.T) { truncate(t) ctx := context.Background() const userID, tokenID, channelID = 32, 32, 32 const initQuota, preConsumed = 10000, 5000 const adaptorQuota = 3000 const tokenRemain = 8000 seedUser(t, userID, initQuota) seedToken(t, tokenID, userID, "sk-nonpercall-adj", tokenRemain) seedChannel(t, channelID) task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0) // PerCallBilling defaults to false adaptor := &mockAdaptor{adjustReturn: adaptorQuota} taskResult := &relaycommon.TaskInfo{Status: model.TaskStatusSuccess} settleTaskBillingOnComplete(ctx, adaptor, task, taskResult) // Non-per-call: adaptor adjustment applies (refund 2000) assert.Equal(t, initQuota+(preConsumed-adaptorQuota), getUserQuota(t, userID)) assert.Equal(t, tokenRemain+(preConsumed-adaptorQuota), getTokenRemainQuota(t, tokenID)) assert.Equal(t, adaptorQuota, task.Quota) log := getLastLog(t) require.NotNil(t, log) assert.Equal(t, model.LogTypeRefund, log.Type) }