topic_summary.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193
  1. #!/usr/bin/env python3
  2. """
  3. 账号人设总结:
  4. 1. 从 input/{account_name}/处理后数据/tree 目录下读取人设树 JSON 文件并合并
  5. 2. 将合并后的 JSON 填充到 topic_summary_prompt.md 中的 {topic_point_tree}
  6. 3. 调用大模型生成账号人设总结,写入 input/{account_name}/处理后数据/persona_data/persona_summary.json
  7. """
  8. import asyncio
  9. import json
  10. import logging
  11. import sys
  12. from pathlib import Path
  13. from typing import Any, Dict
  14. logger = logging.getLogger(__name__)
  15. # 确保可以导入 agent 内的 LLM 调用封装(本文件在 data_process 下,多一层目录)
  16. _project_root = Path(__file__).resolve().parent.parent.parent.parent
  17. if str(_project_root) not in sys.path:
  18. sys.path.insert(0, str(_project_root))
  19. try:
  20. from agent.llm.openrouter import openrouter_llm_call
  21. except ImportError: # pragma: no cover - 仅用于本地缺少依赖时的降级提示
  22. openrouter_llm_call = None # type: ignore[assignment]
  23. # 复用与 search_and_eval 相同的模型,保证行为一致
  24. EVAL_LLM_MODEL = "google/gemini-3.1-pro-preview"
  25. # 脚本与 topic_summary_prompt.md 在 data_process;数据在 overall_derivation/input
  26. BASE_DIR = Path(__file__).resolve().parent
  27. OVERALL_DERIVATION_DIR = BASE_DIR.parent
  28. INPUT_BASE = OVERALL_DERIVATION_DIR / "input"
  29. # 人设树中不送入 LLM 的字段(递归删除)
  30. _TREE_STRIP_KEYS = frozenset(
  31. {
  32. "_post_ids",
  33. "_child_categories_relation",
  34. "_child_categories_relation_detail",
  35. }
  36. )
  37. def _strip_tree_fields(obj: Any) -> Any:
  38. """递归从树结构中移除 _TREE_STRIP_KEYS 中的键。"""
  39. if isinstance(obj, dict):
  40. return {
  41. k: _strip_tree_fields(v)
  42. for k, v in obj.items()
  43. if k not in _TREE_STRIP_KEYS
  44. }
  45. if isinstance(obj, list):
  46. return [_strip_tree_fields(x) for x in obj]
  47. return obj
  48. def _extract_json_object(content: str) -> Dict[str, Any]:
  49. """
  50. 从 LLM 回复中解析第一个 JSON 对象(允许被 ```json ... ``` 包裹)。
  51. 逻辑参考 tools/search_and_eval.py 中的实现。
  52. """
  53. content = content.strip()
  54. # 处理 ```json ... ``` 包裹的情况
  55. import re
  56. m = re.search(r"```(?:json)?\s*([\s\S]*?)\s*```", content)
  57. if m:
  58. content = m.group(1).strip()
  59. # 截取最外层 { ... }
  60. start = content.find("{")
  61. end = content.rfind("}")
  62. if start != -1 and end != -1:
  63. content = content[start : end + 1]
  64. return json.loads(content)
  65. def _load_topic_point_tree(account_name: str) -> Dict[str, Any]:
  66. """
  67. 读取 input/{account_name}/处理后数据/tree 目录下的所有 JSON 文件,并合并成一个字典:
  68. {
  69. "<文件名去掉后缀>": <该文件对应的树 JSON>,
  70. ...
  71. }
  72. 每棵树加载后会去掉 _post_ids、_child_categories_relation、_child_categories_relation_detail。
  73. """
  74. tree_dir = INPUT_BASE / account_name / "处理后数据" / "tree"
  75. if not tree_dir.is_dir():
  76. raise FileNotFoundError(f"人设树目录不存在: {tree_dir}")
  77. merged: Dict[str, Any] = {}
  78. files = sorted(tree_dir.glob("*.json"))
  79. if not files:
  80. raise FileNotFoundError(f"人设树目录中未找到任何 JSON 文件: {tree_dir}")
  81. for path in files:
  82. with open(path, "r", encoding="utf-8") as f:
  83. try:
  84. data = json.load(f)
  85. except json.JSONDecodeError as e:
  86. raise ValueError(f"解析 JSON 文件失败: {path}") from e
  87. merged[path.stem] = _strip_tree_fields(data)
  88. logger.info("已加载人设树文件: %s", path.name)
  89. return merged
  90. def _load_prompt_template() -> str:
  91. """读取 topic_summary_prompt.md 模板。"""
  92. prompt_path = BASE_DIR / "topic_summary_prompt.md"
  93. if not prompt_path.is_file():
  94. raise FileNotFoundError(f"找不到 prompt 模板文件: {prompt_path}")
  95. with open(prompt_path, "r", encoding="utf-8") as f:
  96. return f.read()
  97. async def generate_topic_summary(account_name: str) -> Dict[str, Any]:
  98. """
  99. 生成账号人设总结,并返回解析后的 JSON 结果。
  100. 同时将结果写入 persona_summary.json 文件。
  101. """
  102. if openrouter_llm_call is None:
  103. raise RuntimeError("未找到 openrouter_llm_call,请检查 agent.llm 依赖是否可用。")
  104. # 1. 加载并合并人设树
  105. topic_tree = _load_topic_point_tree(account_name)
  106. topic_tree_str = json.dumps(topic_tree, ensure_ascii=False, indent=2)
  107. logger.info("已合并人设树,共包含 %d 个子树", len(topic_tree))
  108. # 2. 读取并填充 prompt 模板
  109. prompt_template = _load_prompt_template()
  110. system_prompt = prompt_template.replace("{topic_point_tree}", topic_tree_str)
  111. # 3. 调用 LLM 生成总结
  112. messages = [
  113. {"role": "system", "content": system_prompt},
  114. {
  115. "role": "user",
  116. "content": "请根据以上说明,严格按照 JSON 模板输出账号人设总结,仅输出 JSON,不要包含其他解释性文字。",
  117. },
  118. ]
  119. logger.info("开始调用 LLM 生成账号人设总结,account_name=%s", account_name)
  120. llm_result = await openrouter_llm_call(messages, model=EVAL_LLM_MODEL)
  121. content = llm_result.get("content", "") if isinstance(llm_result, dict) else ""
  122. if not content:
  123. raise RuntimeError("LLM 未返回任何内容")
  124. try:
  125. summary_data = _extract_json_object(content)
  126. except Exception as e: # noqa: BLE001
  127. logger.exception("解析 LLM 返回的 JSON 失败")
  128. raise RuntimeError(f"解析 LLM 返回内容失败: {e}") from e
  129. # 4. 写入 persona_summary.json
  130. persona_dir = INPUT_BASE / account_name / "处理后数据" / "persona_data"
  131. persona_dir.mkdir(parents=True, exist_ok=True)
  132. persona_file = persona_dir / "persona_summary.json"
  133. with open(persona_file, "w", encoding="utf-8") as f:
  134. json.dump(summary_data, f, ensure_ascii=False, indent=2)
  135. logger.info("已写入账号人设总结到文件: %s", persona_file)
  136. return summary_data
  137. def main(account_name) -> None:
  138. # parser = argparse.ArgumentParser(description="根据人设树生成账号人设总结")
  139. # parser.add_argument("account_name", help="账号名称(对应 input/{account_name} 目录)")
  140. # args = parser.parse_args(argv)
  141. logging.basicConfig(
  142. level=logging.INFO,
  143. format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
  144. datefmt="%H:%M:%S",
  145. )
  146. logger.info("生成账号人设总结,account_name=%s", account_name)
  147. async def _run() -> None:
  148. summary = await generate_topic_summary(account_name)
  149. print(json.dumps(summary, ensure_ascii=False, indent=2))
  150. asyncio.run(_run())
  151. if __name__ == "__main__":
  152. main(account_name="空间点阵设计研究室")