exec_summary.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422
  1. """
  2. 在流程结束后写入**内容策略表格** JSON,并回写 MySQL。
  3. 输出路径:{OUTPUT_DIR}/{trace_id}/process_trace.json
  4. 每条策略行另按 (trace_id, aweme_id) 更新 demand_find_content_result.process_trace(TEXT)。
  5. """
  6. from __future__ import annotations
  7. import json
  8. import logging
  9. import os
  10. from pathlib import Path
  11. from typing import Any, Dict, List, Optional, Tuple
  12. from agent.tools import tool, ToolResult
  13. from utils.tool_logging import format_tool_result_for_log, log_tool_call
  14. from db import update_process_trace_by_aweme_id
  15. _LOG_LABEL = "工具调用:exec_summary -> 写入过程 trace JSON"
  16. logger = logging.getLogger(__name__)
  17. def _output_dir_path() -> Path:
  18. # 与 store_results_mysql / output.json 目录约定一致
  19. return Path(os.getenv("OUTPUT_DIR", ".cache/output"))
  20. def _parse_payload(summary_json: str) -> Dict[str, Any]:
  21. """
  22. 解析并规范化 LLM 传入的表格数据。
  23. - 如果是数组:视为“表格行列表”,包成 {"rows": [...]}
  24. - 如果是对象:直接返回(用于后续扩展字段)
  25. """
  26. data = json.loads(summary_json)
  27. if isinstance(data, list):
  28. return {"rows": data}
  29. if not isinstance(data, dict):
  30. raise ValueError("summary_json 解析后必须是 JSON 对象或数组")
  31. return data
  32. def _split_input_features(raw: str) -> List[str]:
  33. s = (raw or "").strip()
  34. if not s:
  35. return []
  36. parts = s.replace(",", ",").split(",")
  37. out: List[str] = []
  38. for p in parts:
  39. t = p.strip()
  40. if t:
  41. out.append(t)
  42. return out
  43. def _load_output_json(*, trace_id: str) -> Optional[Dict[str, Any]]:
  44. path = _output_dir_path() / trace_id / "output.json"
  45. try:
  46. with path.open("r", encoding="utf-8") as f:
  47. data = json.load(f)
  48. except FileNotFoundError:
  49. return None
  50. except Exception:
  51. logger.warning("读取 output.json 失败: %s", str(path), exc_info=True)
  52. return None
  53. return data if isinstance(data, dict) else None
  54. def _extract_contents(*, trace_id: str) -> List[Dict[str, Any]]:
  55. """
  56. 从 output.json 读取最终入选 contents。
  57. 约定:
  58. - 只允许对 output.json.contents 内的 aweme_id 生成/写入 process_trace rows
  59. """
  60. output_json = _load_output_json(trace_id=trace_id) or {}
  61. contents = output_json.get("contents")
  62. if not isinstance(contents, list):
  63. return []
  64. out: List[Dict[str, Any]] = []
  65. for item in contents:
  66. if isinstance(item, dict):
  67. out.append(item)
  68. return out
  69. def _map_strategy_type(value: Any) -> str:
  70. v = str(value or "").strip()
  71. if v in ("case_based", "case", "case出发"):
  72. return "case出发"
  73. if v in ("feature_based", "feature", "特征出发"):
  74. return "特征出发"
  75. return v
  76. def _map_channel(value: Any) -> str:
  77. v = str(value or "").strip()
  78. mapping = {
  79. "search": "抖音搜索",
  80. "author": "订阅账号",
  81. "ranking": "榜单",
  82. "other": "其他",
  83. "抖音搜索": "抖音搜索",
  84. "订阅账号": "订阅账号",
  85. "榜单": "榜单",
  86. "其他": "其他",
  87. }
  88. return mapping.get(v, v)
  89. def _map_decision_basis(value: Any) -> str:
  90. v = str(value or "").strip()
  91. mapping = {
  92. "content_portrait": "内容画像匹配",
  93. "author_portrait": "作者画像匹配",
  94. "demand_filtering": "需求筛选",
  95. "other": "其他",
  96. "画像缺失": "画像缺失",
  97. "内容画像匹配": "内容画像匹配",
  98. "作者画像匹配": "作者画像匹配",
  99. "需求筛选": "需求筛选",
  100. "其他": "其他",
  101. }
  102. return mapping.get(v, v)
  103. def _infer_decision_basis_from_output_content(content: Dict[str, Any]) -> str:
  104. portrait = content.get("portrait_data") or {}
  105. source = str(portrait.get("source") or "").strip()
  106. if source == "content_like":
  107. return "内容画像匹配"
  108. if source == "account_fans":
  109. return "作者画像匹配"
  110. if source == "none":
  111. return "画像缺失"
  112. return ""
  113. def _build_base_row(*, trace_id: str, content: Dict[str, Any], input_features: List[str], query: str) -> Dict[str, Any]:
  114. return {
  115. "trace_id": trace_id,
  116. "aweme_id": str(content.get("aweme_id") or "").strip(),
  117. "title": str(content.get("title") or "").strip(),
  118. "author_nickname": str(content.get("author_nickname") or "").strip(),
  119. "strategy_type": "",
  120. "from_case_aweme_id": "",
  121. "from_case_point": "",
  122. "from_feature": "",
  123. "search_keyword": str(query or "").strip(),
  124. "channel": "抖音搜索",
  125. "decision_basis": _infer_decision_basis_from_output_content(content),
  126. "decision_notes": str(content.get("reason") or "").strip(),
  127. "input_features": input_features,
  128. }
  129. _ROW_KEYS: Tuple[str, ...] = (
  130. "trace_id",
  131. "aweme_id",
  132. "title",
  133. "author_nickname",
  134. "strategy_type",
  135. "from_case_aweme_id",
  136. "from_case_point",
  137. "from_feature",
  138. "search_keyword",
  139. "channel",
  140. "decision_basis",
  141. "decision_notes",
  142. "input_features",
  143. )
  144. def _sanitize_row(row: Dict[str, Any]) -> Dict[str, Any]:
  145. """只保留固定字段,并把枚举值规范成中文。"""
  146. out: Dict[str, Any] = {k: row.get(k, "") for k in _ROW_KEYS}
  147. out["strategy_type"] = _map_strategy_type(out.get("strategy_type"))
  148. out["channel"] = _map_channel(out.get("channel"))
  149. out["decision_basis"] = _map_decision_basis(out.get("decision_basis"))
  150. # input_features 规范为 list[str]
  151. feats = out.get("input_features")
  152. if isinstance(feats, list):
  153. out["input_features"] = [str(x).strip() for x in feats if str(x).strip()]
  154. elif isinstance(feats, str):
  155. out["input_features"] = _split_input_features(feats)
  156. else:
  157. out["input_features"] = []
  158. return out
  159. def _normalize_payload(*, trace_id: str, payload: Dict[str, Any]) -> Dict[str, Any]:
  160. # tool 只做最小职责:过滤/补全/规范化;复杂推理由 skill 生成 summary_json 来完成
  161. raw_rows = payload.get("rows")
  162. rows_in_payload: List[Dict[str, Any]] = []
  163. if isinstance(raw_rows, list):
  164. for item in raw_rows:
  165. if isinstance(item, dict):
  166. rows_in_payload.append(item)
  167. output_json = _load_output_json(trace_id=trace_id) or {}
  168. query = str(output_json.get("query") or "").strip()
  169. input_features = _split_input_features(query)
  170. contents = _extract_contents(trace_id=trace_id)
  171. contents_by_aweme_id: Dict[str, Dict[str, Any]] = {
  172. str(c.get("aweme_id") or "").strip(): c for c in contents if str(c.get("aweme_id") or "").strip()
  173. }
  174. # 先把 payload rows 归并到 aweme_id
  175. payload_by_aweme_id: Dict[str, Dict[str, Any]] = {}
  176. for r in rows_in_payload:
  177. aweme_id = str(r.get("aweme_id") or r.get("awemeId") or "").strip()
  178. if not aweme_id:
  179. continue
  180. payload_by_aweme_id[aweme_id] = dict(r)
  181. # 只允许 payload 覆盖“策略/来源/解释”字段,避免覆盖 output.json.contents 的身份字段(title/author 等)
  182. allowed_payload_keys: set[str] = {
  183. "strategy_type",
  184. "from_case_aweme_id",
  185. "from_case_point",
  186. "from_feature",
  187. "search_keyword",
  188. "channel",
  189. "decision_basis",
  190. "decision_notes",
  191. "input_features",
  192. }
  193. # 兼容 payload 的常见别名/驼峰 key(模型输出不稳定时,尽量不丢信息)
  194. alias_map: Dict[str, Tuple[str, ...]] = {
  195. "strategy_type": ("strategy_type", "strategyType"),
  196. "from_case_aweme_id": ("from_case_aweme_id", "fromCaseAwemeId", "case_aweme_id", "caseAwemeId"),
  197. "from_case_point": ("from_case_point", "fromCasePoint", "case_point", "casePoint"),
  198. "from_feature": ("from_feature", "fromFeature", "feature", "from_feature_name"),
  199. "search_keyword": ("search_keyword", "searchKeyword", "keyword"),
  200. "channel": ("channel", "source_channel", "sourceChannel", "source"),
  201. "decision_basis": ("decision_basis", "decisionBasis"),
  202. "decision_notes": ("decision_notes", "decisionNotes", "notes"),
  203. "input_features": ("input_features", "inputFeatures"),
  204. }
  205. def _pick(provided: Dict[str, Any], key: str) -> Any:
  206. for k in alias_map.get(key, (key,)):
  207. if k in provided:
  208. return provided.get(k)
  209. return None
  210. normalized: List[Dict[str, Any]] = []
  211. for aweme_id, content in contents_by_aweme_id.items():
  212. base = _build_base_row(trace_id=trace_id, content=content, input_features=input_features, query=query)
  213. provided = payload_by_aweme_id.get(aweme_id) or {}
  214. merged = dict(base)
  215. # 只合并允许覆盖的字段
  216. for k in allowed_payload_keys:
  217. v = _pick(provided, k)
  218. if v is not None:
  219. merged[k] = v
  220. # 身份字段强制以 output.json.contents 为准(即使 payload 传了也不采纳)
  221. merged["aweme_id"] = str(content.get("aweme_id") or "").strip()
  222. merged["title"] = str(content.get("title") or "").strip()
  223. merged["author_nickname"] = str(content.get("author_nickname") or "").strip()
  224. # 如果缺失 input_features,用 query 拆分补齐
  225. if "input_features" not in merged or not merged.get("input_features"):
  226. merged["input_features"] = input_features
  227. normalized.append(_sanitize_row(merged))
  228. # 保持稳定顺序:按 rank(若有)或 aweme_id
  229. def _sort_key(r: Dict[str, Any]) -> Tuple[int, str]:
  230. c = contents_by_aweme_id.get(str(r.get("aweme_id") or "").strip()) or {}
  231. try:
  232. rank = int(c.get("rank") or 0)
  233. except Exception:
  234. rank = 0
  235. return (rank if rank > 0 else 10**9, str(r.get("aweme_id") or ""))
  236. normalized.sort(key=_sort_key)
  237. return {"rows": normalized}
  238. def _write_process_trace(*, trace_id: str, payload: Dict[str, Any]) -> Path:
  239. out_dir = _output_dir_path() / trace_id
  240. out_dir.mkdir(parents=True, exist_ok=True)
  241. path = out_dir / "process_trace.json"
  242. # 输出格式收敛:只允许 {"rows": [...]}
  243. doc = {"rows": payload.get("rows") or []}
  244. with path.open("w", encoding="utf-8") as f:
  245. json.dump(doc, f, ensure_ascii=False, indent=2)
  246. return path
  247. def _sync_process_trace_rows_to_mysql(*, trace_id: str, rows: List[Dict[str, Any]]) -> Dict[str, Any]:
  248. """
  249. 将每条归一化后的策略行序列化为 JSON 文本,按 (trace_id, aweme_id) 更新 process_trace 与 channel。
  250. channel 当前统一为「抖音」(与 process_trace.json 内 channel「抖音搜索」区分)。
  251. 表中无匹配行时 rowcount 为 0,计入 skipped。
  252. """
  253. updated = 0
  254. skipped = 0
  255. errors: List[str] = []
  256. for row in rows:
  257. aweme_id = str(row.get("aweme_id") or "").strip()
  258. if not aweme_id:
  259. skipped += 1
  260. continue
  261. text = json.dumps(row, ensure_ascii=False)
  262. try:
  263. n = update_process_trace_by_aweme_id(
  264. trace_id=trace_id,
  265. aweme_id=aweme_id,
  266. process_trace_text=text,
  267. channel="抖音",
  268. )
  269. if n > 0:
  270. updated += 1
  271. else:
  272. skipped += 1
  273. except Exception as e:
  274. logger.warning(
  275. "process_trace 回写 MySQL 失败 trace_id=%s aweme_id=%s: %s",
  276. trace_id,
  277. aweme_id,
  278. e,
  279. exc_info=True,
  280. )
  281. errors.append(f"{aweme_id}: {e}")
  282. return {"updated": updated, "skipped": skipped, "errors": errors}
  283. @tool(
  284. description=(
  285. "在**全部流程执行完毕之后**调用:把每条最终入选内容的「选择策略」整理成表格 JSON,"
  286. "写入当前任务的 output 目录下的 process_trace.json,便于后续复盘;"
  287. "并将每一行策略 JSON 序列化为文本,按 trace_id + aweme_id 回写到 "
  288. "demand_find_content_result.process_trace,并同步将 channel 字段设为「抖音」。"
  289. "参数 summary_json 为 JSON 字符串,可以是数组或对象(对象需包含 rows)。"
  290. "可选参数 log_path/log_text 用于传入本次运行日志(便于复盘留档/未来扩展)。"
  291. ),
  292. )
  293. async def exec_summary(
  294. trace_id: str,
  295. summary_json: str,
  296. log_path: str = "",
  297. log_text: str = "",
  298. ) -> ToolResult:
  299. call_params = {
  300. "trace_id": trace_id,
  301. "summary_json": "<json>",
  302. "log_path": (log_path or "").strip(),
  303. "log_text": "<text>",
  304. }
  305. tid = (trace_id or "").strip()
  306. if not tid:
  307. err = ToolResult(
  308. title="过程摘要",
  309. output="trace_id 不能为空",
  310. metadata={"ok": False, "error": "empty trace_id"},
  311. )
  312. log_tool_call(_LOG_LABEL, call_params, format_tool_result_for_log(err))
  313. return err
  314. try:
  315. payload = _parse_payload(summary_json)
  316. except json.JSONDecodeError as e:
  317. err = ToolResult(
  318. title="过程摘要",
  319. output=f"summary_json 不是合法 JSON: {e}",
  320. metadata={"ok": False, "error": str(e)},
  321. )
  322. log_tool_call(_LOG_LABEL, call_params, format_tool_result_for_log(err))
  323. return err
  324. except ValueError as e:
  325. err = ToolResult(
  326. title="过程摘要",
  327. output=str(e),
  328. metadata={"ok": False, "error": str(e)},
  329. )
  330. log_tool_call(_LOG_LABEL, call_params, format_tool_result_for_log(err))
  331. return err
  332. payload = _normalize_payload(trace_id=tid, payload=payload)
  333. try:
  334. path = _write_process_trace(trace_id=tid, payload=payload)
  335. except OSError as e:
  336. msg = f"写入 process_trace.json 失败: {e}"
  337. logger.error(msg, exc_info=True)
  338. err = ToolResult(title="过程摘要", output=msg, metadata={"ok": False, "error": str(e)})
  339. log_tool_call(_LOG_LABEL, call_params, format_tool_result_for_log(err))
  340. return err
  341. rows = payload.get("rows") or []
  342. mysql_meta: Dict[str, Any]
  343. try:
  344. mysql_meta = _sync_process_trace_rows_to_mysql(trace_id=tid, rows=rows if isinstance(rows, list) else [])
  345. except Exception as e:
  346. logger.warning("process_trace 批量回写 MySQL 异常: %s", e, exc_info=True)
  347. mysql_meta = {"updated": 0, "skipped": 0, "errors": [str(e)]}
  348. out = ToolResult(
  349. title="过程摘要",
  350. output=f"已写入 {path};MySQL process_trace 已更新 {mysql_meta.get('updated', 0)} 条",
  351. metadata={
  352. "ok": True,
  353. "trace_id": tid,
  354. "path": str(path),
  355. "log_path": (log_path or "").strip(),
  356. "log_text_len": len((log_text or "").strip()),
  357. "mysql_process_trace_updated": mysql_meta.get("updated", 0),
  358. "mysql_process_trace_skipped": mysql_meta.get("skipped", 0),
  359. "mysql_process_trace_errors": mysql_meta.get("errors") or [],
  360. },
  361. )
  362. log_tool_call(_LOG_LABEL, {"trace_id": tid}, format_tool_result_for_log(out))
  363. return out