price_test.go 2.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162
  1. package helper
  2. import (
  3. "net/http"
  4. "net/http/httptest"
  5. "testing"
  6. "github.com/QuantumNous/new-api/common"
  7. "github.com/QuantumNous/new-api/pkg/billingexpr"
  8. relaycommon "github.com/QuantumNous/new-api/relay/common"
  9. "github.com/QuantumNous/new-api/setting/billing_setting"
  10. "github.com/QuantumNous/new-api/setting/config"
  11. "github.com/QuantumNous/new-api/types"
  12. "github.com/gin-gonic/gin"
  13. "github.com/stretchr/testify/require"
  14. )
  15. func TestModelPriceHelperTieredUsesPreloadedRequestInput(t *testing.T) {
  16. gin.SetMode(gin.TestMode)
  17. saved := map[string]string{}
  18. require.NoError(t, config.GlobalConfig.SaveToDB(func(key, value string) error {
  19. saved[key] = value
  20. return nil
  21. }))
  22. t.Cleanup(func() {
  23. require.NoError(t, config.GlobalConfig.LoadFromDB(saved))
  24. })
  25. require.NoError(t, config.GlobalConfig.LoadFromDB(map[string]string{
  26. "billing_setting.billing_mode": `{"tiered-test-model":"tiered_expr"}`,
  27. "billing_setting.billing_expr": `{"tiered-test-model":"param(\"stream\") == true ? tier(\"stream\", p * 3) : tier(\"base\", p * 2)"}`,
  28. }))
  29. recorder := httptest.NewRecorder()
  30. ctx, _ := gin.CreateTestContext(recorder)
  31. req := httptest.NewRequest(http.MethodPost, "/api/channel/test/1", nil)
  32. req.Body = nil
  33. req.ContentLength = 0
  34. req.Header.Set("Content-Type", "application/json")
  35. ctx.Request = req
  36. ctx.Set("group", "default")
  37. info := &relaycommon.RelayInfo{
  38. OriginModelName: "tiered-test-model",
  39. UserGroup: "default",
  40. UsingGroup: "default",
  41. RequestHeaders: map[string]string{"Content-Type": "application/json"},
  42. BillingRequestInput: &billingexpr.RequestInput{
  43. Headers: map[string]string{"Content-Type": "application/json"},
  44. Body: []byte(`{"stream":true}`),
  45. },
  46. }
  47. priceData, err := ModelPriceHelper(ctx, info, 1000, &types.TokenCountMeta{})
  48. require.NoError(t, err)
  49. require.Equal(t, 1500, priceData.QuotaToPreConsume)
  50. require.NotNil(t, info.TieredBillingSnapshot)
  51. require.Equal(t, "stream", info.TieredBillingSnapshot.EstimatedTier)
  52. require.Equal(t, billing_setting.BillingModeTieredExpr, info.TieredBillingSnapshot.BillingMode)
  53. require.Equal(t, common.QuotaPerUnit, info.TieredBillingSnapshot.QuotaPerUnit)
  54. }