|
|
@@ -1,422 +0,0 @@
|
|
|
-"""
|
|
|
-在流程结束后写入**内容策略表格** JSON,并回写 MySQL。
|
|
|
-
|
|
|
-输出路径:{OUTPUT_DIR}/{trace_id}/process_trace.json
|
|
|
-每条策略行另按 (trace_id, aweme_id) 更新 demand_find_content_result.process_trace(TEXT)。
|
|
|
-"""
|
|
|
-
|
|
|
-from __future__ import annotations
|
|
|
-
|
|
|
-import json
|
|
|
-import logging
|
|
|
-import os
|
|
|
-from pathlib import Path
|
|
|
-from typing import Any, Dict, List, Optional, Tuple
|
|
|
-
|
|
|
-from agent.tools import tool, ToolResult
|
|
|
-from utils.tool_logging import format_tool_result_for_log, log_tool_call
|
|
|
-
|
|
|
-from db import update_process_trace_by_aweme_id
|
|
|
-
|
|
|
-_LOG_LABEL = "工具调用:exec_summary -> 写入过程 trace JSON"
|
|
|
-
|
|
|
-logger = logging.getLogger(__name__)
|
|
|
-
|
|
|
-
|
|
|
-def _output_dir_path() -> Path:
|
|
|
- # 与 store_results_mysql / output.json 目录约定一致
|
|
|
- return Path(os.getenv("OUTPUT_DIR", ".cache/output"))
|
|
|
-
|
|
|
-
|
|
|
-def _parse_payload(summary_json: str) -> Dict[str, Any]:
|
|
|
- """
|
|
|
- 解析并规范化 LLM 传入的表格数据。
|
|
|
-
|
|
|
- - 如果是数组:视为“表格行列表”,包成 {"rows": [...]}
|
|
|
- - 如果是对象:直接返回(用于后续扩展字段)
|
|
|
- """
|
|
|
- data = json.loads(summary_json)
|
|
|
- if isinstance(data, list):
|
|
|
- return {"rows": data}
|
|
|
- if not isinstance(data, dict):
|
|
|
- raise ValueError("summary_json 解析后必须是 JSON 对象或数组")
|
|
|
- return data
|
|
|
-
|
|
|
-
|
|
|
-def _split_input_features(raw: str) -> List[str]:
|
|
|
- s = (raw or "").strip()
|
|
|
- if not s:
|
|
|
- return []
|
|
|
- parts = s.replace(",", ",").split(",")
|
|
|
- out: List[str] = []
|
|
|
- for p in parts:
|
|
|
- t = p.strip()
|
|
|
- if t:
|
|
|
- out.append(t)
|
|
|
- return out
|
|
|
-
|
|
|
-
|
|
|
-def _load_output_json(*, trace_id: str) -> Optional[Dict[str, Any]]:
|
|
|
- path = _output_dir_path() / trace_id / "output.json"
|
|
|
- try:
|
|
|
- with path.open("r", encoding="utf-8") as f:
|
|
|
- data = json.load(f)
|
|
|
- except FileNotFoundError:
|
|
|
- return None
|
|
|
- except Exception:
|
|
|
- logger.warning("读取 output.json 失败: %s", str(path), exc_info=True)
|
|
|
- return None
|
|
|
- return data if isinstance(data, dict) else None
|
|
|
-
|
|
|
-
|
|
|
-def _extract_contents(*, trace_id: str) -> List[Dict[str, Any]]:
|
|
|
- """
|
|
|
- 从 output.json 读取最终入选 contents。
|
|
|
-
|
|
|
- 约定:
|
|
|
- - 只允许对 output.json.contents 内的 aweme_id 生成/写入 process_trace rows
|
|
|
- """
|
|
|
- output_json = _load_output_json(trace_id=trace_id) or {}
|
|
|
- contents = output_json.get("contents")
|
|
|
- if not isinstance(contents, list):
|
|
|
- return []
|
|
|
- out: List[Dict[str, Any]] = []
|
|
|
- for item in contents:
|
|
|
- if isinstance(item, dict):
|
|
|
- out.append(item)
|
|
|
- return out
|
|
|
-
|
|
|
-
|
|
|
-def _map_strategy_type(value: Any) -> str:
|
|
|
- v = str(value or "").strip()
|
|
|
- if v in ("case_based", "case", "case出发"):
|
|
|
- return "case出发"
|
|
|
- if v in ("feature_based", "feature", "特征出发"):
|
|
|
- return "特征出发"
|
|
|
- return v
|
|
|
-
|
|
|
-
|
|
|
-def _map_channel(value: Any) -> str:
|
|
|
- v = str(value or "").strip()
|
|
|
- mapping = {
|
|
|
- "search": "抖音搜索",
|
|
|
- "author": "订阅账号",
|
|
|
- "ranking": "榜单",
|
|
|
- "other": "其他",
|
|
|
- "抖音搜索": "抖音搜索",
|
|
|
- "订阅账号": "订阅账号",
|
|
|
- "榜单": "榜单",
|
|
|
- "其他": "其他",
|
|
|
- }
|
|
|
- return mapping.get(v, v)
|
|
|
-
|
|
|
-
|
|
|
-def _map_decision_basis(value: Any) -> str:
|
|
|
- v = str(value or "").strip()
|
|
|
- mapping = {
|
|
|
- "content_portrait": "内容画像匹配",
|
|
|
- "author_portrait": "作者画像匹配",
|
|
|
- "demand_filtering": "需求筛选",
|
|
|
- "other": "其他",
|
|
|
- "画像缺失": "画像缺失",
|
|
|
- "内容画像匹配": "内容画像匹配",
|
|
|
- "作者画像匹配": "作者画像匹配",
|
|
|
- "需求筛选": "需求筛选",
|
|
|
- "其他": "其他",
|
|
|
- }
|
|
|
- return mapping.get(v, v)
|
|
|
-
|
|
|
-
|
|
|
-def _infer_decision_basis_from_output_content(content: Dict[str, Any]) -> str:
|
|
|
- portrait = content.get("portrait_data") or {}
|
|
|
- source = str(portrait.get("source") or "").strip()
|
|
|
- if source == "content_like":
|
|
|
- return "内容画像匹配"
|
|
|
- if source == "account_fans":
|
|
|
- return "作者画像匹配"
|
|
|
- if source == "none":
|
|
|
- return "画像缺失"
|
|
|
- return ""
|
|
|
-
|
|
|
-
|
|
|
-def _build_base_row(*, trace_id: str, content: Dict[str, Any], input_features: List[str], query: str) -> Dict[str, Any]:
|
|
|
- return {
|
|
|
- "trace_id": trace_id,
|
|
|
- "aweme_id": str(content.get("aweme_id") or "").strip(),
|
|
|
- "title": str(content.get("title") or "").strip(),
|
|
|
- "author_nickname": str(content.get("author_nickname") or "").strip(),
|
|
|
- "strategy_type": "",
|
|
|
- "from_case_aweme_id": "",
|
|
|
- "from_case_point": "",
|
|
|
- "from_feature": "",
|
|
|
- "search_keyword": str(query or "").strip(),
|
|
|
- "channel": "抖音搜索",
|
|
|
- "decision_basis": _infer_decision_basis_from_output_content(content),
|
|
|
- "decision_notes": str(content.get("reason") or "").strip(),
|
|
|
- "input_features": input_features,
|
|
|
- }
|
|
|
-
|
|
|
-
|
|
|
-_ROW_KEYS: Tuple[str, ...] = (
|
|
|
- "trace_id",
|
|
|
- "aweme_id",
|
|
|
- "title",
|
|
|
- "author_nickname",
|
|
|
- "strategy_type",
|
|
|
- "from_case_aweme_id",
|
|
|
- "from_case_point",
|
|
|
- "from_feature",
|
|
|
- "search_keyword",
|
|
|
- "channel",
|
|
|
- "decision_basis",
|
|
|
- "decision_notes",
|
|
|
- "input_features",
|
|
|
-)
|
|
|
-
|
|
|
-
|
|
|
-def _sanitize_row(row: Dict[str, Any]) -> Dict[str, Any]:
|
|
|
- """只保留固定字段,并把枚举值规范成中文。"""
|
|
|
- out: Dict[str, Any] = {k: row.get(k, "") for k in _ROW_KEYS}
|
|
|
- out["strategy_type"] = _map_strategy_type(out.get("strategy_type"))
|
|
|
- out["channel"] = _map_channel(out.get("channel"))
|
|
|
- out["decision_basis"] = _map_decision_basis(out.get("decision_basis"))
|
|
|
- # input_features 规范为 list[str]
|
|
|
- feats = out.get("input_features")
|
|
|
- if isinstance(feats, list):
|
|
|
- out["input_features"] = [str(x).strip() for x in feats if str(x).strip()]
|
|
|
- elif isinstance(feats, str):
|
|
|
- out["input_features"] = _split_input_features(feats)
|
|
|
- else:
|
|
|
- out["input_features"] = []
|
|
|
- return out
|
|
|
-
|
|
|
-
|
|
|
-def _normalize_payload(*, trace_id: str, payload: Dict[str, Any]) -> Dict[str, Any]:
|
|
|
- # tool 只做最小职责:过滤/补全/规范化;复杂推理由 skill 生成 summary_json 来完成
|
|
|
- raw_rows = payload.get("rows")
|
|
|
- rows_in_payload: List[Dict[str, Any]] = []
|
|
|
- if isinstance(raw_rows, list):
|
|
|
- for item in raw_rows:
|
|
|
- if isinstance(item, dict):
|
|
|
- rows_in_payload.append(item)
|
|
|
-
|
|
|
- output_json = _load_output_json(trace_id=trace_id) or {}
|
|
|
- query = str(output_json.get("query") or "").strip()
|
|
|
- input_features = _split_input_features(query)
|
|
|
- contents = _extract_contents(trace_id=trace_id)
|
|
|
- contents_by_aweme_id: Dict[str, Dict[str, Any]] = {
|
|
|
- str(c.get("aweme_id") or "").strip(): c for c in contents if str(c.get("aweme_id") or "").strip()
|
|
|
- }
|
|
|
-
|
|
|
- # 先把 payload rows 归并到 aweme_id
|
|
|
- payload_by_aweme_id: Dict[str, Dict[str, Any]] = {}
|
|
|
- for r in rows_in_payload:
|
|
|
- aweme_id = str(r.get("aweme_id") or r.get("awemeId") or "").strip()
|
|
|
- if not aweme_id:
|
|
|
- continue
|
|
|
- payload_by_aweme_id[aweme_id] = dict(r)
|
|
|
-
|
|
|
- # 只允许 payload 覆盖“策略/来源/解释”字段,避免覆盖 output.json.contents 的身份字段(title/author 等)
|
|
|
- allowed_payload_keys: set[str] = {
|
|
|
- "strategy_type",
|
|
|
- "from_case_aweme_id",
|
|
|
- "from_case_point",
|
|
|
- "from_feature",
|
|
|
- "search_keyword",
|
|
|
- "channel",
|
|
|
- "decision_basis",
|
|
|
- "decision_notes",
|
|
|
- "input_features",
|
|
|
- }
|
|
|
-
|
|
|
- # 兼容 payload 的常见别名/驼峰 key(模型输出不稳定时,尽量不丢信息)
|
|
|
- alias_map: Dict[str, Tuple[str, ...]] = {
|
|
|
- "strategy_type": ("strategy_type", "strategyType"),
|
|
|
- "from_case_aweme_id": ("from_case_aweme_id", "fromCaseAwemeId", "case_aweme_id", "caseAwemeId"),
|
|
|
- "from_case_point": ("from_case_point", "fromCasePoint", "case_point", "casePoint"),
|
|
|
- "from_feature": ("from_feature", "fromFeature", "feature", "from_feature_name"),
|
|
|
- "search_keyword": ("search_keyword", "searchKeyword", "keyword"),
|
|
|
- "channel": ("channel", "source_channel", "sourceChannel", "source"),
|
|
|
- "decision_basis": ("decision_basis", "decisionBasis"),
|
|
|
- "decision_notes": ("decision_notes", "decisionNotes", "notes"),
|
|
|
- "input_features": ("input_features", "inputFeatures"),
|
|
|
- }
|
|
|
-
|
|
|
- def _pick(provided: Dict[str, Any], key: str) -> Any:
|
|
|
- for k in alias_map.get(key, (key,)):
|
|
|
- if k in provided:
|
|
|
- return provided.get(k)
|
|
|
- return None
|
|
|
-
|
|
|
- normalized: List[Dict[str, Any]] = []
|
|
|
- for aweme_id, content in contents_by_aweme_id.items():
|
|
|
- base = _build_base_row(trace_id=trace_id, content=content, input_features=input_features, query=query)
|
|
|
- provided = payload_by_aweme_id.get(aweme_id) or {}
|
|
|
-
|
|
|
- merged = dict(base)
|
|
|
- # 只合并允许覆盖的字段
|
|
|
- for k in allowed_payload_keys:
|
|
|
- v = _pick(provided, k)
|
|
|
- if v is not None:
|
|
|
- merged[k] = v
|
|
|
-
|
|
|
- # 身份字段强制以 output.json.contents 为准(即使 payload 传了也不采纳)
|
|
|
- merged["aweme_id"] = str(content.get("aweme_id") or "").strip()
|
|
|
- merged["title"] = str(content.get("title") or "").strip()
|
|
|
- merged["author_nickname"] = str(content.get("author_nickname") or "").strip()
|
|
|
-
|
|
|
- # 如果缺失 input_features,用 query 拆分补齐
|
|
|
- if "input_features" not in merged or not merged.get("input_features"):
|
|
|
- merged["input_features"] = input_features
|
|
|
-
|
|
|
- normalized.append(_sanitize_row(merged))
|
|
|
-
|
|
|
- # 保持稳定顺序:按 rank(若有)或 aweme_id
|
|
|
- def _sort_key(r: Dict[str, Any]) -> Tuple[int, str]:
|
|
|
- c = contents_by_aweme_id.get(str(r.get("aweme_id") or "").strip()) or {}
|
|
|
- try:
|
|
|
- rank = int(c.get("rank") or 0)
|
|
|
- except Exception:
|
|
|
- rank = 0
|
|
|
- return (rank if rank > 0 else 10**9, str(r.get("aweme_id") or ""))
|
|
|
-
|
|
|
- normalized.sort(key=_sort_key)
|
|
|
- return {"rows": normalized}
|
|
|
-
|
|
|
-
|
|
|
-def _write_process_trace(*, trace_id: str, payload: Dict[str, Any]) -> Path:
|
|
|
- out_dir = _output_dir_path() / trace_id
|
|
|
- out_dir.mkdir(parents=True, exist_ok=True)
|
|
|
- path = out_dir / "process_trace.json"
|
|
|
- # 输出格式收敛:只允许 {"rows": [...]}
|
|
|
- doc = {"rows": payload.get("rows") or []}
|
|
|
- with path.open("w", encoding="utf-8") as f:
|
|
|
- json.dump(doc, f, ensure_ascii=False, indent=2)
|
|
|
- return path
|
|
|
-
|
|
|
-
|
|
|
-def _sync_process_trace_rows_to_mysql(*, trace_id: str, rows: List[Dict[str, Any]]) -> Dict[str, Any]:
|
|
|
- """
|
|
|
- 将每条归一化后的策略行序列化为 JSON 文本,按 (trace_id, aweme_id) 更新 process_trace 与 channel。
|
|
|
-
|
|
|
- channel 当前统一为「抖音」(与 process_trace.json 内 channel「抖音搜索」区分)。
|
|
|
- 表中无匹配行时 rowcount 为 0,计入 skipped。
|
|
|
- """
|
|
|
- updated = 0
|
|
|
- skipped = 0
|
|
|
- errors: List[str] = []
|
|
|
- for row in rows:
|
|
|
- aweme_id = str(row.get("aweme_id") or "").strip()
|
|
|
- if not aweme_id:
|
|
|
- skipped += 1
|
|
|
- continue
|
|
|
- text = json.dumps(row, ensure_ascii=False)
|
|
|
- try:
|
|
|
- n = update_process_trace_by_aweme_id(
|
|
|
- trace_id=trace_id,
|
|
|
- aweme_id=aweme_id,
|
|
|
- process_trace_text=text,
|
|
|
- channel="抖音",
|
|
|
- )
|
|
|
- if n > 0:
|
|
|
- updated += 1
|
|
|
- else:
|
|
|
- skipped += 1
|
|
|
- except Exception as e:
|
|
|
- logger.warning(
|
|
|
- "process_trace 回写 MySQL 失败 trace_id=%s aweme_id=%s: %s",
|
|
|
- trace_id,
|
|
|
- aweme_id,
|
|
|
- e,
|
|
|
- exc_info=True,
|
|
|
- )
|
|
|
- errors.append(f"{aweme_id}: {e}")
|
|
|
- return {"updated": updated, "skipped": skipped, "errors": errors}
|
|
|
-
|
|
|
-
|
|
|
-@tool(
|
|
|
- description=(
|
|
|
- "在**全部流程执行完毕之后**调用:把每条最终入选内容的「选择策略」整理成表格 JSON,"
|
|
|
- "写入当前任务的 output 目录下的 process_trace.json,便于后续复盘;"
|
|
|
- "并将每一行策略 JSON 序列化为文本,按 trace_id + aweme_id 回写到 "
|
|
|
- "demand_find_content_result.process_trace,并同步将 channel 字段设为「抖音」。"
|
|
|
- "参数 summary_json 为 JSON 字符串,可以是数组或对象(对象需包含 rows)。"
|
|
|
- "可选参数 log_path/log_text 用于传入本次运行日志(便于复盘留档/未来扩展)。"
|
|
|
- ),
|
|
|
-)
|
|
|
-async def exec_summary(
|
|
|
- trace_id: str,
|
|
|
- summary_json: str,
|
|
|
- log_path: str = "",
|
|
|
- log_text: str = "",
|
|
|
-) -> ToolResult:
|
|
|
- call_params = {
|
|
|
- "trace_id": trace_id,
|
|
|
- "summary_json": "<json>",
|
|
|
- "log_path": (log_path or "").strip(),
|
|
|
- "log_text": "<text>",
|
|
|
- }
|
|
|
- tid = (trace_id or "").strip()
|
|
|
- if not tid:
|
|
|
- err = ToolResult(
|
|
|
- title="过程摘要",
|
|
|
- output="trace_id 不能为空",
|
|
|
- metadata={"ok": False, "error": "empty trace_id"},
|
|
|
- )
|
|
|
- log_tool_call(_LOG_LABEL, call_params, format_tool_result_for_log(err))
|
|
|
- return err
|
|
|
-
|
|
|
- try:
|
|
|
- payload = _parse_payload(summary_json)
|
|
|
- except json.JSONDecodeError as e:
|
|
|
- err = ToolResult(
|
|
|
- title="过程摘要",
|
|
|
- output=f"summary_json 不是合法 JSON: {e}",
|
|
|
- metadata={"ok": False, "error": str(e)},
|
|
|
- )
|
|
|
- log_tool_call(_LOG_LABEL, call_params, format_tool_result_for_log(err))
|
|
|
- return err
|
|
|
- except ValueError as e:
|
|
|
- err = ToolResult(
|
|
|
- title="过程摘要",
|
|
|
- output=str(e),
|
|
|
- metadata={"ok": False, "error": str(e)},
|
|
|
- )
|
|
|
- log_tool_call(_LOG_LABEL, call_params, format_tool_result_for_log(err))
|
|
|
- return err
|
|
|
-
|
|
|
- payload = _normalize_payload(trace_id=tid, payload=payload)
|
|
|
-
|
|
|
- try:
|
|
|
- path = _write_process_trace(trace_id=tid, payload=payload)
|
|
|
- except OSError as e:
|
|
|
- msg = f"写入 process_trace.json 失败: {e}"
|
|
|
- logger.error(msg, exc_info=True)
|
|
|
- err = ToolResult(title="过程摘要", output=msg, metadata={"ok": False, "error": str(e)})
|
|
|
- log_tool_call(_LOG_LABEL, call_params, format_tool_result_for_log(err))
|
|
|
- return err
|
|
|
-
|
|
|
- rows = payload.get("rows") or []
|
|
|
- mysql_meta: Dict[str, Any]
|
|
|
- try:
|
|
|
- mysql_meta = _sync_process_trace_rows_to_mysql(trace_id=tid, rows=rows if isinstance(rows, list) else [])
|
|
|
- except Exception as e:
|
|
|
- logger.warning("process_trace 批量回写 MySQL 异常: %s", e, exc_info=True)
|
|
|
- mysql_meta = {"updated": 0, "skipped": 0, "errors": [str(e)]}
|
|
|
-
|
|
|
- out = ToolResult(
|
|
|
- title="过程摘要",
|
|
|
- output=f"已写入 {path};MySQL process_trace 已更新 {mysql_meta.get('updated', 0)} 条",
|
|
|
- metadata={
|
|
|
- "ok": True,
|
|
|
- "trace_id": tid,
|
|
|
- "path": str(path),
|
|
|
- "log_path": (log_path or "").strip(),
|
|
|
- "log_text_len": len((log_text or "").strip()),
|
|
|
- "mysql_process_trace_updated": mysql_meta.get("updated", 0),
|
|
|
- "mysql_process_trace_skipped": mysql_meta.get("skipped", 0),
|
|
|
- "mysql_process_trace_errors": mysql_meta.get("errors") or [],
|
|
|
- },
|
|
|
- )
|
|
|
- log_tool_call(_LOG_LABEL, {"trace_id": tid}, format_tool_result_for_log(out))
|
|
|
- return out
|