|
|
@@ -3,12 +3,14 @@ package service
|
|
|
import (
|
|
|
"context"
|
|
|
"encoding/json"
|
|
|
+ "net/http"
|
|
|
"os"
|
|
|
"testing"
|
|
|
"time"
|
|
|
|
|
|
"github.com/QuantumNous/new-api/common"
|
|
|
"github.com/QuantumNous/new-api/model"
|
|
|
+ relaycommon "github.com/QuantumNous/new-api/relay/common"
|
|
|
"github.com/glebarez/sqlite"
|
|
|
"github.com/stretchr/testify/assert"
|
|
|
"github.com/stretchr/testify/require"
|
|
|
@@ -125,7 +127,7 @@ func makeTask(userId, channelId, quota, tokenId int, billingSource string, subsc
|
|
|
BillingContext: &model.TaskBillingContext{
|
|
|
ModelPrice: 0.02,
|
|
|
GroupRatio: 1.0,
|
|
|
- ModelName: "test-model",
|
|
|
+ OriginModelName: "test-model",
|
|
|
},
|
|
|
},
|
|
|
}
|
|
|
@@ -604,3 +606,107 @@ func TestNonTerminalUpdate_NoBilling(t *testing.T) {
|
|
|
require.NoError(t, model.DB.First(&reloaded, task.ID).Error)
|
|
|
assert.Equal(t, "50%", reloaded.Progress)
|
|
|
}
|
|
|
+
|
|
|
+// ===========================================================================
|
|
|
+// Mock adaptor for settleTaskBillingOnComplete tests
|
|
|
+// ===========================================================================
|
|
|
+
|
|
|
+type mockAdaptor struct {
|
|
|
+ adjustReturn int
|
|
|
+}
|
|
|
+
|
|
|
+func (m *mockAdaptor) Init(_ *relaycommon.RelayInfo) {}
|
|
|
+func (m *mockAdaptor) FetchTask(string, string, map[string]any, string) (*http.Response, error) { return nil, nil }
|
|
|
+func (m *mockAdaptor) ParseTaskResult([]byte) (*relaycommon.TaskInfo, error) { return nil, nil }
|
|
|
+func (m *mockAdaptor) AdjustBillingOnComplete(_ *model.Task, _ *relaycommon.TaskInfo) int {
|
|
|
+ return m.adjustReturn
|
|
|
+}
|
|
|
+
|
|
|
+// ===========================================================================
|
|
|
+// PerCallBilling tests — settleTaskBillingOnComplete
|
|
|
+// ===========================================================================
|
|
|
+
|
|
|
+func TestSettle_PerCallBilling_SkipsAdaptorAdjust(t *testing.T) {
|
|
|
+ truncate(t)
|
|
|
+ ctx := context.Background()
|
|
|
+
|
|
|
+ const userID, tokenID, channelID = 30, 30, 30
|
|
|
+ const initQuota, preConsumed = 10000, 5000
|
|
|
+ const tokenRemain = 8000
|
|
|
+
|
|
|
+ seedUser(t, userID, initQuota)
|
|
|
+ seedToken(t, tokenID, userID, "sk-percall-adaptor", tokenRemain)
|
|
|
+ seedChannel(t, channelID)
|
|
|
+
|
|
|
+ task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0)
|
|
|
+ task.PrivateData.BillingContext.PerCallBilling = true
|
|
|
+
|
|
|
+ adaptor := &mockAdaptor{adjustReturn: 2000}
|
|
|
+ taskResult := &relaycommon.TaskInfo{Status: model.TaskStatusSuccess}
|
|
|
+
|
|
|
+ settleTaskBillingOnComplete(ctx, adaptor, task, taskResult)
|
|
|
+
|
|
|
+ // Per-call: no adjustment despite adaptor returning 2000
|
|
|
+ assert.Equal(t, initQuota, getUserQuota(t, userID))
|
|
|
+ assert.Equal(t, tokenRemain, getTokenRemainQuota(t, tokenID))
|
|
|
+ assert.Equal(t, preConsumed, task.Quota)
|
|
|
+ assert.Equal(t, int64(0), countLogs(t))
|
|
|
+}
|
|
|
+
|
|
|
+func TestSettle_PerCallBilling_SkipsTotalTokens(t *testing.T) {
|
|
|
+ truncate(t)
|
|
|
+ ctx := context.Background()
|
|
|
+
|
|
|
+ const userID, tokenID, channelID = 31, 31, 31
|
|
|
+ const initQuota, preConsumed = 10000, 4000
|
|
|
+ const tokenRemain = 7000
|
|
|
+
|
|
|
+ seedUser(t, userID, initQuota)
|
|
|
+ seedToken(t, tokenID, userID, "sk-percall-tokens", tokenRemain)
|
|
|
+ seedChannel(t, channelID)
|
|
|
+
|
|
|
+ task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0)
|
|
|
+ task.PrivateData.BillingContext.PerCallBilling = true
|
|
|
+
|
|
|
+ adaptor := &mockAdaptor{adjustReturn: 0}
|
|
|
+ taskResult := &relaycommon.TaskInfo{Status: model.TaskStatusSuccess, TotalTokens: 9999}
|
|
|
+
|
|
|
+ settleTaskBillingOnComplete(ctx, adaptor, task, taskResult)
|
|
|
+
|
|
|
+ // Per-call: no recalculation by tokens
|
|
|
+ assert.Equal(t, initQuota, getUserQuota(t, userID))
|
|
|
+ assert.Equal(t, tokenRemain, getTokenRemainQuota(t, tokenID))
|
|
|
+ assert.Equal(t, preConsumed, task.Quota)
|
|
|
+ assert.Equal(t, int64(0), countLogs(t))
|
|
|
+}
|
|
|
+
|
|
|
+func TestSettle_NonPerCall_AdaptorAdjustWorks(t *testing.T) {
|
|
|
+ truncate(t)
|
|
|
+ ctx := context.Background()
|
|
|
+
|
|
|
+ const userID, tokenID, channelID = 32, 32, 32
|
|
|
+ const initQuota, preConsumed = 10000, 5000
|
|
|
+ const adaptorQuota = 3000
|
|
|
+ const tokenRemain = 8000
|
|
|
+
|
|
|
+ seedUser(t, userID, initQuota)
|
|
|
+ seedToken(t, tokenID, userID, "sk-nonpercall-adj", tokenRemain)
|
|
|
+ seedChannel(t, channelID)
|
|
|
+
|
|
|
+ task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0)
|
|
|
+ // PerCallBilling defaults to false
|
|
|
+
|
|
|
+ adaptor := &mockAdaptor{adjustReturn: adaptorQuota}
|
|
|
+ taskResult := &relaycommon.TaskInfo{Status: model.TaskStatusSuccess}
|
|
|
+
|
|
|
+ settleTaskBillingOnComplete(ctx, adaptor, task, taskResult)
|
|
|
+
|
|
|
+ // Non-per-call: adaptor adjustment applies (refund 2000)
|
|
|
+ assert.Equal(t, initQuota+(preConsumed-adaptorQuota), getUserQuota(t, userID))
|
|
|
+ assert.Equal(t, tokenRemain+(preConsumed-adaptorQuota), getTokenRemainQuota(t, tokenID))
|
|
|
+ assert.Equal(t, adaptorQuota, task.Quota)
|
|
|
+
|
|
|
+ log := getLastLog(t)
|
|
|
+ require.NotNil(t, log)
|
|
|
+ assert.Equal(t, model.LogTypeRefund, log.Type)
|
|
|
+}
|