| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367 |
- """
- 在流程结束后写入**内容策略表格** JSON。
- 输出路径:{OUTPUT_DIR}/{trace_id}/process_trace.json
- """
- from __future__ import annotations
- import json
- import logging
- import os
- from pathlib import Path
- from typing import Any, Dict, List, Optional, Tuple
- from agent.tools import tool, ToolResult
- from utils.tool_logging import format_tool_result_for_log, log_tool_call
- _LOG_LABEL = "工具调用:exec_summary -> 写入过程 trace JSON"
- logger = logging.getLogger(__name__)
- def _output_dir_path() -> Path:
- # 与 store_results_mysql / output.json 目录约定一致
- return Path(os.getenv("OUTPUT_DIR", ".cache/output"))
- def _parse_payload(summary_json: str) -> Dict[str, Any]:
- """
- 解析并规范化 LLM 传入的表格数据。
- - 如果是数组:视为“表格行列表”,包成 {"rows": [...]}
- - 如果是对象:直接返回(用于后续扩展字段)
- """
- data = json.loads(summary_json)
- if isinstance(data, list):
- return {"rows": data}
- if not isinstance(data, dict):
- raise ValueError("summary_json 解析后必须是 JSON 对象或数组")
- return data
- def _split_input_features(raw: str) -> List[str]:
- s = (raw or "").strip()
- if not s:
- return []
- parts = s.replace(",", ",").split(",")
- out: List[str] = []
- for p in parts:
- t = p.strip()
- if t:
- out.append(t)
- return out
- def _load_output_json(*, trace_id: str) -> Optional[Dict[str, Any]]:
- path = _output_dir_path() / trace_id / "output.json"
- try:
- with path.open("r", encoding="utf-8") as f:
- data = json.load(f)
- except FileNotFoundError:
- return None
- except Exception:
- logger.warning("读取 output.json 失败: %s", str(path), exc_info=True)
- return None
- return data if isinstance(data, dict) else None
- def _extract_contents(*, trace_id: str) -> List[Dict[str, Any]]:
- """
- 从 output.json 读取最终入选 contents。
- 约定:
- - 只允许对 output.json.contents 内的 aweme_id 生成/写入 process_trace rows
- """
- output_json = _load_output_json(trace_id=trace_id) or {}
- contents = output_json.get("contents")
- if not isinstance(contents, list):
- return []
- out: List[Dict[str, Any]] = []
- for item in contents:
- if isinstance(item, dict):
- out.append(item)
- return out
- def _map_strategy_type(value: Any) -> str:
- v = str(value or "").strip()
- if v in ("case_based", "case", "case出发"):
- return "case出发"
- if v in ("feature_based", "feature", "特征出发"):
- return "特征出发"
- return v
- def _map_channel(value: Any) -> str:
- v = str(value or "").strip()
- mapping = {
- "search": "抖音搜索",
- "author": "订阅账号",
- "ranking": "榜单",
- "other": "其他",
- "抖音搜索": "抖音搜索",
- "订阅账号": "订阅账号",
- "榜单": "榜单",
- "其他": "其他",
- }
- return mapping.get(v, v)
- def _map_decision_basis(value: Any) -> str:
- v = str(value or "").strip()
- mapping = {
- "content_portrait": "内容画像匹配",
- "author_portrait": "作者画像匹配",
- "demand_filtering": "需求筛选",
- "other": "其他",
- "画像缺失": "画像缺失",
- "内容画像匹配": "内容画像匹配",
- "作者画像匹配": "作者画像匹配",
- "需求筛选": "需求筛选",
- "其他": "其他",
- }
- return mapping.get(v, v)
- def _infer_decision_basis_from_output_content(content: Dict[str, Any]) -> str:
- portrait = content.get("portrait_data") or {}
- source = str(portrait.get("source") or "").strip()
- if source == "content_like":
- return "内容画像匹配"
- if source == "account_fans":
- return "作者画像匹配"
- if source == "none":
- return "画像缺失"
- return ""
- def _build_base_row(*, trace_id: str, content: Dict[str, Any], input_features: List[str], query: str) -> Dict[str, Any]:
- return {
- "trace_id": trace_id,
- "aweme_id": str(content.get("aweme_id") or "").strip(),
- "title": str(content.get("title") or "").strip(),
- "author_nickname": str(content.get("author_nickname") or "").strip(),
- "strategy_type": "",
- "from_case_aweme_id": "",
- "from_case_point": "",
- "from_feature": "",
- "search_keyword": str(query or "").strip(),
- "channel": "抖音搜索",
- "decision_basis": _infer_decision_basis_from_output_content(content),
- "decision_notes": str(content.get("reason") or "").strip(),
- "input_features": input_features,
- }
- _ROW_KEYS: Tuple[str, ...] = (
- "trace_id",
- "aweme_id",
- "title",
- "author_nickname",
- "strategy_type",
- "from_case_aweme_id",
- "from_case_point",
- "from_feature",
- "search_keyword",
- "channel",
- "decision_basis",
- "decision_notes",
- "input_features",
- )
- def _sanitize_row(row: Dict[str, Any]) -> Dict[str, Any]:
- """只保留固定字段,并把枚举值规范成中文。"""
- out: Dict[str, Any] = {k: row.get(k, "") for k in _ROW_KEYS}
- out["strategy_type"] = _map_strategy_type(out.get("strategy_type"))
- out["channel"] = _map_channel(out.get("channel"))
- out["decision_basis"] = _map_decision_basis(out.get("decision_basis"))
- # input_features 规范为 list[str]
- feats = out.get("input_features")
- if isinstance(feats, list):
- out["input_features"] = [str(x).strip() for x in feats if str(x).strip()]
- elif isinstance(feats, str):
- out["input_features"] = _split_input_features(feats)
- else:
- out["input_features"] = []
- return out
- def _normalize_payload(*, trace_id: str, payload: Dict[str, Any]) -> Dict[str, Any]:
- # tool 只做最小职责:过滤/补全/规范化;复杂推理由 skill 生成 summary_json 来完成
- raw_rows = payload.get("rows")
- rows_in_payload: List[Dict[str, Any]] = []
- if isinstance(raw_rows, list):
- for item in raw_rows:
- if isinstance(item, dict):
- rows_in_payload.append(item)
- output_json = _load_output_json(trace_id=trace_id) or {}
- query = str(output_json.get("query") or "").strip()
- input_features = _split_input_features(query)
- contents = _extract_contents(trace_id=trace_id)
- contents_by_aweme_id: Dict[str, Dict[str, Any]] = {
- str(c.get("aweme_id") or "").strip(): c for c in contents if str(c.get("aweme_id") or "").strip()
- }
- # 先把 payload rows 归并到 aweme_id
- payload_by_aweme_id: Dict[str, Dict[str, Any]] = {}
- for r in rows_in_payload:
- aweme_id = str(r.get("aweme_id") or r.get("awemeId") or "").strip()
- if not aweme_id:
- continue
- payload_by_aweme_id[aweme_id] = dict(r)
- # 只允许 payload 覆盖“策略/来源/解释”字段,避免覆盖 output.json.contents 的身份字段(title/author 等)
- allowed_payload_keys: set[str] = {
- "strategy_type",
- "from_case_aweme_id",
- "from_case_point",
- "from_feature",
- "search_keyword",
- "channel",
- "decision_basis",
- "decision_notes",
- "input_features",
- }
- # 兼容 payload 的常见别名/驼峰 key(模型输出不稳定时,尽量不丢信息)
- alias_map: Dict[str, Tuple[str, ...]] = {
- "strategy_type": ("strategy_type", "strategyType"),
- "from_case_aweme_id": ("from_case_aweme_id", "fromCaseAwemeId", "case_aweme_id", "caseAwemeId"),
- "from_case_point": ("from_case_point", "fromCasePoint", "case_point", "casePoint"),
- "from_feature": ("from_feature", "fromFeature", "feature", "from_feature_name"),
- "search_keyword": ("search_keyword", "searchKeyword", "keyword"),
- "channel": ("channel", "source_channel", "sourceChannel", "source"),
- "decision_basis": ("decision_basis", "decisionBasis"),
- "decision_notes": ("decision_notes", "decisionNotes", "notes"),
- "input_features": ("input_features", "inputFeatures"),
- }
- def _pick(provided: Dict[str, Any], key: str) -> Any:
- for k in alias_map.get(key, (key,)):
- if k in provided:
- return provided.get(k)
- return None
- normalized: List[Dict[str, Any]] = []
- for aweme_id, content in contents_by_aweme_id.items():
- base = _build_base_row(trace_id=trace_id, content=content, input_features=input_features, query=query)
- provided = payload_by_aweme_id.get(aweme_id) or {}
- merged = dict(base)
- # 只合并允许覆盖的字段
- for k in allowed_payload_keys:
- v = _pick(provided, k)
- if v is not None:
- merged[k] = v
- # 身份字段强制以 output.json.contents 为准(即使 payload 传了也不采纳)
- merged["aweme_id"] = str(content.get("aweme_id") or "").strip()
- merged["title"] = str(content.get("title") or "").strip()
- merged["author_nickname"] = str(content.get("author_nickname") or "").strip()
- # 如果缺失 input_features,用 query 拆分补齐
- if "input_features" not in merged or not merged.get("input_features"):
- merged["input_features"] = input_features
- normalized.append(_sanitize_row(merged))
- # 保持稳定顺序:按 rank(若有)或 aweme_id
- def _sort_key(r: Dict[str, Any]) -> Tuple[int, str]:
- c = contents_by_aweme_id.get(str(r.get("aweme_id") or "").strip()) or {}
- try:
- rank = int(c.get("rank") or 0)
- except Exception:
- rank = 0
- return (rank if rank > 0 else 10**9, str(r.get("aweme_id") or ""))
- normalized.sort(key=_sort_key)
- return {"rows": normalized}
- def _write_process_trace(*, trace_id: str, payload: Dict[str, Any]) -> Path:
- out_dir = _output_dir_path() / trace_id
- out_dir.mkdir(parents=True, exist_ok=True)
- path = out_dir / "process_trace.json"
- # 输出格式收敛:只允许 {"rows": [...]}
- doc = {"rows": payload.get("rows") or []}
- with path.open("w", encoding="utf-8") as f:
- json.dump(doc, f, ensure_ascii=False, indent=2)
- return path
- @tool(
- description=(
- "在**全部流程执行完毕之后**调用:把每条最终入选内容的「选择策略」整理成表格 JSON,"
- "写入当前任务的 output 目录下的 process_trace.json,便于后续复盘。"
- "参数 summary_json 为 JSON 字符串,可以是数组或对象(对象需包含 rows)。"
- "可选参数 log_path/log_text 用于传入本次运行日志(便于复盘留档/未来扩展)。"
- ),
- )
- async def exec_summary(
- trace_id: str,
- summary_json: str,
- log_path: str = "",
- log_text: str = "",
- ) -> ToolResult:
- call_params = {
- "trace_id": trace_id,
- "summary_json": "<json>",
- "log_path": (log_path or "").strip(),
- "log_text": "<text>",
- }
- tid = (trace_id or "").strip()
- if not tid:
- err = ToolResult(
- title="过程摘要",
- output="trace_id 不能为空",
- metadata={"ok": False, "error": "empty trace_id"},
- )
- log_tool_call(_LOG_LABEL, call_params, format_tool_result_for_log(err))
- return err
- try:
- payload = _parse_payload(summary_json)
- except json.JSONDecodeError as e:
- err = ToolResult(
- title="过程摘要",
- output=f"summary_json 不是合法 JSON: {e}",
- metadata={"ok": False, "error": str(e)},
- )
- log_tool_call(_LOG_LABEL, call_params, format_tool_result_for_log(err))
- return err
- except ValueError as e:
- err = ToolResult(
- title="过程摘要",
- output=str(e),
- metadata={"ok": False, "error": str(e)},
- )
- log_tool_call(_LOG_LABEL, call_params, format_tool_result_for_log(err))
- return err
- payload = _normalize_payload(trace_id=tid, payload=payload)
- try:
- path = _write_process_trace(trace_id=tid, payload=payload)
- except OSError as e:
- msg = f"写入 process_trace.json 失败: {e}"
- logger.error(msg, exc_info=True)
- err = ToolResult(title="过程摘要", output=msg, metadata={"ok": False, "error": str(e)})
- log_tool_call(_LOG_LABEL, call_params, format_tool_result_for_log(err))
- return err
- out = ToolResult(
- title="过程摘要",
- output=f"已写入 {path}",
- metadata={
- "ok": True,
- "trace_id": tid,
- "path": str(path),
- "log_path": (log_path or "").strip(),
- "log_text_len": len((log_text or "").strip()),
- },
- )
- log_tool_call(_LOG_LABEL, {"trace_id": tid}, format_tool_result_for_log(out))
- return out
|