exec_summary.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367
  1. """
  2. 在流程结束后写入**内容策略表格** JSON。
  3. 输出路径:{OUTPUT_DIR}/{trace_id}/process_trace.json
  4. """
  5. from __future__ import annotations
  6. import json
  7. import logging
  8. import os
  9. from pathlib import Path
  10. from typing import Any, Dict, List, Optional, Tuple
  11. from agent.tools import tool, ToolResult
  12. from utils.tool_logging import format_tool_result_for_log, log_tool_call
  13. _LOG_LABEL = "工具调用:exec_summary -> 写入过程 trace JSON"
  14. logger = logging.getLogger(__name__)
  15. def _output_dir_path() -> Path:
  16. # 与 store_results_mysql / output.json 目录约定一致
  17. return Path(os.getenv("OUTPUT_DIR", ".cache/output"))
  18. def _parse_payload(summary_json: str) -> Dict[str, Any]:
  19. """
  20. 解析并规范化 LLM 传入的表格数据。
  21. - 如果是数组:视为“表格行列表”,包成 {"rows": [...]}
  22. - 如果是对象:直接返回(用于后续扩展字段)
  23. """
  24. data = json.loads(summary_json)
  25. if isinstance(data, list):
  26. return {"rows": data}
  27. if not isinstance(data, dict):
  28. raise ValueError("summary_json 解析后必须是 JSON 对象或数组")
  29. return data
  30. def _split_input_features(raw: str) -> List[str]:
  31. s = (raw or "").strip()
  32. if not s:
  33. return []
  34. parts = s.replace(",", ",").split(",")
  35. out: List[str] = []
  36. for p in parts:
  37. t = p.strip()
  38. if t:
  39. out.append(t)
  40. return out
  41. def _load_output_json(*, trace_id: str) -> Optional[Dict[str, Any]]:
  42. path = _output_dir_path() / trace_id / "output.json"
  43. try:
  44. with path.open("r", encoding="utf-8") as f:
  45. data = json.load(f)
  46. except FileNotFoundError:
  47. return None
  48. except Exception:
  49. logger.warning("读取 output.json 失败: %s", str(path), exc_info=True)
  50. return None
  51. return data if isinstance(data, dict) else None
  52. def _extract_contents(*, trace_id: str) -> List[Dict[str, Any]]:
  53. """
  54. 从 output.json 读取最终入选 contents。
  55. 约定:
  56. - 只允许对 output.json.contents 内的 aweme_id 生成/写入 process_trace rows
  57. """
  58. output_json = _load_output_json(trace_id=trace_id) or {}
  59. contents = output_json.get("contents")
  60. if not isinstance(contents, list):
  61. return []
  62. out: List[Dict[str, Any]] = []
  63. for item in contents:
  64. if isinstance(item, dict):
  65. out.append(item)
  66. return out
  67. def _map_strategy_type(value: Any) -> str:
  68. v = str(value or "").strip()
  69. if v in ("case_based", "case", "case出发"):
  70. return "case出发"
  71. if v in ("feature_based", "feature", "特征出发"):
  72. return "特征出发"
  73. return v
  74. def _map_channel(value: Any) -> str:
  75. v = str(value or "").strip()
  76. mapping = {
  77. "search": "抖音搜索",
  78. "author": "订阅账号",
  79. "ranking": "榜单",
  80. "other": "其他",
  81. "抖音搜索": "抖音搜索",
  82. "订阅账号": "订阅账号",
  83. "榜单": "榜单",
  84. "其他": "其他",
  85. }
  86. return mapping.get(v, v)
  87. def _map_decision_basis(value: Any) -> str:
  88. v = str(value or "").strip()
  89. mapping = {
  90. "content_portrait": "内容画像匹配",
  91. "author_portrait": "作者画像匹配",
  92. "demand_filtering": "需求筛选",
  93. "other": "其他",
  94. "画像缺失": "画像缺失",
  95. "内容画像匹配": "内容画像匹配",
  96. "作者画像匹配": "作者画像匹配",
  97. "需求筛选": "需求筛选",
  98. "其他": "其他",
  99. }
  100. return mapping.get(v, v)
  101. def _infer_decision_basis_from_output_content(content: Dict[str, Any]) -> str:
  102. portrait = content.get("portrait_data") or {}
  103. source = str(portrait.get("source") or "").strip()
  104. if source == "content_like":
  105. return "内容画像匹配"
  106. if source == "account_fans":
  107. return "作者画像匹配"
  108. if source == "none":
  109. return "画像缺失"
  110. return ""
  111. def _build_base_row(*, trace_id: str, content: Dict[str, Any], input_features: List[str], query: str) -> Dict[str, Any]:
  112. return {
  113. "trace_id": trace_id,
  114. "aweme_id": str(content.get("aweme_id") or "").strip(),
  115. "title": str(content.get("title") or "").strip(),
  116. "author_nickname": str(content.get("author_nickname") or "").strip(),
  117. "strategy_type": "",
  118. "from_case_aweme_id": "",
  119. "from_case_point": "",
  120. "from_feature": "",
  121. "search_keyword": str(query or "").strip(),
  122. "channel": "抖音搜索",
  123. "decision_basis": _infer_decision_basis_from_output_content(content),
  124. "decision_notes": str(content.get("reason") or "").strip(),
  125. "input_features": input_features,
  126. }
  127. _ROW_KEYS: Tuple[str, ...] = (
  128. "trace_id",
  129. "aweme_id",
  130. "title",
  131. "author_nickname",
  132. "strategy_type",
  133. "from_case_aweme_id",
  134. "from_case_point",
  135. "from_feature",
  136. "search_keyword",
  137. "channel",
  138. "decision_basis",
  139. "decision_notes",
  140. "input_features",
  141. )
  142. def _sanitize_row(row: Dict[str, Any]) -> Dict[str, Any]:
  143. """只保留固定字段,并把枚举值规范成中文。"""
  144. out: Dict[str, Any] = {k: row.get(k, "") for k in _ROW_KEYS}
  145. out["strategy_type"] = _map_strategy_type(out.get("strategy_type"))
  146. out["channel"] = _map_channel(out.get("channel"))
  147. out["decision_basis"] = _map_decision_basis(out.get("decision_basis"))
  148. # input_features 规范为 list[str]
  149. feats = out.get("input_features")
  150. if isinstance(feats, list):
  151. out["input_features"] = [str(x).strip() for x in feats if str(x).strip()]
  152. elif isinstance(feats, str):
  153. out["input_features"] = _split_input_features(feats)
  154. else:
  155. out["input_features"] = []
  156. return out
  157. def _normalize_payload(*, trace_id: str, payload: Dict[str, Any]) -> Dict[str, Any]:
  158. # tool 只做最小职责:过滤/补全/规范化;复杂推理由 skill 生成 summary_json 来完成
  159. raw_rows = payload.get("rows")
  160. rows_in_payload: List[Dict[str, Any]] = []
  161. if isinstance(raw_rows, list):
  162. for item in raw_rows:
  163. if isinstance(item, dict):
  164. rows_in_payload.append(item)
  165. output_json = _load_output_json(trace_id=trace_id) or {}
  166. query = str(output_json.get("query") or "").strip()
  167. input_features = _split_input_features(query)
  168. contents = _extract_contents(trace_id=trace_id)
  169. contents_by_aweme_id: Dict[str, Dict[str, Any]] = {
  170. str(c.get("aweme_id") or "").strip(): c for c in contents if str(c.get("aweme_id") or "").strip()
  171. }
  172. # 先把 payload rows 归并到 aweme_id
  173. payload_by_aweme_id: Dict[str, Dict[str, Any]] = {}
  174. for r in rows_in_payload:
  175. aweme_id = str(r.get("aweme_id") or r.get("awemeId") or "").strip()
  176. if not aweme_id:
  177. continue
  178. payload_by_aweme_id[aweme_id] = dict(r)
  179. # 只允许 payload 覆盖“策略/来源/解释”字段,避免覆盖 output.json.contents 的身份字段(title/author 等)
  180. allowed_payload_keys: set[str] = {
  181. "strategy_type",
  182. "from_case_aweme_id",
  183. "from_case_point",
  184. "from_feature",
  185. "search_keyword",
  186. "channel",
  187. "decision_basis",
  188. "decision_notes",
  189. "input_features",
  190. }
  191. # 兼容 payload 的常见别名/驼峰 key(模型输出不稳定时,尽量不丢信息)
  192. alias_map: Dict[str, Tuple[str, ...]] = {
  193. "strategy_type": ("strategy_type", "strategyType"),
  194. "from_case_aweme_id": ("from_case_aweme_id", "fromCaseAwemeId", "case_aweme_id", "caseAwemeId"),
  195. "from_case_point": ("from_case_point", "fromCasePoint", "case_point", "casePoint"),
  196. "from_feature": ("from_feature", "fromFeature", "feature", "from_feature_name"),
  197. "search_keyword": ("search_keyword", "searchKeyword", "keyword"),
  198. "channel": ("channel", "source_channel", "sourceChannel", "source"),
  199. "decision_basis": ("decision_basis", "decisionBasis"),
  200. "decision_notes": ("decision_notes", "decisionNotes", "notes"),
  201. "input_features": ("input_features", "inputFeatures"),
  202. }
  203. def _pick(provided: Dict[str, Any], key: str) -> Any:
  204. for k in alias_map.get(key, (key,)):
  205. if k in provided:
  206. return provided.get(k)
  207. return None
  208. normalized: List[Dict[str, Any]] = []
  209. for aweme_id, content in contents_by_aweme_id.items():
  210. base = _build_base_row(trace_id=trace_id, content=content, input_features=input_features, query=query)
  211. provided = payload_by_aweme_id.get(aweme_id) or {}
  212. merged = dict(base)
  213. # 只合并允许覆盖的字段
  214. for k in allowed_payload_keys:
  215. v = _pick(provided, k)
  216. if v is not None:
  217. merged[k] = v
  218. # 身份字段强制以 output.json.contents 为准(即使 payload 传了也不采纳)
  219. merged["aweme_id"] = str(content.get("aweme_id") or "").strip()
  220. merged["title"] = str(content.get("title") or "").strip()
  221. merged["author_nickname"] = str(content.get("author_nickname") or "").strip()
  222. # 如果缺失 input_features,用 query 拆分补齐
  223. if "input_features" not in merged or not merged.get("input_features"):
  224. merged["input_features"] = input_features
  225. normalized.append(_sanitize_row(merged))
  226. # 保持稳定顺序:按 rank(若有)或 aweme_id
  227. def _sort_key(r: Dict[str, Any]) -> Tuple[int, str]:
  228. c = contents_by_aweme_id.get(str(r.get("aweme_id") or "").strip()) or {}
  229. try:
  230. rank = int(c.get("rank") or 0)
  231. except Exception:
  232. rank = 0
  233. return (rank if rank > 0 else 10**9, str(r.get("aweme_id") or ""))
  234. normalized.sort(key=_sort_key)
  235. return {"rows": normalized}
  236. def _write_process_trace(*, trace_id: str, payload: Dict[str, Any]) -> Path:
  237. out_dir = _output_dir_path() / trace_id
  238. out_dir.mkdir(parents=True, exist_ok=True)
  239. path = out_dir / "process_trace.json"
  240. # 输出格式收敛:只允许 {"rows": [...]}
  241. doc = {"rows": payload.get("rows") or []}
  242. with path.open("w", encoding="utf-8") as f:
  243. json.dump(doc, f, ensure_ascii=False, indent=2)
  244. return path
  245. @tool(
  246. description=(
  247. "在**全部流程执行完毕之后**调用:把每条最终入选内容的「选择策略」整理成表格 JSON,"
  248. "写入当前任务的 output 目录下的 process_trace.json,便于后续复盘。"
  249. "参数 summary_json 为 JSON 字符串,可以是数组或对象(对象需包含 rows)。"
  250. "可选参数 log_path/log_text 用于传入本次运行日志(便于复盘留档/未来扩展)。"
  251. ),
  252. )
  253. async def exec_summary(
  254. trace_id: str,
  255. summary_json: str,
  256. log_path: str = "",
  257. log_text: str = "",
  258. ) -> ToolResult:
  259. call_params = {
  260. "trace_id": trace_id,
  261. "summary_json": "<json>",
  262. "log_path": (log_path or "").strip(),
  263. "log_text": "<text>",
  264. }
  265. tid = (trace_id or "").strip()
  266. if not tid:
  267. err = ToolResult(
  268. title="过程摘要",
  269. output="trace_id 不能为空",
  270. metadata={"ok": False, "error": "empty trace_id"},
  271. )
  272. log_tool_call(_LOG_LABEL, call_params, format_tool_result_for_log(err))
  273. return err
  274. try:
  275. payload = _parse_payload(summary_json)
  276. except json.JSONDecodeError as e:
  277. err = ToolResult(
  278. title="过程摘要",
  279. output=f"summary_json 不是合法 JSON: {e}",
  280. metadata={"ok": False, "error": str(e)},
  281. )
  282. log_tool_call(_LOG_LABEL, call_params, format_tool_result_for_log(err))
  283. return err
  284. except ValueError as e:
  285. err = ToolResult(
  286. title="过程摘要",
  287. output=str(e),
  288. metadata={"ok": False, "error": str(e)},
  289. )
  290. log_tool_call(_LOG_LABEL, call_params, format_tool_result_for_log(err))
  291. return err
  292. payload = _normalize_payload(trace_id=tid, payload=payload)
  293. try:
  294. path = _write_process_trace(trace_id=tid, payload=payload)
  295. except OSError as e:
  296. msg = f"写入 process_trace.json 失败: {e}"
  297. logger.error(msg, exc_info=True)
  298. err = ToolResult(title="过程摘要", output=msg, metadata={"ok": False, "error": str(e)})
  299. log_tool_call(_LOG_LABEL, call_params, format_tool_result_for_log(err))
  300. return err
  301. out = ToolResult(
  302. title="过程摘要",
  303. output=f"已写入 {path}",
  304. metadata={
  305. "ok": True,
  306. "trace_id": tid,
  307. "path": str(path),
  308. "log_path": (log_path or "").strip(),
  309. "log_text_len": len((log_text or "").strip()),
  310. },
  311. )
  312. log_tool_call(_LOG_LABEL, {"trace_id": tid}, format_tool_result_for_log(out))
  313. return out