exec_summary.py 10.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295
  1. """
  2. 在流程结束后写入**内容策略表格** JSON。
  3. 输出路径:{OUTPUT_DIR}/{trace_id}/process_trace.json
  4. """
  5. from __future__ import annotations
  6. import json
  7. import logging
  8. import os
  9. from datetime import datetime, timezone
  10. from pathlib import Path
  11. from typing import Any, Dict, List, Optional, Sequence, Tuple
  12. from agent.tools import tool, ToolResult
  13. from utils.tool_logging import format_tool_result_for_log, log_tool_call
  14. _LOG_LABEL = "工具调用:exec_summary -> 写入过程 trace JSON"
  15. logger = logging.getLogger(__name__)
  16. def _output_dir_path() -> Path:
  17. # 与 store_results_mysql / output.json 目录约定一致
  18. return Path(os.getenv("OUTPUT_DIR", ".cache/output"))
  19. def _parse_payload(summary_json: str) -> Dict[str, Any]:
  20. """
  21. 解析并规范化 LLM 传入的表格数据。
  22. - 如果是数组:视为“表格行列表”,包成 {"rows": [...]}
  23. - 如果是对象:直接返回(用于后续扩展字段)
  24. """
  25. data = json.loads(summary_json)
  26. if isinstance(data, list):
  27. return {"rows": data}
  28. if not isinstance(data, dict):
  29. raise ValueError("summary_json 解析后必须是 JSON 对象或数组")
  30. return data
  31. def _split_input_features(raw: str) -> List[str]:
  32. s = (raw or "").strip()
  33. if not s:
  34. return []
  35. parts = s.replace(",", ",").split(",")
  36. out: List[str] = []
  37. for p in parts:
  38. t = p.strip()
  39. if t:
  40. out.append(t)
  41. return out
  42. def _load_output_json(*, trace_id: str) -> Optional[Dict[str, Any]]:
  43. path = _output_dir_path() / trace_id / "output.json"
  44. try:
  45. with path.open("r", encoding="utf-8") as f:
  46. data = json.load(f)
  47. except FileNotFoundError:
  48. return None
  49. except Exception:
  50. logger.warning("读取 output.json 失败: %s", str(path), exc_info=True)
  51. return None
  52. return data if isinstance(data, dict) else None
  53. def _extract_get_video_topic_videos(*, trace_id: str) -> List[Dict[str, Any]]:
  54. """
  55. 从 log.txt 中提取 get_video_topic 的返回 metadata.videos(原始选题点)。
  56. 期望日志片段形态(render_log_html 同源格式):
  57. [FOLD:🔧 工具调用:get_video_topic ...]
  58. ...
  59. [FOLD:📤 返回内容]
  60. <json array>
  61. [/FOLD]
  62. """
  63. log_path = _output_dir_path() / trace_id / "log.txt"
  64. try:
  65. text = log_path.read_text(encoding="utf-8")
  66. except FileNotFoundError:
  67. return []
  68. except Exception:
  69. logger.warning("读取 log.txt 失败: %s", str(log_path), exc_info=True)
  70. return []
  71. marker = "[FOLD:🔧 工具调用:get_video_topic"
  72. start = text.find(marker)
  73. if start < 0:
  74. return []
  75. snippet = text[start:]
  76. out_marker = "[FOLD:📤 返回内容]"
  77. out_start = snippet.find(out_marker)
  78. if out_start < 0:
  79. return []
  80. after = snippet[out_start + len(out_marker) :]
  81. json_start = after.find("[")
  82. if json_start < 0:
  83. return []
  84. json_end = after.find("[/FOLD]")
  85. if json_end < 0:
  86. return []
  87. raw = after[json_start:json_end].strip()
  88. try:
  89. parsed = json.loads(raw)
  90. except Exception:
  91. logger.warning("解析 get_video_topic 返回 JSON 失败", exc_info=True)
  92. return []
  93. return parsed if isinstance(parsed, list) else []
  94. def _flatten_case_points_text(video: Dict[str, Any]) -> str:
  95. tp = video.get("选题点")
  96. if not isinstance(tp, dict):
  97. return ""
  98. tokens: List[str] = []
  99. for k in ("灵感点", "目的点", "关键点"):
  100. v = tp.get(k)
  101. if isinstance(v, list):
  102. for x in v:
  103. if isinstance(x, str) and x.strip():
  104. tokens.append(x.strip())
  105. return " ".join(tokens)
  106. def _score_match(*, row_text: str, candidate_text: str) -> int:
  107. """
  108. 简单可控的匹配评分:按“子串命中次数”计分,避免引入分词依赖。
  109. """
  110. rt = (row_text or "").strip()
  111. ct = (candidate_text or "").strip()
  112. if not rt or not ct:
  113. return 0
  114. score = 0
  115. for token in _split_input_features(rt):
  116. if token and token in ct:
  117. score += 2
  118. # 再做一次整体包含(更强信号)
  119. if rt and rt in ct:
  120. score += 3
  121. return score
  122. def _pick_best_case_video(
  123. *, row: Dict[str, Any], case_videos: Sequence[Dict[str, Any]]
  124. ) -> Optional[Dict[str, Any]]:
  125. if not case_videos:
  126. return None
  127. row_text = " ".join(
  128. [
  129. str(row.get("from_case_point") or ""),
  130. str(row.get("search_keyword") or ""),
  131. str(row.get("title") or ""),
  132. ]
  133. ).strip()
  134. scored: List[Tuple[int, int]] = []
  135. for i, v in enumerate(case_videos):
  136. scored.append((_score_match(row_text=row_text, candidate_text=_flatten_case_points_text(v)), i))
  137. scored.sort(reverse=True)
  138. best_score, best_idx = scored[0]
  139. # 低于 1 视为“不确定”,但仍给出一个稳定的默认(第一个)
  140. if best_score <= 0:
  141. return case_videos[0]
  142. return case_videos[best_idx]
  143. def _normalize_payload(*, trace_id: str, payload: Dict[str, Any]) -> Dict[str, Any]:
  144. rows = payload.get("rows")
  145. if not isinstance(rows, list):
  146. return payload
  147. output_json = _load_output_json(trace_id=trace_id) or {}
  148. input_features = _split_input_features(str(output_json.get("query") or ""))
  149. case_videos = _extract_get_video_topic_videos(trace_id=trace_id)
  150. normalized_rows: List[Any] = []
  151. for item in rows:
  152. if not isinstance(item, dict):
  153. normalized_rows.append(item)
  154. continue
  155. row = dict(item)
  156. # 1) 每条视频都体现原始输入特征词
  157. if "input_features" not in row:
  158. row["input_features"] = input_features
  159. # 2) from_case_point:尽量输出“原始选题点信息”,而不是联想词
  160. if "from_case_point" in row and case_videos:
  161. original = _pick_best_case_video(row=row, case_videos=case_videos)
  162. if isinstance(original, dict) and isinstance(original.get("选题点"), dict):
  163. # 保留模型原先写的联想/归类结果,便于排查,但不作为主字段
  164. if isinstance(row.get("from_case_point"), str) and row.get("from_case_point"):
  165. row["from_case_point_guess"] = row["from_case_point"]
  166. row["from_case_point"] = original.get("选题点")
  167. if "from_case_aweme_id" not in row:
  168. row["from_case_aweme_id"] = str(original.get("id") or "").strip() or None
  169. normalized_rows.append(row)
  170. out = dict(payload)
  171. out["rows"] = normalized_rows
  172. return out
  173. def _write_process_trace(*, trace_id: str, payload: Dict[str, Any]) -> Path:
  174. out_dir = _output_dir_path() / trace_id
  175. out_dir.mkdir(parents=True, exist_ok=True)
  176. path = out_dir / "process_trace.json"
  177. doc = {
  178. **payload,
  179. "schema_version": "1.0",
  180. "trace_id": trace_id,
  181. "generated_at": datetime.now(timezone.utc).isoformat(),
  182. }
  183. with path.open("w", encoding="utf-8") as f:
  184. json.dump(doc, f, ensure_ascii=False, indent=2)
  185. return path
  186. @tool(
  187. description=(
  188. "在**全部流程执行完毕之后**调用:把每条入选/候选内容的「选择策略」整理成表格 JSON,"
  189. "写入当前任务的 output 目录下的 process_trace.json,便于后续复盘。"
  190. "参数 summary_json 为 JSON 字符串,可以是:"
  191. "1)数组:每一项是一行记录;会被包成 {\"rows\": [...]};"
  192. "2)对象:应包含 rows 字段,rows 为行列表。"
  193. "建议每行至少包含:strategy_type(\"case_based\" | \"feature_based\")、"
  194. "from_case_aweme_id / from_feature(来源 case 的选题点或特征)、"
  195. "search_keyword(使用的搜索词)、"
  196. "channel(\"search\" | \"author\" | \"ranking\" | \"other\" 等)、"
  197. "decision_basis(如 \"demand_filtering\" | \"content_portrait\" | \"author_portrait\" | \"other\")、"
  198. "decision_notes(自由文本补充原因)。"
  199. ),
  200. )
  201. async def exec_summary(trace_id: str, summary_json: str) -> ToolResult:
  202. """
  203. Args:
  204. trace_id: 本次任务 trace_id(与 output.json 同目录)。
  205. summary_json: JSON 字符串。对象或数组均可;数组会包成 {\"rows\": [...] }。
  206. """
  207. call_params = {"trace_id": trace_id, "summary_json": "<json>"}
  208. tid = (trace_id or "").strip()
  209. if not tid:
  210. err = ToolResult(
  211. title="过程摘要",
  212. output="trace_id 不能为空",
  213. metadata={"ok": False, "error": "empty trace_id"},
  214. )
  215. log_tool_call(_LOG_LABEL, call_params, format_tool_result_for_log(err))
  216. return err
  217. try:
  218. payload = _parse_payload(summary_json)
  219. except json.JSONDecodeError as e:
  220. err = ToolResult(
  221. title="过程摘要",
  222. output=f"summary_json 不是合法 JSON: {e}",
  223. metadata={"ok": False, "error": str(e)},
  224. )
  225. log_tool_call(_LOG_LABEL, call_params, format_tool_result_for_log(err))
  226. return err
  227. except ValueError as e:
  228. err = ToolResult(
  229. title="过程摘要",
  230. output=str(e),
  231. metadata={"ok": False, "error": str(e)},
  232. )
  233. log_tool_call(_LOG_LABEL, call_params, format_tool_result_for_log(err))
  234. return err
  235. payload = _normalize_payload(trace_id=tid, payload=payload)
  236. try:
  237. path = _write_process_trace(trace_id=tid, payload=payload)
  238. except OSError as e:
  239. msg = f"写入 process_trace.json 失败: {e}"
  240. logger.error(msg, exc_info=True)
  241. err = ToolResult(title="过程摘要", output=msg, metadata={"ok": False, "error": str(e)})
  242. log_tool_call(_LOG_LABEL, call_params, format_tool_result_for_log(err))
  243. return err
  244. out = ToolResult(
  245. title="过程摘要",
  246. output=f"已写入 {path}",
  247. metadata={"ok": True, "trace_id": tid, "path": str(path)},
  248. )
  249. log_tool_call(_LOG_LABEL, {"trace_id": tid}, format_tool_result_for_log(out))
  250. return out