Просмотр исходного кода

fix(mode_workflow): query_score 成本回退(token估)+ /api/query_score sel 正则守卫(补 import re)

刘文武 1 неделя назад
Родитель
Сommit
eb612ece8c
2 измененных файлов с 18 добавлено и 2 удалено
  1. 5 1
      examples/mode_workflow/server.py
  2. 13 1
      examples/mode_workflow/stages/query_score.py

+ 5 - 1
examples/mode_workflow/server.py

@@ -15,6 +15,7 @@
 import hashlib
 import json
 import os
+import re
 import subprocess
 import sys
 import threading
@@ -472,7 +473,10 @@ class Handler(BaseHTTPRequestHandler):
             elif u.path == "/api/category_tree":
                 self._json_etag(_category_tree(qs.get("source_type", "实质")))
             elif u.path == "/api/query_score":
-                cache = SCORE_CACHE_DIR / f"{qs.get('sel', '')}.json"
+                sel = qs.get("sel", "")
+                if not re.fullmatch(r"[0-9a-f]{16}", sel):   # 防路径穿越:sel 必为 16 位十六进制
+                    return self._err("bad sel", 400)
+                cache = SCORE_CACHE_DIR / f"{sel}.json"
                 if cache.is_file():
                     self._json_etag(json.loads(cache.read_text(encoding="utf-8")))
                 else:

+ 13 - 1
examples/mode_workflow/stages/query_score.py

@@ -72,7 +72,19 @@ async def _call_with_retry(llm_call, messages, model, task_name, max_retries=3):
         try:
             resp = await llm_call(messages=cur_messages, model=model,
                                   temperature=0.1, max_tokens=4000)
-            cost = resp.get("cost") or 0.0
+            # 成本:优先用 provider 自带 cost;缺省时按 token 用量估(同 llm_helper 口径)
+            provider_cost = resp.get("cost")
+            if isinstance(provider_cost, (int, float)) and provider_cost > 0:
+                cost = provider_cost
+            else:
+                usage = resp.get("usage") or {}
+                if hasattr(usage, "__dict__"):
+                    it = getattr(usage, "input_tokens", 0) or getattr(usage, "prompt_tokens", 0)
+                    ot = getattr(usage, "output_tokens", 0) or getattr(usage, "completion_tokens", 0)
+                else:
+                    it = usage.get("input_tokens", 0) or usage.get("prompt_tokens", 0)
+                    ot = usage.get("output_tokens", 0) or usage.get("completion_tokens", 0)
+                cost = (it / 1e6 * 3.0) + (ot / 1e6 * 15.0)
             total_cost += cost
             content = resp.get("content", "")
             if isinstance(content, list):