tiered_billing.go 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. package billing_setting
  2. import (
  3. "fmt"
  4. "github.com/QuantumNous/new-api/pkg/billingexpr"
  5. "github.com/QuantumNous/new-api/setting/config"
  6. "github.com/samber/lo"
  7. )
  8. const (
  9. BillingModeRatio = "ratio"
  10. BillingModeTieredExpr = "tiered_expr"
  11. BillingModeField = "billing_mode"
  12. BillingExprField = "billing_expr"
  13. )
  14. // BillingSetting is managed by config.GlobalConfig.Register.
  15. // DB keys: billing_setting.billing_mode, billing_setting.billing_expr
  16. type BillingSetting struct {
  17. BillingMode map[string]string `json:"billing_mode"`
  18. BillingExpr map[string]string `json:"billing_expr"`
  19. }
  20. var billingSetting = BillingSetting{
  21. BillingMode: make(map[string]string),
  22. BillingExpr: make(map[string]string),
  23. }
  24. func init() {
  25. config.GlobalConfig.Register("billing_setting", &billingSetting)
  26. }
  27. // ---------------------------------------------------------------------------
  28. // Read accessors (hot path, must be fast)
  29. // ---------------------------------------------------------------------------
  30. func GetBillingMode(model string) string {
  31. if mode, ok := billingSetting.BillingMode[model]; ok {
  32. return mode
  33. }
  34. return BillingModeRatio
  35. }
  36. func GetBillingExpr(model string) (string, bool) {
  37. expr, ok := billingSetting.BillingExpr[model]
  38. return expr, ok
  39. }
  40. func GetBillingModeCopy() map[string]string {
  41. return lo.Assign(billingSetting.BillingMode)
  42. }
  43. func GetBillingExprCopy() map[string]string {
  44. return lo.Assign(billingSetting.BillingExpr)
  45. }
  46. func GetPricingSyncData(base map[string]any) map[string]any {
  47. extra := make(map[string]any, 2)
  48. if modes := GetBillingModeCopy(); len(modes) > 0 {
  49. extra[BillingModeField] = modes
  50. }
  51. if exprs := GetBillingExprCopy(); len(exprs) > 0 {
  52. extra[BillingExprField] = exprs
  53. }
  54. return lo.Assign(base, extra)
  55. }
  56. // ---------------------------------------------------------------------------
  57. // Smoke test (called externally for validation before save)
  58. // ---------------------------------------------------------------------------
  59. func SmokeTestExpr(exprStr string) error {
  60. return smokeTestExpr(exprStr)
  61. }
  62. func smokeTestExpr(exprStr string) error {
  63. vectors := []billingexpr.TokenParams{
  64. {P: 0, C: 0, Len: 0},
  65. {P: 1000, C: 1000, Len: 1000},
  66. {P: 100000, C: 100000, Len: 100000},
  67. {P: 1000000, C: 1000000, Len: 1000000},
  68. }
  69. requests := []billingexpr.RequestInput{
  70. {},
  71. {
  72. Headers: map[string]string{
  73. "anthropic-beta": "fast-mode-2026-02-01",
  74. },
  75. Body: []byte(`{"service_tier":"fast","stream_options":{"include_usage":true},"messages":[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21]}`),
  76. },
  77. }
  78. for _, v := range vectors {
  79. for _, request := range requests {
  80. result, _, err := billingexpr.RunExprWithRequest(exprStr, v, request)
  81. if err != nil {
  82. return fmt.Errorf("vector {p=%g, c=%g}: run failed: %w", v.P, v.C, err)
  83. }
  84. if result < 0 {
  85. return fmt.Errorf("vector {p=%g, c=%g}: result %f < 0", v.P, v.C, result)
  86. }
  87. }
  88. }
  89. return nil
  90. }