run.go 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
  1. package billingexpr
  2. import (
  3. "fmt"
  4. "math"
  5. "strings"
  6. "time"
  7. "github.com/expr-lang/expr"
  8. "github.com/expr-lang/expr/vm"
  9. "github.com/tidwall/gjson"
  10. )
  11. // RunExpr compiles (with cache) and executes an expression string.
  12. // The environment exposes:
  13. // - p, c — prompt / completion tokens (auto-excluding separately-priced sub-categories)
  14. // - len — total input context length for tier conditions (never reduced by sub-category exclusion)
  15. // - cr, cc, cc1h — cache read / creation / creation-1h tokens
  16. // - tier(name, value) — trace callback that records which tier matched
  17. // - max, min, abs, ceil, floor — standard math helpers
  18. //
  19. // Returns the resulting float64 quota (before group ratio) and a TraceResult
  20. // with side-channel info captured by tier() during execution.
  21. func RunExpr(exprStr string, params TokenParams) (float64, TraceResult, error) {
  22. return RunExprWithRequest(exprStr, params, RequestInput{})
  23. }
  24. func RunExprWithRequest(exprStr string, params TokenParams, request RequestInput) (float64, TraceResult, error) {
  25. prog, err := CompileFromCache(exprStr)
  26. if err != nil {
  27. return 0, TraceResult{}, err
  28. }
  29. return runProgram(prog, params, request)
  30. }
  31. // RunExprByHash is like RunExpr but accepts a pre-computed hash for the cache
  32. // lookup, avoiding a redundant SHA-256 computation when the caller already
  33. // holds BillingSnapshot.ExprHash.
  34. func RunExprByHash(exprStr, hash string, params TokenParams) (float64, TraceResult, error) {
  35. return RunExprByHashWithRequest(exprStr, hash, params, RequestInput{})
  36. }
  37. func RunExprByHashWithRequest(exprStr, hash string, params TokenParams, request RequestInput) (float64, TraceResult, error) {
  38. prog, err := CompileFromCacheByHash(exprStr, hash)
  39. if err != nil {
  40. return 0, TraceResult{}, err
  41. }
  42. return runProgram(prog, params, request)
  43. }
  44. func runProgram(prog *vm.Program, params TokenParams, request RequestInput) (float64, TraceResult, error) {
  45. trace := TraceResult{}
  46. headers := normalizeHeaders(request.Headers)
  47. env := map[string]interface{}{
  48. "p": params.P,
  49. "c": params.C,
  50. "len": params.Len,
  51. "cr": params.CR,
  52. "cc": params.CC,
  53. "cc1h": params.CC1h,
  54. "img": params.Img,
  55. "img_o": params.ImgO,
  56. "ai": params.AI,
  57. "ao": params.AO,
  58. "tier": func(name string, value float64) float64 {
  59. trace.MatchedTier = name
  60. trace.Cost = value
  61. return value
  62. },
  63. "header": func(key string) string {
  64. return headers[strings.ToLower(strings.TrimSpace(key))]
  65. },
  66. "param": func(path string) interface{} {
  67. path = strings.TrimSpace(path)
  68. if path == "" || len(request.Body) == 0 {
  69. return nil
  70. }
  71. result := gjson.GetBytes(request.Body, path)
  72. if !result.Exists() {
  73. return nil
  74. }
  75. return result.Value()
  76. },
  77. "has": func(source interface{}, substr string) bool {
  78. if source == nil || substr == "" {
  79. return false
  80. }
  81. return strings.Contains(fmt.Sprint(source), substr)
  82. },
  83. "hour": func(tz string) int { return timeInZone(tz).Hour() },
  84. "minute": func(tz string) int { return timeInZone(tz).Minute() },
  85. "weekday": func(tz string) int { return int(timeInZone(tz).Weekday()) },
  86. "month": func(tz string) int { return int(timeInZone(tz).Month()) },
  87. "day": func(tz string) int { return timeInZone(tz).Day() },
  88. "max": math.Max,
  89. "min": math.Min,
  90. "abs": math.Abs,
  91. "ceil": math.Ceil,
  92. "floor": math.Floor,
  93. }
  94. out, err := expr.Run(prog, env)
  95. if err != nil {
  96. return 0, trace, fmt.Errorf("expr run error: %w", err)
  97. }
  98. f, ok := out.(float64)
  99. if !ok {
  100. return 0, trace, fmt.Errorf("expr result is %T, want float64", out)
  101. }
  102. return f, trace, nil
  103. }
  104. func timeInZone(tz string) time.Time {
  105. tz = strings.TrimSpace(tz)
  106. if tz == "" {
  107. return time.Now().UTC()
  108. }
  109. loc, err := time.LoadLocation(tz)
  110. if err != nil {
  111. return time.Now().UTC()
  112. }
  113. return time.Now().In(loc)
  114. }
  115. func normalizeHeaders(headers map[string]string) map[string]string {
  116. if len(headers) == 0 {
  117. return map[string]string{}
  118. }
  119. normalized := make(map[string]string, len(headers))
  120. for key, value := range headers {
  121. k := strings.ToLower(strings.TrimSpace(key))
  122. v := strings.TrimSpace(value)
  123. if k == "" || v == "" {
  124. continue
  125. }
  126. normalized[k] = v
  127. }
  128. return normalized
  129. }