run.go 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. package billingexpr
  2. import (
  3. "fmt"
  4. "math"
  5. "strings"
  6. "time"
  7. "github.com/expr-lang/expr"
  8. "github.com/expr-lang/expr/vm"
  9. "github.com/tidwall/gjson"
  10. )
  11. // RunExpr compiles (with cache) and executes an expression string.
  12. // The environment exposes:
  13. // - p, c — prompt / completion tokens
  14. // - cr, cc, cc1h — cache read / creation / creation-1h tokens
  15. // - tier(name, value) — trace callback that records which tier matched
  16. // - max, min, abs, ceil, floor — standard math helpers
  17. //
  18. // Returns the resulting float64 quota (before group ratio) and a TraceResult
  19. // with side-channel info captured by tier() during execution.
  20. func RunExpr(exprStr string, params TokenParams) (float64, TraceResult, error) {
  21. return RunExprWithRequest(exprStr, params, RequestInput{})
  22. }
  23. func RunExprWithRequest(exprStr string, params TokenParams, request RequestInput) (float64, TraceResult, error) {
  24. prog, err := CompileFromCache(exprStr)
  25. if err != nil {
  26. return 0, TraceResult{}, err
  27. }
  28. return runProgram(prog, params, request)
  29. }
  30. // RunExprByHash is like RunExpr but accepts a pre-computed hash for the cache
  31. // lookup, avoiding a redundant SHA-256 computation when the caller already
  32. // holds BillingSnapshot.ExprHash.
  33. func RunExprByHash(exprStr, hash string, params TokenParams) (float64, TraceResult, error) {
  34. return RunExprByHashWithRequest(exprStr, hash, params, RequestInput{})
  35. }
  36. func RunExprByHashWithRequest(exprStr, hash string, params TokenParams, request RequestInput) (float64, TraceResult, error) {
  37. prog, err := CompileFromCacheByHash(exprStr, hash)
  38. if err != nil {
  39. return 0, TraceResult{}, err
  40. }
  41. return runProgram(prog, params, request)
  42. }
  43. func runProgram(prog *vm.Program, params TokenParams, request RequestInput) (float64, TraceResult, error) {
  44. trace := TraceResult{}
  45. headers := normalizeHeaders(request.Headers)
  46. env := map[string]interface{}{
  47. "p": params.P,
  48. "c": params.C,
  49. "cr": params.CR,
  50. "cc": params.CC,
  51. "cc1h": params.CC1h,
  52. "prompt_tokens": params.P,
  53. "completion_tokens": params.C,
  54. "cache_read_tokens": params.CR,
  55. "cache_create_tokens": params.CC,
  56. "cache_create_1h_tokens": params.CC1h,
  57. "img": params.Img,
  58. "ai": params.AI,
  59. "ao": params.AO,
  60. "image_tokens": params.Img,
  61. "audio_input_tokens": params.AI,
  62. "audio_output_tokens": params.AO,
  63. "tier": func(name string, value float64) float64 {
  64. trace.MatchedTier = name
  65. trace.Cost = value
  66. return value
  67. },
  68. "header": func(key string) string {
  69. return headers[strings.ToLower(strings.TrimSpace(key))]
  70. },
  71. "param": func(path string) interface{} {
  72. path = strings.TrimSpace(path)
  73. if path == "" || len(request.Body) == 0 {
  74. return nil
  75. }
  76. result := gjson.GetBytes(request.Body, path)
  77. if !result.Exists() {
  78. return nil
  79. }
  80. return result.Value()
  81. },
  82. "has": func(source interface{}, substr string) bool {
  83. if source == nil || substr == "" {
  84. return false
  85. }
  86. return strings.Contains(fmt.Sprint(source), substr)
  87. },
  88. "hour": func(tz string) int { return timeInZone(tz).Hour() },
  89. "minute": func(tz string) int { return timeInZone(tz).Minute() },
  90. "weekday": func(tz string) int { return int(timeInZone(tz).Weekday()) },
  91. "month": func(tz string) int { return int(timeInZone(tz).Month()) },
  92. "day": func(tz string) int { return timeInZone(tz).Day() },
  93. "max": math.Max,
  94. "min": math.Min,
  95. "abs": math.Abs,
  96. "ceil": math.Ceil,
  97. "floor": math.Floor,
  98. }
  99. out, err := expr.Run(prog, env)
  100. if err != nil {
  101. return 0, trace, fmt.Errorf("expr run error: %w", err)
  102. }
  103. f, ok := out.(float64)
  104. if !ok {
  105. return 0, trace, fmt.Errorf("expr result is %T, want float64", out)
  106. }
  107. return f, trace, nil
  108. }
  109. func timeInZone(tz string) time.Time {
  110. tz = strings.TrimSpace(tz)
  111. if tz == "" {
  112. return time.Now().UTC()
  113. }
  114. loc, err := time.LoadLocation(tz)
  115. if err != nil {
  116. return time.Now().UTC()
  117. }
  118. return time.Now().In(loc)
  119. }
  120. func normalizeHeaders(headers map[string]string) map[string]string {
  121. if len(headers) == 0 {
  122. return map[string]string{}
  123. }
  124. normalized := make(map[string]string, len(headers))
  125. for key, value := range headers {
  126. k := strings.ToLower(strings.TrimSpace(key))
  127. v := strings.TrimSpace(value)
  128. if k == "" || v == "" {
  129. continue
  130. }
  131. normalized[k] = v
  132. }
  133. return normalized
  134. }