Просмотр исходного кода

refactor(task): enhance UpdateWithStatus for CAS updates and add integration tests

- Updated UpdateWithStatus method to use Model().Select("*").Updates() for conditional updates, preventing GORM's INSERT fallback.
- Introduced comprehensive integration tests for UpdateWithStatus, covering scenarios for winning and losing CAS updates, as well as concurrent updates.
- Added task_cas_test.go to validate the new behavior and ensure data integrity during concurrent state transitions.
CaIon 1 неделя назад
Родитель
Сommit
9976b311ef
4 измененных файлов с 831 добавлено и 2 удалено
  1. 3 1
      model/midjourney.go
  2. 5 1
      model/task.go
  3. 217 0
      model/task_cas_test.go
  4. 606 0
      service/task_billing_test.go

+ 3 - 1
model/midjourney.go

@@ -160,8 +160,10 @@ func (midjourney *Midjourney) Update() error {
 // UpdateWithStatus performs a conditional UPDATE guarded by fromStatus (CAS).
 // Returns (true, nil) if this caller won the update, (false, nil) if
 // another process already moved the task out of fromStatus.
+// UpdateWithStatus performs a conditional UPDATE guarded by fromStatus (CAS).
+// Uses Model().Select("*").Updates() to avoid GORM Save()'s INSERT fallback.
 func (midjourney *Midjourney) UpdateWithStatus(fromStatus string) (bool, error) {
-	result := DB.Where("status = ?", fromStatus).Save(midjourney)
+	result := DB.Model(midjourney).Where("status = ?", fromStatus).Select("*").Updates(midjourney)
 	if result.Error != nil {
 		return false, result.Error
 	}

+ 5 - 1
model/task.go

@@ -388,8 +388,12 @@ func (Task *Task) Update() error {
 // UpdateWithStatus performs a conditional UPDATE guarded by fromStatus (CAS).
 // Returns (true, nil) if this caller won the update, (false, nil) if
 // another process already moved the task out of fromStatus.
+//
+// Uses Model().Select("*").Updates() instead of Save() because GORM's Save
+// falls back to INSERT ON CONFLICT when the WHERE-guarded UPDATE matches
+// zero rows, which silently bypasses the CAS guard.
 func (t *Task) UpdateWithStatus(fromStatus TaskStatus) (bool, error) {
-	result := DB.Where("status = ?", fromStatus).Save(t)
+	result := DB.Model(t).Where("status = ?", fromStatus).Select("*").Updates(t)
 	if result.Error != nil {
 		return false, result.Error
 	}

+ 217 - 0
model/task_cas_test.go

@@ -0,0 +1,217 @@
+package model
+
+import (
+	"encoding/json"
+	"os"
+	"sync"
+	"testing"
+	"time"
+
+	"github.com/QuantumNous/new-api/common"
+	"github.com/glebarez/sqlite"
+	"github.com/stretchr/testify/assert"
+	"github.com/stretchr/testify/require"
+	"gorm.io/gorm"
+)
+
+func TestMain(m *testing.M) {
+	db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
+	if err != nil {
+		panic("failed to open test db: " + err.Error())
+	}
+	DB = db
+	LOG_DB = db
+
+	common.UsingSQLite = true
+	common.RedisEnabled = false
+	common.BatchUpdateEnabled = false
+	common.LogConsumeEnabled = true
+
+	sqlDB, err := db.DB()
+	if err != nil {
+		panic("failed to get sql.DB: " + err.Error())
+	}
+	sqlDB.SetMaxOpenConns(1)
+
+	if err := db.AutoMigrate(&Task{}, &User{}, &Token{}, &Log{}, &Channel{}); err != nil {
+		panic("failed to migrate: " + err.Error())
+	}
+
+	os.Exit(m.Run())
+}
+
+func truncateTables(t *testing.T) {
+	t.Helper()
+	t.Cleanup(func() {
+		DB.Exec("DELETE FROM tasks")
+		DB.Exec("DELETE FROM users")
+		DB.Exec("DELETE FROM tokens")
+		DB.Exec("DELETE FROM logs")
+		DB.Exec("DELETE FROM channels")
+	})
+}
+
+func insertTask(t *testing.T, task *Task) {
+	t.Helper()
+	task.CreatedAt = time.Now().Unix()
+	task.UpdatedAt = time.Now().Unix()
+	require.NoError(t, DB.Create(task).Error)
+}
+
+// ---------------------------------------------------------------------------
+// Snapshot / Equal — pure logic tests (no DB)
+// ---------------------------------------------------------------------------
+
+func TestSnapshotEqual_Same(t *testing.T) {
+	s := taskSnapshot{
+		Status:     TaskStatusInProgress,
+		Progress:   "50%",
+		StartTime:  1000,
+		FinishTime: 0,
+		FailReason: "",
+		ResultURL:  "",
+		Data:       json.RawMessage(`{"key":"value"}`),
+	}
+	assert.True(t, s.Equal(s))
+}
+
+func TestSnapshotEqual_DifferentStatus(t *testing.T) {
+	a := taskSnapshot{Status: TaskStatusInProgress, Data: json.RawMessage(`{}`)}
+	b := taskSnapshot{Status: TaskStatusSuccess, Data: json.RawMessage(`{}`)}
+	assert.False(t, a.Equal(b))
+}
+
+func TestSnapshotEqual_DifferentProgress(t *testing.T) {
+	a := taskSnapshot{Status: TaskStatusInProgress, Progress: "30%", Data: json.RawMessage(`{}`)}
+	b := taskSnapshot{Status: TaskStatusInProgress, Progress: "60%", Data: json.RawMessage(`{}`)}
+	assert.False(t, a.Equal(b))
+}
+
+func TestSnapshotEqual_DifferentData(t *testing.T) {
+	a := taskSnapshot{Status: TaskStatusInProgress, Data: json.RawMessage(`{"a":1}`)}
+	b := taskSnapshot{Status: TaskStatusInProgress, Data: json.RawMessage(`{"a":2}`)}
+	assert.False(t, a.Equal(b))
+}
+
+func TestSnapshotEqual_NilVsEmpty(t *testing.T) {
+	a := taskSnapshot{Status: TaskStatusInProgress, Data: nil}
+	b := taskSnapshot{Status: TaskStatusInProgress, Data: json.RawMessage{}}
+	// bytes.Equal(nil, []byte{}) == true
+	assert.True(t, a.Equal(b))
+}
+
+func TestSnapshot_Roundtrip(t *testing.T) {
+	task := &Task{
+		Status:     TaskStatusInProgress,
+		Progress:   "42%",
+		StartTime:  1234,
+		FinishTime: 5678,
+		FailReason: "timeout",
+		PrivateData: TaskPrivateData{
+			ResultURL: "https://example.com/result.mp4",
+		},
+		Data: json.RawMessage(`{"model":"test-model"}`),
+	}
+	snap := task.Snapshot()
+	assert.Equal(t, task.Status, snap.Status)
+	assert.Equal(t, task.Progress, snap.Progress)
+	assert.Equal(t, task.StartTime, snap.StartTime)
+	assert.Equal(t, task.FinishTime, snap.FinishTime)
+	assert.Equal(t, task.FailReason, snap.FailReason)
+	assert.Equal(t, task.PrivateData.ResultURL, snap.ResultURL)
+	assert.JSONEq(t, string(task.Data), string(snap.Data))
+}
+
+// ---------------------------------------------------------------------------
+// UpdateWithStatus CAS — DB integration tests
+// ---------------------------------------------------------------------------
+
+func TestUpdateWithStatus_Win(t *testing.T) {
+	truncateTables(t)
+
+	task := &Task{
+		TaskID:   "task_cas_win",
+		Status:   TaskStatusInProgress,
+		Progress: "50%",
+		Data:     json.RawMessage(`{}`),
+	}
+	insertTask(t, task)
+
+	task.Status = TaskStatusSuccess
+	task.Progress = "100%"
+	won, err := task.UpdateWithStatus(TaskStatusInProgress)
+	require.NoError(t, err)
+	assert.True(t, won)
+
+	var reloaded Task
+	require.NoError(t, DB.First(&reloaded, task.ID).Error)
+	assert.EqualValues(t, TaskStatusSuccess, reloaded.Status)
+	assert.Equal(t, "100%", reloaded.Progress)
+}
+
+func TestUpdateWithStatus_Lose(t *testing.T) {
+	truncateTables(t)
+
+	task := &Task{
+		TaskID: "task_cas_lose",
+		Status: TaskStatusFailure,
+		Data:   json.RawMessage(`{}`),
+	}
+	insertTask(t, task)
+
+	task.Status = TaskStatusSuccess
+	won, err := task.UpdateWithStatus(TaskStatusInProgress) // wrong fromStatus
+	require.NoError(t, err)
+	assert.False(t, won)
+
+	var reloaded Task
+	require.NoError(t, DB.First(&reloaded, task.ID).Error)
+	assert.EqualValues(t, TaskStatusFailure, reloaded.Status) // unchanged
+}
+
+func TestUpdateWithStatus_ConcurrentWinner(t *testing.T) {
+	truncateTables(t)
+
+	task := &Task{
+		TaskID: "task_cas_race",
+		Status: TaskStatusInProgress,
+		Quota:  1000,
+		Data:   json.RawMessage(`{}`),
+	}
+	insertTask(t, task)
+
+	const goroutines = 5
+	wins := make([]bool, goroutines)
+	var wg sync.WaitGroup
+	wg.Add(goroutines)
+
+	for i := 0; i < goroutines; i++ {
+		go func(idx int) {
+			defer wg.Done()
+			t := &Task{}
+			*t = Task{
+				ID:       task.ID,
+				TaskID:   task.TaskID,
+				Status:   TaskStatusSuccess,
+				Progress: "100%",
+				Quota:    task.Quota,
+				Data:     json.RawMessage(`{}`),
+			}
+			t.CreatedAt = task.CreatedAt
+			t.UpdatedAt = time.Now().Unix()
+			won, err := t.UpdateWithStatus(TaskStatusInProgress)
+			if err == nil {
+				wins[idx] = won
+			}
+		}(i)
+	}
+	wg.Wait()
+
+	winCount := 0
+	for _, w := range wins {
+		if w {
+			winCount++
+		}
+	}
+	assert.Equal(t, 1, winCount, "exactly one goroutine should win the CAS")
+}

+ 606 - 0
service/task_billing_test.go

@@ -0,0 +1,606 @@
+package service
+
+import (
+	"context"
+	"encoding/json"
+	"os"
+	"testing"
+	"time"
+
+	"github.com/QuantumNous/new-api/common"
+	"github.com/QuantumNous/new-api/model"
+	"github.com/glebarez/sqlite"
+	"github.com/stretchr/testify/assert"
+	"github.com/stretchr/testify/require"
+	"gorm.io/gorm"
+)
+
+func TestMain(m *testing.M) {
+	db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
+	if err != nil {
+		panic("failed to open test db: " + err.Error())
+	}
+	sqlDB, err := db.DB()
+	if err != nil {
+		panic("failed to get sql.DB: " + err.Error())
+	}
+	sqlDB.SetMaxOpenConns(1)
+
+	model.DB = db
+	model.LOG_DB = db
+
+	common.UsingSQLite = true
+	common.RedisEnabled = false
+	common.BatchUpdateEnabled = false
+	common.LogConsumeEnabled = true
+
+	if err := db.AutoMigrate(
+		&model.Task{},
+		&model.User{},
+		&model.Token{},
+		&model.Log{},
+		&model.Channel{},
+		&model.UserSubscription{},
+	); err != nil {
+		panic("failed to migrate: " + err.Error())
+	}
+
+	os.Exit(m.Run())
+}
+
+// ---------------------------------------------------------------------------
+// Seed helpers
+// ---------------------------------------------------------------------------
+
+func truncate(t *testing.T) {
+	t.Helper()
+	t.Cleanup(func() {
+		model.DB.Exec("DELETE FROM tasks")
+		model.DB.Exec("DELETE FROM users")
+		model.DB.Exec("DELETE FROM tokens")
+		model.DB.Exec("DELETE FROM logs")
+		model.DB.Exec("DELETE FROM channels")
+		model.DB.Exec("DELETE FROM user_subscriptions")
+	})
+}
+
+func seedUser(t *testing.T, id int, quota int) {
+	t.Helper()
+	user := &model.User{Id: id, Username: "test_user", Quota: quota, Status: common.UserStatusEnabled}
+	require.NoError(t, model.DB.Create(user).Error)
+}
+
+func seedToken(t *testing.T, id int, userId int, key string, remainQuota int) {
+	t.Helper()
+	token := &model.Token{
+		Id:          id,
+		UserId:      userId,
+		Key:         key,
+		Name:        "test_token",
+		Status:      common.TokenStatusEnabled,
+		RemainQuota: remainQuota,
+		UsedQuota:   0,
+	}
+	require.NoError(t, model.DB.Create(token).Error)
+}
+
+func seedSubscription(t *testing.T, id int, userId int, amountTotal int64, amountUsed int64) {
+	t.Helper()
+	sub := &model.UserSubscription{
+		Id:          id,
+		UserId:      userId,
+		AmountTotal: amountTotal,
+		AmountUsed:  amountUsed,
+		Status:      "active",
+		StartTime:   time.Now().Unix(),
+		EndTime:     time.Now().Add(30 * 24 * time.Hour).Unix(),
+	}
+	require.NoError(t, model.DB.Create(sub).Error)
+}
+
+func seedChannel(t *testing.T, id int) {
+	t.Helper()
+	ch := &model.Channel{Id: id, Name: "test_channel", Key: "sk-test", Status: common.ChannelStatusEnabled}
+	require.NoError(t, model.DB.Create(ch).Error)
+}
+
+func makeTask(userId, channelId, quota, tokenId int, billingSource string, subscriptionId int) *model.Task {
+	return &model.Task{
+		TaskID:    "task_" + time.Now().Format("150405.000"),
+		UserId:    userId,
+		ChannelId: channelId,
+		Quota:     quota,
+		Status:    model.TaskStatus(model.TaskStatusInProgress),
+		Group:     "default",
+		Data:      json.RawMessage(`{}`),
+		CreatedAt: time.Now().Unix(),
+		UpdatedAt: time.Now().Unix(),
+		Properties: model.Properties{
+			OriginModelName: "test-model",
+		},
+		PrivateData: model.TaskPrivateData{
+			BillingSource:  billingSource,
+			SubscriptionId: subscriptionId,
+			TokenId:        tokenId,
+			BillingContext: &model.TaskBillingContext{
+				ModelPrice: 0.02,
+				GroupRatio: 1.0,
+				ModelName:  "test-model",
+			},
+		},
+	}
+}
+
+// ---------------------------------------------------------------------------
+// Read-back helpers
+// ---------------------------------------------------------------------------
+
+func getUserQuota(t *testing.T, id int) int {
+	t.Helper()
+	var user model.User
+	require.NoError(t, model.DB.Select("quota").Where("id = ?", id).First(&user).Error)
+	return user.Quota
+}
+
+func getTokenRemainQuota(t *testing.T, id int) int {
+	t.Helper()
+	var token model.Token
+	require.NoError(t, model.DB.Select("remain_quota").Where("id = ?", id).First(&token).Error)
+	return token.RemainQuota
+}
+
+func getTokenUsedQuota(t *testing.T, id int) int {
+	t.Helper()
+	var token model.Token
+	require.NoError(t, model.DB.Select("used_quota").Where("id = ?", id).First(&token).Error)
+	return token.UsedQuota
+}
+
+func getSubscriptionUsed(t *testing.T, id int) int64 {
+	t.Helper()
+	var sub model.UserSubscription
+	require.NoError(t, model.DB.Select("amount_used").Where("id = ?", id).First(&sub).Error)
+	return sub.AmountUsed
+}
+
+func getLastLog(t *testing.T) *model.Log {
+	t.Helper()
+	var log model.Log
+	err := model.LOG_DB.Order("id desc").First(&log).Error
+	if err != nil {
+		return nil
+	}
+	return &log
+}
+
+func countLogs(t *testing.T) int64 {
+	t.Helper()
+	var count int64
+	model.LOG_DB.Model(&model.Log{}).Count(&count)
+	return count
+}
+
+// ===========================================================================
+// RefundTaskQuota tests
+// ===========================================================================
+
+func TestRefundTaskQuota_Wallet(t *testing.T) {
+	truncate(t)
+	ctx := context.Background()
+
+	const userID, tokenID, channelID = 1, 1, 1
+	const initQuota, preConsumed = 10000, 3000
+	const tokenRemain = 5000
+
+	seedUser(t, userID, initQuota)
+	seedToken(t, tokenID, userID, "sk-test-key", tokenRemain)
+	seedChannel(t, channelID)
+
+	task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0)
+
+	RefundTaskQuota(ctx, task, "task failed: upstream error")
+
+	// User quota should increase by preConsumed
+	assert.Equal(t, initQuota+preConsumed, getUserQuota(t, userID))
+
+	// Token remain_quota should increase, used_quota should decrease
+	assert.Equal(t, tokenRemain+preConsumed, getTokenRemainQuota(t, tokenID))
+	assert.Equal(t, -preConsumed, getTokenUsedQuota(t, tokenID))
+
+	// A refund log should be created
+	log := getLastLog(t)
+	require.NotNil(t, log)
+	assert.Equal(t, model.LogTypeRefund, log.Type)
+	assert.Equal(t, preConsumed, log.Quota)
+	assert.Equal(t, "test-model", log.ModelName)
+}
+
+func TestRefundTaskQuota_Subscription(t *testing.T) {
+	truncate(t)
+	ctx := context.Background()
+
+	const userID, tokenID, channelID, subID = 2, 2, 2, 1
+	const preConsumed = 2000
+	const subTotal, subUsed int64 = 100000, 50000
+	const tokenRemain = 8000
+
+	seedUser(t, userID, 0)
+	seedToken(t, tokenID, userID, "sk-sub-key", tokenRemain)
+	seedChannel(t, channelID)
+	seedSubscription(t, subID, userID, subTotal, subUsed)
+
+	task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceSubscription, subID)
+
+	RefundTaskQuota(ctx, task, "subscription task failed")
+
+	// Subscription used should decrease by preConsumed
+	assert.Equal(t, subUsed-int64(preConsumed), getSubscriptionUsed(t, subID))
+
+	// Token should also be refunded
+	assert.Equal(t, tokenRemain+preConsumed, getTokenRemainQuota(t, tokenID))
+
+	log := getLastLog(t)
+	require.NotNil(t, log)
+	assert.Equal(t, model.LogTypeRefund, log.Type)
+}
+
+func TestRefundTaskQuota_ZeroQuota(t *testing.T) {
+	truncate(t)
+	ctx := context.Background()
+
+	const userID = 3
+	seedUser(t, userID, 5000)
+
+	task := makeTask(userID, 0, 0, 0, BillingSourceWallet, 0)
+
+	RefundTaskQuota(ctx, task, "zero quota task")
+
+	// No change to user quota
+	assert.Equal(t, 5000, getUserQuota(t, userID))
+
+	// No log created
+	assert.Equal(t, int64(0), countLogs(t))
+}
+
+func TestRefundTaskQuota_NoToken(t *testing.T) {
+	truncate(t)
+	ctx := context.Background()
+
+	const userID, channelID = 4, 4
+	const initQuota, preConsumed = 10000, 1500
+
+	seedUser(t, userID, initQuota)
+	seedChannel(t, channelID)
+
+	task := makeTask(userID, channelID, preConsumed, 0, BillingSourceWallet, 0) // TokenId=0
+
+	RefundTaskQuota(ctx, task, "no token task failed")
+
+	// User quota refunded
+	assert.Equal(t, initQuota+preConsumed, getUserQuota(t, userID))
+
+	// Log created
+	log := getLastLog(t)
+	require.NotNil(t, log)
+	assert.Equal(t, model.LogTypeRefund, log.Type)
+}
+
+// ===========================================================================
+// RecalculateTaskQuota tests
+// ===========================================================================
+
+func TestRecalculate_PositiveDelta(t *testing.T) {
+	truncate(t)
+	ctx := context.Background()
+
+	const userID, tokenID, channelID = 10, 10, 10
+	const initQuota, preConsumed = 10000, 2000
+	const actualQuota = 3000 // under-charged by 1000
+	const tokenRemain = 5000
+
+	seedUser(t, userID, initQuota)
+	seedToken(t, tokenID, userID, "sk-recalc-pos", tokenRemain)
+	seedChannel(t, channelID)
+
+	task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0)
+
+	RecalculateTaskQuota(ctx, task, actualQuota, "adaptor adjustment")
+
+	// User quota should decrease by the delta (1000 additional charge)
+	assert.Equal(t, initQuota-(actualQuota-preConsumed), getUserQuota(t, userID))
+
+	// Token should also be charged the delta
+	assert.Equal(t, tokenRemain-(actualQuota-preConsumed), getTokenRemainQuota(t, tokenID))
+
+	// task.Quota should be updated to actualQuota
+	assert.Equal(t, actualQuota, task.Quota)
+
+	// Log type should be Consume (additional charge)
+	log := getLastLog(t)
+	require.NotNil(t, log)
+	assert.Equal(t, model.LogTypeConsume, log.Type)
+	assert.Equal(t, actualQuota-preConsumed, log.Quota)
+}
+
+func TestRecalculate_NegativeDelta(t *testing.T) {
+	truncate(t)
+	ctx := context.Background()
+
+	const userID, tokenID, channelID = 11, 11, 11
+	const initQuota, preConsumed = 10000, 5000
+	const actualQuota = 3000 // over-charged by 2000
+	const tokenRemain = 5000
+
+	seedUser(t, userID, initQuota)
+	seedToken(t, tokenID, userID, "sk-recalc-neg", tokenRemain)
+	seedChannel(t, channelID)
+
+	task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0)
+
+	RecalculateTaskQuota(ctx, task, actualQuota, "adaptor adjustment")
+
+	// User quota should increase by abs(delta) = 2000 (refund overpayment)
+	assert.Equal(t, initQuota+(preConsumed-actualQuota), getUserQuota(t, userID))
+
+	// Token should be refunded the difference
+	assert.Equal(t, tokenRemain+(preConsumed-actualQuota), getTokenRemainQuota(t, tokenID))
+
+	// task.Quota updated
+	assert.Equal(t, actualQuota, task.Quota)
+
+	// Log type should be Refund
+	log := getLastLog(t)
+	require.NotNil(t, log)
+	assert.Equal(t, model.LogTypeRefund, log.Type)
+	assert.Equal(t, preConsumed-actualQuota, log.Quota)
+}
+
+func TestRecalculate_ZeroDelta(t *testing.T) {
+	truncate(t)
+	ctx := context.Background()
+
+	const userID = 12
+	const initQuota, preConsumed = 10000, 3000
+
+	seedUser(t, userID, initQuota)
+
+	task := makeTask(userID, 0, preConsumed, 0, BillingSourceWallet, 0)
+
+	RecalculateTaskQuota(ctx, task, preConsumed, "exact match")
+
+	// No change to user quota
+	assert.Equal(t, initQuota, getUserQuota(t, userID))
+
+	// No log created (delta is zero)
+	assert.Equal(t, int64(0), countLogs(t))
+}
+
+func TestRecalculate_ActualQuotaZero(t *testing.T) {
+	truncate(t)
+	ctx := context.Background()
+
+	const userID = 13
+	const initQuota = 10000
+
+	seedUser(t, userID, initQuota)
+
+	task := makeTask(userID, 0, 5000, 0, BillingSourceWallet, 0)
+
+	RecalculateTaskQuota(ctx, task, 0, "zero actual")
+
+	// No change (early return)
+	assert.Equal(t, initQuota, getUserQuota(t, userID))
+	assert.Equal(t, int64(0), countLogs(t))
+}
+
+func TestRecalculate_Subscription_NegativeDelta(t *testing.T) {
+	truncate(t)
+	ctx := context.Background()
+
+	const userID, tokenID, channelID, subID = 14, 14, 14, 2
+	const preConsumed = 5000
+	const actualQuota = 2000 // over-charged by 3000
+	const subTotal, subUsed int64 = 100000, 50000
+	const tokenRemain = 8000
+
+	seedUser(t, userID, 0)
+	seedToken(t, tokenID, userID, "sk-sub-recalc", tokenRemain)
+	seedChannel(t, channelID)
+	seedSubscription(t, subID, userID, subTotal, subUsed)
+
+	task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceSubscription, subID)
+
+	RecalculateTaskQuota(ctx, task, actualQuota, "subscription over-charge")
+
+	// Subscription used should decrease by delta (refund 3000)
+	assert.Equal(t, subUsed-int64(preConsumed-actualQuota), getSubscriptionUsed(t, subID))
+
+	// Token refunded
+	assert.Equal(t, tokenRemain+(preConsumed-actualQuota), getTokenRemainQuota(t, tokenID))
+
+	assert.Equal(t, actualQuota, task.Quota)
+
+	log := getLastLog(t)
+	require.NotNil(t, log)
+	assert.Equal(t, model.LogTypeRefund, log.Type)
+}
+
+// ===========================================================================
+// CAS + Billing integration tests
+// Simulates the flow in updateVideoSingleTask (service/task_polling.go)
+// ===========================================================================
+
+// simulatePollBilling reproduces the CAS + billing logic from updateVideoSingleTask.
+// It takes a persisted task (already in DB), applies the new status, and performs
+// the conditional update + billing exactly as the polling loop does.
+func simulatePollBilling(ctx context.Context, task *model.Task, newStatus model.TaskStatus, actualQuota int) {
+	snap := task.Snapshot()
+
+	shouldRefund := false
+	shouldSettle := false
+	quota := task.Quota
+
+	task.Status = newStatus
+	switch string(newStatus) {
+	case model.TaskStatusSuccess:
+		task.Progress = "100%"
+		task.FinishTime = 9999
+		shouldSettle = true
+	case model.TaskStatusFailure:
+		task.Progress = "100%"
+		task.FinishTime = 9999
+		task.FailReason = "upstream error"
+		if quota != 0 {
+			shouldRefund = true
+		}
+	default:
+		task.Progress = "50%"
+	}
+
+	isDone := task.Status == model.TaskStatus(model.TaskStatusSuccess) || task.Status == model.TaskStatus(model.TaskStatusFailure)
+	if isDone && snap.Status != task.Status {
+		won, err := task.UpdateWithStatus(snap.Status)
+		if err != nil {
+			shouldRefund = false
+			shouldSettle = false
+		} else if !won {
+			shouldRefund = false
+			shouldSettle = false
+		}
+	} else if !snap.Equal(task.Snapshot()) {
+		_, _ = task.UpdateWithStatus(snap.Status)
+	}
+
+	if shouldSettle && actualQuota > 0 {
+		RecalculateTaskQuota(ctx, task, actualQuota, "test settle")
+	}
+	if shouldRefund {
+		RefundTaskQuota(ctx, task, task.FailReason)
+	}
+}
+
+func TestCASGuardedRefund_Win(t *testing.T) {
+	truncate(t)
+	ctx := context.Background()
+
+	const userID, tokenID, channelID = 20, 20, 20
+	const initQuota, preConsumed = 10000, 4000
+	const tokenRemain = 6000
+
+	seedUser(t, userID, initQuota)
+	seedToken(t, tokenID, userID, "sk-cas-refund-win", tokenRemain)
+	seedChannel(t, channelID)
+
+	task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0)
+	task.Status = model.TaskStatus(model.TaskStatusInProgress)
+	require.NoError(t, model.DB.Create(task).Error)
+
+	simulatePollBilling(ctx, task, model.TaskStatus(model.TaskStatusFailure), 0)
+
+	// CAS wins: task in DB should now be FAILURE
+	var reloaded model.Task
+	require.NoError(t, model.DB.First(&reloaded, task.ID).Error)
+	assert.EqualValues(t, model.TaskStatusFailure, reloaded.Status)
+
+	// Refund should have happened
+	assert.Equal(t, initQuota+preConsumed, getUserQuota(t, userID))
+	assert.Equal(t, tokenRemain+preConsumed, getTokenRemainQuota(t, tokenID))
+
+	log := getLastLog(t)
+	require.NotNil(t, log)
+	assert.Equal(t, model.LogTypeRefund, log.Type)
+}
+
+func TestCASGuardedRefund_Lose(t *testing.T) {
+	truncate(t)
+	ctx := context.Background()
+
+	const userID, tokenID, channelID = 21, 21, 21
+	const initQuota, preConsumed = 10000, 4000
+	const tokenRemain = 6000
+
+	seedUser(t, userID, initQuota)
+	seedToken(t, tokenID, userID, "sk-cas-refund-lose", tokenRemain)
+	seedChannel(t, channelID)
+
+	// Create task with IN_PROGRESS in DB
+	task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0)
+	task.Status = model.TaskStatus(model.TaskStatusInProgress)
+	require.NoError(t, model.DB.Create(task).Error)
+
+	// Simulate another process already transitioning to FAILURE
+	model.DB.Model(&model.Task{}).Where("id = ?", task.ID).Update("status", model.TaskStatusFailure)
+
+	// Our process still has the old in-memory state (IN_PROGRESS) and tries to transition
+	// task.Status is still IN_PROGRESS in the snapshot
+	simulatePollBilling(ctx, task, model.TaskStatus(model.TaskStatusFailure), 0)
+
+	// CAS lost: user quota should NOT change (no double refund)
+	assert.Equal(t, initQuota, getUserQuota(t, userID))
+	assert.Equal(t, tokenRemain, getTokenRemainQuota(t, tokenID))
+
+	// No billing log should be created
+	assert.Equal(t, int64(0), countLogs(t))
+}
+
+func TestCASGuardedSettle_Win(t *testing.T) {
+	truncate(t)
+	ctx := context.Background()
+
+	const userID, tokenID, channelID = 22, 22, 22
+	const initQuota, preConsumed = 10000, 5000
+	const actualQuota = 3000 // over-charged, should get partial refund
+	const tokenRemain = 8000
+
+	seedUser(t, userID, initQuota)
+	seedToken(t, tokenID, userID, "sk-cas-settle-win", tokenRemain)
+	seedChannel(t, channelID)
+
+	task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0)
+	task.Status = model.TaskStatus(model.TaskStatusInProgress)
+	require.NoError(t, model.DB.Create(task).Error)
+
+	simulatePollBilling(ctx, task, model.TaskStatus(model.TaskStatusSuccess), actualQuota)
+
+	// CAS wins: task should be SUCCESS
+	var reloaded model.Task
+	require.NoError(t, model.DB.First(&reloaded, task.ID).Error)
+	assert.EqualValues(t, model.TaskStatusSuccess, reloaded.Status)
+
+	// Settlement should refund the over-charge (5000 - 3000 = 2000 back to user)
+	assert.Equal(t, initQuota+(preConsumed-actualQuota), getUserQuota(t, userID))
+	assert.Equal(t, tokenRemain+(preConsumed-actualQuota), getTokenRemainQuota(t, tokenID))
+
+	// task.Quota should be updated to actualQuota
+	assert.Equal(t, actualQuota, task.Quota)
+}
+
+func TestNonTerminalUpdate_NoBilling(t *testing.T) {
+	truncate(t)
+	ctx := context.Background()
+
+	const userID, channelID = 23, 23
+	const initQuota, preConsumed = 10000, 3000
+
+	seedUser(t, userID, initQuota)
+	seedChannel(t, channelID)
+
+	task := makeTask(userID, channelID, preConsumed, 0, BillingSourceWallet, 0)
+	task.Status = model.TaskStatus(model.TaskStatusInProgress)
+	task.Progress = "20%"
+	require.NoError(t, model.DB.Create(task).Error)
+
+	// Simulate a non-terminal poll update (still IN_PROGRESS, progress changed)
+	simulatePollBilling(ctx, task, model.TaskStatus(model.TaskStatusInProgress), 0)
+
+	// User quota should NOT change
+	assert.Equal(t, initQuota, getUserQuota(t, userID))
+
+	// No billing log
+	assert.Equal(t, int64(0), countLogs(t))
+
+	// Task progress should be updated in DB
+	var reloaded model.Task
+	require.NoError(t, model.DB.First(&reloaded, task.ID).Error)
+	assert.Equal(t, "50%", reloaded.Progress)
+}