Procházet zdrojové kódy

增加dt和channel字段

jihuaqiang před 1 dnem
rodič
revize
e2189b8d4b

+ 2 - 0
examples/content_finder/db/__init__.py

@@ -19,6 +19,7 @@ from .schedule import (
     update_task_on_complete,
 )
 from .store_results import (
+    fetch_demand_content_dt,
     upsert_good_authors,
     insert_contents,
     update_content_plan_ids,
@@ -37,6 +38,7 @@ __all__ = [
     "fetch_trace_ids_created_after",
     "update_task_status",
     "update_task_on_complete",
+    "fetch_demand_content_dt",
     "upsert_good_authors",
     "insert_contents",
     "update_content_plan_ids",

+ 32 - 6
examples/content_finder/db/store_results.py

@@ -89,15 +89,29 @@ def upsert_good_authors(
         return rows
 
 
+def fetch_demand_content_dt(conn, demand_content_id: int) -> Optional[Any]:
+    """按 demand_content.id 查询 dt(与 schedule 约定一致,多为 YYYYMMDD 整数)。"""
+    sql = "SELECT dt FROM demand_content WHERE id = %s LIMIT 1"
+    with conn.cursor() as cur:
+        cur.execute(sql, (demand_content_id,))
+        row = cur.fetchone()
+    if not row:
+        return None
+    return row.get("dt")
+
+
 def insert_contents(
     conn,
     trace_id: str,
     query: str,
     demand_content_id: int,
     contents: List[Dict[str, Any]],
+    dt: Optional[Any] = None,
 ) -> int:
     """
     将 contents 列表写入 demand_find_content_result 表。
+
+    dt 来自 demand_content.dt,与 demand_content_id 对应;未查到时可传 None。
     """
     if not contents:
         return 0
@@ -107,12 +121,12 @@ def insert_contents(
       trace_id, query, rank_no, aweme_id, video_url, title, author_name, author_id, author_link,
       digg_count, comment_count, share_count,
       portrait_source, elderly_ratio, elderly_tgi, recommendation_reason,
-      demand_content_id
+      demand_content_id, dt
     ) VALUES (
       %s, %s, %s, %s, %s, %s, %s, %s, %s,
       %s, %s, %s,
       %s, %s, %s, %s,
-      %s
+      %s, %s
     )
     """
     with conn.cursor() as cur:
@@ -146,6 +160,7 @@ def insert_contents(
                     str(age_50_plus_tgi) if age_50_plus_tgi != "" else "",
                     item.get("reason") or "",
                     demand_content_id,
+                    dt,
                 ),
             )
             rows += cur.rowcount
@@ -228,30 +243,41 @@ def update_web_html_url(trace_id: str, web_html_url: str) -> int:
         conn.close()
 
 
-def update_process_trace_by_aweme_id(*, trace_id: str, aweme_id: str, process_trace_text: str) -> int:
+def update_process_trace_by_aweme_id(
+    *,
+    trace_id: str,
+    aweme_id: str,
+    process_trace_text: str,
+    channel: str = "抖音",
+) -> int:
     """
-    根据 (trace_id, aweme_id) 回写 demand_find_content_result.process_trace(TEXT)。
+    根据 (trace_id, aweme_id) 回写 demand_find_content_result.process_trace(TEXT)与 channel
 
     约定:
     - trace_id 为 output 子目录名
     - aweme_id 为内容唯一 id(表中 demand_find_content_result.aweme_id)
     - process_trace_text 为 JSON 序列化后的字符串(或原始文本)
+    - channel 默认「抖音」;当前业务仅抖音搜索场景,后续可按行区分时再传入
     """
     t = (trace_id or "").strip()
     a = (aweme_id or "").strip()
     text = (process_trace_text or "").strip()
+    ch = (channel or "").strip()
     if not t or not a or not text:
         return 0
+    if not ch:
+        ch = "抖音"
 
     sql = """
     UPDATE demand_find_content_result
-    SET process_trace = %s
+    SET process_trace = %s,
+        channel = %s
     WHERE trace_id = %s AND aweme_id = %s
     """
     conn = get_connection()
     try:
         with conn.cursor() as cur:
-            cur.execute(sql, (text, t, a))
+            cur.execute(sql, (text, ch, t, a))
             return cur.rowcount
     finally:
         conn.close()

+ 58 - 3
examples/content_finder/tools/exec_summary.py

@@ -1,7 +1,8 @@
 """
-在流程结束后写入**内容策略表格** JSON。
+在流程结束后写入**内容策略表格** JSON,并回写 MySQL
 
 输出路径:{OUTPUT_DIR}/{trace_id}/process_trace.json
+每条策略行另按 (trace_id, aweme_id) 更新 demand_find_content_result.process_trace(TEXT)。
 """
 
 from __future__ import annotations
@@ -15,6 +16,8 @@ from typing import Any, Dict, List, Optional, Tuple
 from agent.tools import tool, ToolResult
 from utils.tool_logging import format_tool_result_for_log, log_tool_call
 
+from db import update_process_trace_by_aweme_id
+
 _LOG_LABEL = "工具调用:exec_summary -> 写入过程 trace JSON"
 
 logger = logging.getLogger(__name__)
@@ -292,10 +295,51 @@ def _write_process_trace(*, trace_id: str, payload: Dict[str, Any]) -> Path:
     return path
 
 
+def _sync_process_trace_rows_to_mysql(*, trace_id: str, rows: List[Dict[str, Any]]) -> Dict[str, Any]:
+    """
+    将每条归一化后的策略行序列化为 JSON 文本,按 (trace_id, aweme_id) 更新 process_trace 与 channel。
+
+    channel 当前统一为「抖音」(与 process_trace.json 内 channel「抖音搜索」区分)。
+    表中无匹配行时 rowcount 为 0,计入 skipped。
+    """
+    updated = 0
+    skipped = 0
+    errors: List[str] = []
+    for row in rows:
+        aweme_id = str(row.get("aweme_id") or "").strip()
+        if not aweme_id:
+            skipped += 1
+            continue
+        text = json.dumps(row, ensure_ascii=False)
+        try:
+            n = update_process_trace_by_aweme_id(
+                trace_id=trace_id,
+                aweme_id=aweme_id,
+                process_trace_text=text,
+                channel="抖音",
+            )
+            if n > 0:
+                updated += 1
+            else:
+                skipped += 1
+        except Exception as e:
+            logger.warning(
+                "process_trace 回写 MySQL 失败 trace_id=%s aweme_id=%s: %s",
+                trace_id,
+                aweme_id,
+                e,
+                exc_info=True,
+            )
+            errors.append(f"{aweme_id}: {e}")
+    return {"updated": updated, "skipped": skipped, "errors": errors}
+
+
 @tool(
     description=(
         "在**全部流程执行完毕之后**调用:把每条最终入选内容的「选择策略」整理成表格 JSON,"
-        "写入当前任务的 output 目录下的 process_trace.json,便于后续复盘。"
+        "写入当前任务的 output 目录下的 process_trace.json,便于后续复盘;"
+        "并将每一行策略 JSON 序列化为文本,按 trace_id + aweme_id 回写到 "
+        "demand_find_content_result.process_trace,并同步将 channel 字段设为「抖音」。"
         "参数 summary_json 为 JSON 字符串,可以是数组或对象(对象需包含 rows)。"
         "可选参数 log_path/log_text 用于传入本次运行日志(便于复盘留档/未来扩展)。"
     ),
@@ -352,15 +396,26 @@ async def exec_summary(
         log_tool_call(_LOG_LABEL, call_params, format_tool_result_for_log(err))
         return err
 
+    rows = payload.get("rows") or []
+    mysql_meta: Dict[str, Any]
+    try:
+        mysql_meta = _sync_process_trace_rows_to_mysql(trace_id=tid, rows=rows if isinstance(rows, list) else [])
+    except Exception as e:
+        logger.warning("process_trace 批量回写 MySQL 异常: %s", e, exc_info=True)
+        mysql_meta = {"updated": 0, "skipped": 0, "errors": [str(e)]}
+
     out = ToolResult(
         title="过程摘要",
-        output=f"已写入 {path}",
+        output=f"已写入 {path};MySQL process_trace 已更新 {mysql_meta.get('updated', 0)} 条",
         metadata={
             "ok": True,
             "trace_id": tid,
             "path": str(path),
             "log_path": (log_path or "").strip(),
             "log_text_len": len((log_text or "").strip()),
+            "mysql_process_trace_updated": mysql_meta.get("updated", 0),
+            "mysql_process_trace_skipped": mysql_meta.get("skipped", 0),
+            "mysql_process_trace_errors": mysql_meta.get("errors") or [],
         },
     )
     log_tool_call(_LOG_LABEL, {"trace_id": tid}, format_tool_result_for_log(out))

+ 14 - 3
examples/content_finder/tools/store_results_mysql.py

@@ -16,7 +16,7 @@ from typing import Any, Dict
 from agent.tools import tool, ToolResult
 from utils.tool_logging import format_tool_result_for_log, log_tool_call
 
-from db import get_connection, insert_contents, upsert_good_authors
+from db import fetch_demand_content_dt, get_connection, insert_contents, upsert_good_authors
 
 _LOG_LABEL = "工具调用:store_results_mysql -> 推荐结果写入MySQL"
 
@@ -40,6 +40,7 @@ async def store_results_mysql(trace_id: str) -> ToolResult:
     """
     根据 trace_id 读取 output.json,并写入 MySQL。
     demand_content_id 从 output 的 demand_id 字段获取,需在 output_schema 中输出。
+    写入内容结果时按 demand_content_id 查询 demand_content.dt,并写入每条 demand_find_content_result。
     """
     call_params = {"trace_id": trace_id}
     try:
@@ -63,7 +64,6 @@ async def store_results_mysql(trace_id: str) -> ToolResult:
         err = ToolResult(title="存储推荐结果", output=msg, metadata={"ok": False, "error": msg})
         log_tool_call(_LOG_LABEL, call_params, format_tool_result_for_log(err))
         return err
-
     conn = None
     try:
         conn = get_connection()
@@ -71,8 +71,17 @@ async def store_results_mysql(trace_id: str) -> ToolResult:
         contents = data.get("contents") or []
         query = data.get("query") or ""
 
+        dc_dt = fetch_demand_content_dt(conn, demand_content_id)
+        if dc_dt is None:
+            logger.warning(
+                "demand_content 无对应记录或 dt 为空: id=%s,demand_find_content_result.dt 将写入 NULL",
+                demand_content_id,
+            )
+
         authors_rows = upsert_good_authors(conn, trace_id, good_block)
-        contents_rows = insert_contents(conn, trace_id, query, demand_content_id, contents)
+        contents_rows = insert_contents(
+            conn, trace_id, query, demand_content_id, contents, dt=dc_dt
+        )
 
         output = (
             f"MySQL 写入完成:demand_find_author 影响行数={authors_rows}, "
@@ -85,6 +94,8 @@ async def store_results_mysql(trace_id: str) -> ToolResult:
             metadata={
                 "ok": True,
                 "trace_id": trace_id,
+                "demand_content_id": demand_content_id,
+                "demand_content_dt": dc_dt,
                 "good_authors_affected": authors_rows,
                 "contents_inserted": contents_rows,
             },