run.go 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. package billingexpr
  2. import (
  3. "fmt"
  4. "math"
  5. "strings"
  6. "time"
  7. "github.com/expr-lang/expr"
  8. "github.com/expr-lang/expr/vm"
  9. "github.com/tidwall/gjson"
  10. )
  11. // RunExpr compiles (with cache) and executes an expression string.
  12. // The environment exposes:
  13. // - p, c — prompt / completion tokens
  14. // - cr, cc, cc1h — cache read / creation / creation-1h tokens
  15. // - tier(name, value) — trace callback that records which tier matched
  16. // - max, min, abs, ceil, floor — standard math helpers
  17. //
  18. // Returns the resulting float64 quota (before group ratio) and a TraceResult
  19. // with side-channel info captured by tier() during execution.
  20. func RunExpr(exprStr string, params TokenParams) (float64, TraceResult, error) {
  21. return RunExprWithRequest(exprStr, params, RequestInput{})
  22. }
  23. func RunExprWithRequest(exprStr string, params TokenParams, request RequestInput) (float64, TraceResult, error) {
  24. prog, err := CompileFromCache(exprStr)
  25. if err != nil {
  26. return 0, TraceResult{}, err
  27. }
  28. return runProgram(prog, params, request)
  29. }
  30. // RunExprByHash is like RunExpr but accepts a pre-computed hash for the cache
  31. // lookup, avoiding a redundant SHA-256 computation when the caller already
  32. // holds BillingSnapshot.ExprHash.
  33. func RunExprByHash(exprStr, hash string, params TokenParams) (float64, TraceResult, error) {
  34. return RunExprByHashWithRequest(exprStr, hash, params, RequestInput{})
  35. }
  36. func RunExprByHashWithRequest(exprStr, hash string, params TokenParams, request RequestInput) (float64, TraceResult, error) {
  37. prog, err := CompileFromCacheByHash(exprStr, hash)
  38. if err != nil {
  39. return 0, TraceResult{}, err
  40. }
  41. return runProgram(prog, params, request)
  42. }
  43. func runProgram(prog *vm.Program, params TokenParams, request RequestInput) (float64, TraceResult, error) {
  44. trace := TraceResult{}
  45. headers := normalizeHeaders(request.Headers)
  46. env := map[string]interface{}{
  47. "p": params.P,
  48. "c": params.C,
  49. "cr": params.CR,
  50. "cc": params.CC,
  51. "cc1h": params.CC1h,
  52. "prompt_tokens": params.P,
  53. "completion_tokens": params.C,
  54. "cache_read_tokens": params.CR,
  55. "cache_create_tokens": params.CC,
  56. "cache_create_1h_tokens": params.CC1h,
  57. "tier": func(name string, value float64) float64 {
  58. trace.MatchedTier = name
  59. trace.Cost = value
  60. return value
  61. },
  62. "header": func(key string) string {
  63. return headers[strings.ToLower(strings.TrimSpace(key))]
  64. },
  65. "param": func(path string) interface{} {
  66. path = strings.TrimSpace(path)
  67. if path == "" || len(request.Body) == 0 {
  68. return nil
  69. }
  70. result := gjson.GetBytes(request.Body, path)
  71. if !result.Exists() {
  72. return nil
  73. }
  74. return result.Value()
  75. },
  76. "has": func(source interface{}, substr string) bool {
  77. if source == nil || substr == "" {
  78. return false
  79. }
  80. return strings.Contains(fmt.Sprint(source), substr)
  81. },
  82. "hour": func(tz string) int { return timeInZone(tz).Hour() },
  83. "minute": func(tz string) int { return timeInZone(tz).Minute() },
  84. "weekday": func(tz string) int { return int(timeInZone(tz).Weekday()) },
  85. "month": func(tz string) int { return int(timeInZone(tz).Month()) },
  86. "day": func(tz string) int { return timeInZone(tz).Day() },
  87. "max": math.Max,
  88. "min": math.Min,
  89. "abs": math.Abs,
  90. "ceil": math.Ceil,
  91. "floor": math.Floor,
  92. }
  93. out, err := expr.Run(prog, env)
  94. if err != nil {
  95. return 0, trace, fmt.Errorf("expr run error: %w", err)
  96. }
  97. f, ok := out.(float64)
  98. if !ok {
  99. return 0, trace, fmt.Errorf("expr result is %T, want float64", out)
  100. }
  101. return f, trace, nil
  102. }
  103. func timeInZone(tz string) time.Time {
  104. tz = strings.TrimSpace(tz)
  105. if tz == "" {
  106. return time.Now().UTC()
  107. }
  108. loc, err := time.LoadLocation(tz)
  109. if err != nil {
  110. return time.Now().UTC()
  111. }
  112. return time.Now().In(loc)
  113. }
  114. func normalizeHeaders(headers map[string]string) map[string]string {
  115. if len(headers) == 0 {
  116. return map[string]string{}
  117. }
  118. normalized := make(map[string]string, len(headers))
  119. for key, value := range headers {
  120. k := strings.ToLower(strings.TrimSpace(key))
  121. v := strings.TrimSpace(value)
  122. if k == "" || v == "" {
  123. continue
  124. }
  125. normalized[k] = v
  126. }
  127. return normalized
  128. }