| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295 |
- """
- 在流程结束后写入**内容策略表格** JSON。
- 输出路径:{OUTPUT_DIR}/{trace_id}/process_trace.json
- """
- from __future__ import annotations
- import json
- import logging
- import os
- from datetime import datetime, timezone
- from pathlib import Path
- from typing import Any, Dict, List, Optional, Sequence, Tuple
- from agent.tools import tool, ToolResult
- from utils.tool_logging import format_tool_result_for_log, log_tool_call
- _LOG_LABEL = "工具调用:exec_summary -> 写入过程 trace JSON"
- logger = logging.getLogger(__name__)
- def _output_dir_path() -> Path:
- # 与 store_results_mysql / output.json 目录约定一致
- return Path(os.getenv("OUTPUT_DIR", ".cache/output"))
- def _parse_payload(summary_json: str) -> Dict[str, Any]:
- """
- 解析并规范化 LLM 传入的表格数据。
- - 如果是数组:视为“表格行列表”,包成 {"rows": [...]}
- - 如果是对象:直接返回(用于后续扩展字段)
- """
- data = json.loads(summary_json)
- if isinstance(data, list):
- return {"rows": data}
- if not isinstance(data, dict):
- raise ValueError("summary_json 解析后必须是 JSON 对象或数组")
- return data
- def _split_input_features(raw: str) -> List[str]:
- s = (raw or "").strip()
- if not s:
- return []
- parts = s.replace(",", ",").split(",")
- out: List[str] = []
- for p in parts:
- t = p.strip()
- if t:
- out.append(t)
- return out
- def _load_output_json(*, trace_id: str) -> Optional[Dict[str, Any]]:
- path = _output_dir_path() / trace_id / "output.json"
- try:
- with path.open("r", encoding="utf-8") as f:
- data = json.load(f)
- except FileNotFoundError:
- return None
- except Exception:
- logger.warning("读取 output.json 失败: %s", str(path), exc_info=True)
- return None
- return data if isinstance(data, dict) else None
- def _extract_get_video_topic_videos(*, trace_id: str) -> List[Dict[str, Any]]:
- """
- 从 log.txt 中提取 get_video_topic 的返回 metadata.videos(原始选题点)。
- 期望日志片段形态(render_log_html 同源格式):
- [FOLD:🔧 工具调用:get_video_topic ...]
- ...
- [FOLD:📤 返回内容]
- <json array>
- [/FOLD]
- """
- log_path = _output_dir_path() / trace_id / "log.txt"
- try:
- text = log_path.read_text(encoding="utf-8")
- except FileNotFoundError:
- return []
- except Exception:
- logger.warning("读取 log.txt 失败: %s", str(log_path), exc_info=True)
- return []
- marker = "[FOLD:🔧 工具调用:get_video_topic"
- start = text.find(marker)
- if start < 0:
- return []
- snippet = text[start:]
- out_marker = "[FOLD:📤 返回内容]"
- out_start = snippet.find(out_marker)
- if out_start < 0:
- return []
- after = snippet[out_start + len(out_marker) :]
- json_start = after.find("[")
- if json_start < 0:
- return []
- json_end = after.find("[/FOLD]")
- if json_end < 0:
- return []
- raw = after[json_start:json_end].strip()
- try:
- parsed = json.loads(raw)
- except Exception:
- logger.warning("解析 get_video_topic 返回 JSON 失败", exc_info=True)
- return []
- return parsed if isinstance(parsed, list) else []
- def _flatten_case_points_text(video: Dict[str, Any]) -> str:
- tp = video.get("选题点")
- if not isinstance(tp, dict):
- return ""
- tokens: List[str] = []
- for k in ("灵感点", "目的点", "关键点"):
- v = tp.get(k)
- if isinstance(v, list):
- for x in v:
- if isinstance(x, str) and x.strip():
- tokens.append(x.strip())
- return " ".join(tokens)
- def _score_match(*, row_text: str, candidate_text: str) -> int:
- """
- 简单可控的匹配评分:按“子串命中次数”计分,避免引入分词依赖。
- """
- rt = (row_text or "").strip()
- ct = (candidate_text or "").strip()
- if not rt or not ct:
- return 0
- score = 0
- for token in _split_input_features(rt):
- if token and token in ct:
- score += 2
- # 再做一次整体包含(更强信号)
- if rt and rt in ct:
- score += 3
- return score
- def _pick_best_case_video(
- *, row: Dict[str, Any], case_videos: Sequence[Dict[str, Any]]
- ) -> Optional[Dict[str, Any]]:
- if not case_videos:
- return None
- row_text = " ".join(
- [
- str(row.get("from_case_point") or ""),
- str(row.get("search_keyword") or ""),
- str(row.get("title") or ""),
- ]
- ).strip()
- scored: List[Tuple[int, int]] = []
- for i, v in enumerate(case_videos):
- scored.append((_score_match(row_text=row_text, candidate_text=_flatten_case_points_text(v)), i))
- scored.sort(reverse=True)
- best_score, best_idx = scored[0]
- # 低于 1 视为“不确定”,但仍给出一个稳定的默认(第一个)
- if best_score <= 0:
- return case_videos[0]
- return case_videos[best_idx]
- def _normalize_payload(*, trace_id: str, payload: Dict[str, Any]) -> Dict[str, Any]:
- rows = payload.get("rows")
- if not isinstance(rows, list):
- return payload
- output_json = _load_output_json(trace_id=trace_id) or {}
- input_features = _split_input_features(str(output_json.get("query") or ""))
- case_videos = _extract_get_video_topic_videos(trace_id=trace_id)
- normalized_rows: List[Any] = []
- for item in rows:
- if not isinstance(item, dict):
- normalized_rows.append(item)
- continue
- row = dict(item)
- # 1) 每条视频都体现原始输入特征词
- if "input_features" not in row:
- row["input_features"] = input_features
- # 2) from_case_point:尽量输出“原始选题点信息”,而不是联想词
- if "from_case_point" in row and case_videos:
- original = _pick_best_case_video(row=row, case_videos=case_videos)
- if isinstance(original, dict) and isinstance(original.get("选题点"), dict):
- # 保留模型原先写的联想/归类结果,便于排查,但不作为主字段
- if isinstance(row.get("from_case_point"), str) and row.get("from_case_point"):
- row["from_case_point_guess"] = row["from_case_point"]
- row["from_case_point"] = original.get("选题点")
- if "from_case_aweme_id" not in row:
- row["from_case_aweme_id"] = str(original.get("id") or "").strip() or None
- normalized_rows.append(row)
- out = dict(payload)
- out["rows"] = normalized_rows
- return out
- def _write_process_trace(*, trace_id: str, payload: Dict[str, Any]) -> Path:
- out_dir = _output_dir_path() / trace_id
- out_dir.mkdir(parents=True, exist_ok=True)
- path = out_dir / "process_trace.json"
- doc = {
- **payload,
- "schema_version": "1.0",
- "trace_id": trace_id,
- "generated_at": datetime.now(timezone.utc).isoformat(),
- }
- with path.open("w", encoding="utf-8") as f:
- json.dump(doc, f, ensure_ascii=False, indent=2)
- return path
- @tool(
- description=(
- "在**全部流程执行完毕之后**调用:把每条入选/候选内容的「选择策略」整理成表格 JSON,"
- "写入当前任务的 output 目录下的 process_trace.json,便于后续复盘。"
- "参数 summary_json 为 JSON 字符串,可以是:"
- "1)数组:每一项是一行记录;会被包成 {\"rows\": [...]};"
- "2)对象:应包含 rows 字段,rows 为行列表。"
- "建议每行至少包含:strategy_type(\"case_based\" | \"feature_based\")、"
- "from_case_aweme_id / from_feature(来源 case 的选题点或特征)、"
- "search_keyword(使用的搜索词)、"
- "channel(\"search\" | \"author\" | \"ranking\" | \"other\" 等)、"
- "decision_basis(如 \"demand_filtering\" | \"content_portrait\" | \"author_portrait\" | \"other\")、"
- "decision_notes(自由文本补充原因)。"
- ),
- )
- async def exec_summary(trace_id: str, summary_json: str) -> ToolResult:
- """
- Args:
- trace_id: 本次任务 trace_id(与 output.json 同目录)。
- summary_json: JSON 字符串。对象或数组均可;数组会包成 {\"rows\": [...] }。
- """
- call_params = {"trace_id": trace_id, "summary_json": "<json>"}
- tid = (trace_id or "").strip()
- if not tid:
- err = ToolResult(
- title="过程摘要",
- output="trace_id 不能为空",
- metadata={"ok": False, "error": "empty trace_id"},
- )
- log_tool_call(_LOG_LABEL, call_params, format_tool_result_for_log(err))
- return err
- try:
- payload = _parse_payload(summary_json)
- except json.JSONDecodeError as e:
- err = ToolResult(
- title="过程摘要",
- output=f"summary_json 不是合法 JSON: {e}",
- metadata={"ok": False, "error": str(e)},
- )
- log_tool_call(_LOG_LABEL, call_params, format_tool_result_for_log(err))
- return err
- except ValueError as e:
- err = ToolResult(
- title="过程摘要",
- output=str(e),
- metadata={"ok": False, "error": str(e)},
- )
- log_tool_call(_LOG_LABEL, call_params, format_tool_result_for_log(err))
- return err
- payload = _normalize_payload(trace_id=tid, payload=payload)
- try:
- path = _write_process_trace(trace_id=tid, payload=payload)
- except OSError as e:
- msg = f"写入 process_trace.json 失败: {e}"
- logger.error(msg, exc_info=True)
- err = ToolResult(title="过程摘要", output=msg, metadata={"ok": False, "error": str(e)})
- log_tool_call(_LOG_LABEL, call_params, format_tool_result_for_log(err))
- return err
- out = ToolResult(
- title="过程摘要",
- output=f"已写入 {path}",
- metadata={"ok": True, "trace_id": tid, "path": str(path)},
- )
- log_tool_call(_LOG_LABEL, {"trace_id": tid}, format_tool_result_for_log(out))
- return out
|