| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162 |
- package helper
- import (
- "net/http"
- "net/http/httptest"
- "testing"
- "github.com/QuantumNous/new-api/common"
- "github.com/QuantumNous/new-api/pkg/billingexpr"
- relaycommon "github.com/QuantumNous/new-api/relay/common"
- "github.com/QuantumNous/new-api/setting/billing_setting"
- "github.com/QuantumNous/new-api/setting/config"
- "github.com/QuantumNous/new-api/types"
- "github.com/gin-gonic/gin"
- "github.com/stretchr/testify/require"
- )
- func TestModelPriceHelperTieredUsesPreloadedRequestInput(t *testing.T) {
- gin.SetMode(gin.TestMode)
- saved := map[string]string{}
- require.NoError(t, config.GlobalConfig.SaveToDB(func(key, value string) error {
- saved[key] = value
- return nil
- }))
- t.Cleanup(func() {
- require.NoError(t, config.GlobalConfig.LoadFromDB(saved))
- })
- require.NoError(t, config.GlobalConfig.LoadFromDB(map[string]string{
- "billing_setting.billing_mode": `{"tiered-test-model":"tiered_expr"}`,
- "billing_setting.billing_expr": `{"tiered-test-model":"param(\"stream\") == true ? tier(\"stream\", p * 3) : tier(\"base\", p * 2)"}`,
- }))
- recorder := httptest.NewRecorder()
- ctx, _ := gin.CreateTestContext(recorder)
- req := httptest.NewRequest(http.MethodPost, "/api/channel/test/1", nil)
- req.Body = nil
- req.ContentLength = 0
- req.Header.Set("Content-Type", "application/json")
- ctx.Request = req
- ctx.Set("group", "default")
- info := &relaycommon.RelayInfo{
- OriginModelName: "tiered-test-model",
- UserGroup: "default",
- UsingGroup: "default",
- RequestHeaders: map[string]string{"Content-Type": "application/json"},
- BillingRequestInput: &billingexpr.RequestInput{
- Headers: map[string]string{"Content-Type": "application/json"},
- Body: []byte(`{"stream":true}`),
- },
- }
- priceData, err := ModelPriceHelper(ctx, info, 1000, &types.TokenCountMeta{})
- require.NoError(t, err)
- require.Equal(t, 1500, priceData.QuotaToPreConsume)
- require.NotNil(t, info.TieredBillingSnapshot)
- require.Equal(t, "stream", info.TieredBillingSnapshot.EstimatedTier)
- require.Equal(t, billing_setting.BillingModeTieredExpr, info.TieredBillingSnapshot.BillingMode)
- require.Equal(t, common.QuotaPerUnit, info.TieredBillingSnapshot.QuotaPerUnit)
- }
|