Explorar el Código

feat: 增加token消耗量统计和入库

jihuaqiang hace 1 mes
padre
commit
71c1ded5ac

+ 4 - 0
examples/content_finder/db/__init__.py

@@ -13,6 +13,8 @@ from .schedule import (
     get_daily_unprocessed_pool,
     get_daily_unprocessed_pool,
     get_first_running_task,
     get_first_running_task,
     get_latest_demand_task_oprate_is_open,
     get_latest_demand_task_oprate_is_open,
+    get_latest_day_limit_coast,
+    get_total_token_coast_between,
     get_one_today_unprocessed_demand,
     get_one_today_unprocessed_demand,
     create_task_record,
     create_task_record,
     fetch_trace_ids_created_after,
     fetch_trace_ids_created_after,
@@ -36,6 +38,8 @@ __all__ = [
     "get_daily_unprocessed_pool",
     "get_daily_unprocessed_pool",
     "get_first_running_task",
     "get_first_running_task",
     "get_latest_demand_task_oprate_is_open",
     "get_latest_demand_task_oprate_is_open",
+    "get_latest_day_limit_coast",
+    "get_total_token_coast_between",
     "get_one_today_unprocessed_demand",
     "get_one_today_unprocessed_demand",
     "create_task_record",
     "create_task_record",
     "fetch_trace_ids_created_after",
     "fetch_trace_ids_created_after",

+ 89 - 6
examples/content_finder/db/schedule.py

@@ -7,6 +7,7 @@ demand_find_task: 执行记录表,通过 demand_content_id 关联
 
 
 import logging
 import logging
 from datetime import datetime
 from datetime import datetime
+from decimal import Decimal
 from typing import Any, Dict, List, Optional
 from typing import Any, Dict, List, Optional
 
 
 from .connection import get_connection
 from .connection import get_connection
@@ -127,6 +128,69 @@ def get_latest_demand_task_oprate_is_open() -> Optional[int]:
             conn.close()
             conn.close()
 
 
 
 
+def get_latest_day_limit_coast() -> Optional[Decimal]:
+    """
+    读取 demand_task_oprate 中按 update_time 最新的一行的 day_limit_coast。
+
+    Returns:
+        Decimal(两位小数) 或 None(无记录或字段为空)
+    """
+    sql = """
+    SELECT day_limit_coast
+    FROM demand_task_oprate
+    ORDER BY update_time DESC, id DESC
+    LIMIT 1
+    """
+    conn = None
+    try:
+        conn = get_connection()
+        with conn.cursor() as cur:
+            cur.execute(sql)
+            row = cur.fetchone()
+            if not row:
+                return None
+            val = row.get("day_limit_coast")
+            if val is None:
+                return None
+            return Decimal(str(val))
+    except Exception as e:
+        logger.error(f"get_latest_day_limit_coast 失败: {e}", exc_info=True)
+        raise
+    finally:
+        if conn:
+            conn.close()
+
+
+def get_total_token_coast_between(start_time: datetime, end_time: datetime) -> Decimal:
+    """
+    统计 demand_find_task 在指定时间区间内的 token_coast 总和(基于 created_at 字段)。
+
+    start_time <= created_at < end_time
+    """
+    sql = """
+    SELECT COALESCE(SUM(token_coast), 0) AS total_coast
+    FROM demand_find_task
+    WHERE created_at >= %s
+      AND created_at < %s
+    """
+    conn = None
+    try:
+        conn = get_connection()
+        with conn.cursor() as cur:
+            cur.execute(sql, (start_time, end_time))
+            row = cur.fetchone() or {}
+            val = row.get("total_coast") if row is not None else None
+            if val is None:
+                return Decimal("0")
+            return Decimal(str(val))
+    except Exception as e:
+        logger.error(f"get_total_token_coast_between 失败: {e}", exc_info=True)
+        raise
+    finally:
+        if conn:
+            conn.close()
+
+
 def get_one_today_unprocessed_demand(*, dt: int) -> Optional[Dict[str, Any]]:
 def get_one_today_unprocessed_demand(*, dt: int) -> Optional[Dict[str, Any]]:
     """
     """
     从 demand_content 中取「当天 dt」且尚未在 demand_find_task 中出现过的 1 条需求。
     从 demand_content 中取「当天 dt」且尚未在 demand_find_task 中出现过的 1 条需求。
@@ -250,24 +314,43 @@ def create_task_record(demand_content_id: int, trace_id: str = "", status: int =
             conn.close()
             conn.close()
 
 
 
 
-def update_task_on_complete(demand_content_id: int, trace_id: str, status: int) -> None:
+def update_task_on_complete(
+    demand_content_id: int,
+    trace_id: str,
+    status: int,
+    token_coast: Optional[Decimal] = None,
+) -> None:
     """
     """
     任务完成后更新 trace_id 和 status。
     任务完成后更新 trace_id 和 status。
     匹配 trace_id 为空字符串的记录(初始创建时的占位)。
     匹配 trace_id 为空字符串的记录(初始创建时的占位)。
     """
     """
     sql = """
     sql = """
     UPDATE demand_find_task
     UPDATE demand_find_task
-    SET trace_id = %s, status = %s
-    WHERE demand_content_id = %s AND trace_id = ''
+    SET trace_id = %s,
+        status = %s
     """
     """
+    params: list[Any] = [trace_id, status]
+
+    if token_coast is not None:
+        sql += ", token_coast = %s\n"
+        params.append(token_coast)
+
+    sql += "WHERE demand_content_id = %s AND trace_id = ''"
+    params.append(demand_content_id)
     conn = None
     conn = None
     try:
     try:
         conn = get_connection()
         conn = get_connection()
         with conn.cursor() as cur:
         with conn.cursor() as cur:
-            cur.execute(sql, (trace_id, status, demand_content_id))
-        logger.info(f"更新任务完成: demand_content_id={demand_content_id}, trace_id={trace_id}, status={status}")
+            cur.execute(sql, tuple(params))
+        logger.info(
+            "更新任务完成: demand_content_id=%s, trace_id=%s, status=%s, token_coast=%s",
+            demand_content_id,
+            trace_id,
+            status,
+            token_coast,
+        )
     except Exception as e:
     except Exception as e:
-        logger.error(f"update_task_on_complete 失败: {e}", exc_info=True)
+        logger.error("update_task_on_complete 失败: %s", e, exc_info=True)
         raise
         raise
     finally:
     finally:
         if conn:
         if conn:

+ 70 - 5
examples/content_finder/server.py

@@ -11,13 +11,15 @@
 """
 """
 
 
 import asyncio
 import asyncio
+import json
 import logging
 import logging
 import os
 import os
+import sys
 import uuid
 import uuid
 from datetime import datetime
 from datetime import datetime
+from decimal import Decimal, ROUND_HALF_UP
 from pathlib import Path
 from pathlib import Path
 from typing import Optional
 from typing import Optional
-import sys
 
 
 sys.path.insert(0, str(Path(__file__).parent.parent.parent))
 sys.path.insert(0, str(Path(__file__).parent.parent.parent))
 
 
@@ -38,7 +40,13 @@ from db import (
     update_task_status,
     update_task_status,
     update_task_on_complete,
     update_task_on_complete,
 )
 )
-from db.schedule import STATUS_RUNNING, STATUS_SUCCESS, STATUS_FAILED
+from db.schedule import (
+    STATUS_RUNNING,
+    STATUS_SUCCESS,
+    STATUS_FAILED,
+    get_latest_day_limit_coast,
+    get_total_token_coast_between,
+)
 
 
 # 配置日志
 # 配置日志
 log_dir = Path(__file__).parent / '.cache'
 log_dir = Path(__file__).parent / '.cache'
@@ -106,12 +114,53 @@ class TaskResponse(BaseModel):
 
 
 # ============ 核心函数 ============
 # ============ 核心函数 ============
 
 
+
+def _load_token_coast_from_meta(trace_id: str) -> Optional[Decimal]:
+    """
+    从 TRACE_DIR/{trace_id}/meta.json 读取本次任务的 token 费用,并转成两位小数的 Decimal。
+    优先读取 total_cost 字段,兼容 total_coast;读取或解析失败返回 None。
+    """
+    trace_dir = Path(os.getenv("TRACE_DIR", ".cache/traces"))
+    meta_path = trace_dir / trace_id / "meta.json"
+    if not meta_path.exists():
+        logger.warning("未找到 meta.json,trace_id=%s, path=%s", trace_id, meta_path)
+        return None
+
+    try:
+        with meta_path.open("r", encoding="utf-8") as f:
+            data = json.load(f)
+    except Exception as e:
+        logger.warning("读取 meta.json 失败: trace_id=%s, error=%s", trace_id, e)
+        return None
+
+    raw_cost = data.get("total_cost")
+    if raw_cost is None:
+        raw_cost = data.get("total_coast")
+    if raw_cost is None:
+        logger.warning("meta.json 中未找到 total_cost/total_coast 字段: trace_id=%s", trace_id)
+        return None
+
+    try:
+        cost_decimal = Decimal(str(raw_cost)).quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
+        return cost_decimal
+    except Exception as e:
+        logger.warning("解析 token 费用失败: trace_id=%s, raw=%s, error=%s", trace_id, raw_cost, e)
+        return None
+
+
 def _update_scheduled_task_complete(demand_id: int, trace_id: str, status: int) -> None:
 def _update_scheduled_task_complete(demand_id: int, trace_id: str, status: int) -> None:
-    """定时任务完成时更新 trace_id 和 status,静默处理异常"""
+    """
+    定时任务完成时更新 trace_id、status 以及 token_coast(若能从 meta.json 成功解析)。
+    静默处理异常,不影响整体调度流程。
+    """
     try:
     try:
-        update_task_on_complete(demand_id, trace_id, status)
+        token_coast: Optional[Decimal] = None
+        if trace_id:
+            token_coast = _load_token_coast_from_meta(trace_id)
+
+        update_task_on_complete(demand_id, trace_id, status, token_coast)
     except Exception as e:
     except Exception as e:
-        logger.warning(f"更新任务状态失败: {e}")
+        logger.warning("更新任务状态或 token_coast 失败: %s", e)
 
 
 
 
 async def execute_task(
 async def execute_task(
@@ -223,6 +272,22 @@ async def scheduled_tick():
         logger.info("定时任务跳过:demand_task_oprate 最新记录 is_open=0")
         logger.info("定时任务跳过:demand_task_oprate 最新记录 is_open=0")
         return
         return
 
 
+    # 检查当日 token_coast 是否已超出日预算(以 SCHEDULER_TZ 当地时间为准)
+    day_limit_coast = get_latest_day_limit_coast()
+    if day_limit_coast is not None:
+        now_local = datetime.now(SCHEDULER_TZ)
+        day_start = now_local.replace(hour=0, minute=0, second=0, microsecond=0)
+        day_end = day_start.replace(day=day_start.day + 1)
+
+        used_today = get_total_token_coast_between(day_start, day_end)
+        if used_today >= day_limit_coast:
+            logger.info(
+                "定时任务跳过:当日 token_coast 已达上限,used=%s, limit=%s",
+                used_today,
+                day_limit_coast,
+            )
+            return
+
     # 无空闲并发槽则不派发;保持 tick 很快返回,避免阻塞调度器。
     # 无空闲并发槽则不派发;保持 tick 很快返回,避免阻塞调度器。
     if task_semaphore._value <= 0:
     if task_semaphore._value <= 0:
         logger.info("定时任务跳过:无空闲并发槽")
         logger.info("定时任务跳过:无空闲并发槽")