topic_summary.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169
  1. #!/usr/bin/env python3
  2. """
  3. 账号人设总结:
  4. 1. 从 input/{account_name}/tree 目录下读取人设树 JSON 文件并合并
  5. 2. 将合并后的 JSON 填充到 topic_summary_prompt.md 中的 {topic_point_tree}
  6. 3. 调用大模型生成账号人设总结,写入 input/{account_name}/persona_data/persona_summary.json
  7. """
  8. import argparse
  9. import asyncio
  10. import json
  11. import logging
  12. import sys
  13. from pathlib import Path
  14. from typing import Any, Dict
  15. logger = logging.getLogger(__name__)
  16. # 确保可以导入 agent 内的 LLM 调用封装
  17. _project_root = Path(__file__).resolve().parent.parent.parent
  18. if str(_project_root) not in sys.path:
  19. sys.path.insert(0, str(_project_root))
  20. try:
  21. from agent.llm.openrouter import openrouter_llm_call
  22. except ImportError: # pragma: no cover - 仅用于本地缺少依赖时的降级提示
  23. openrouter_llm_call = None # type: ignore[assignment]
  24. # 复用与 search_and_eval 相同的模型,保证行为一致
  25. EVAL_LLM_MODEL = "google/gemini-3.1-pro-preview"
  26. BASE_DIR = Path(__file__).resolve().parent
  27. INPUT_BASE = BASE_DIR / "input"
  28. def _extract_json_object(content: str) -> Dict[str, Any]:
  29. """
  30. 从 LLM 回复中解析第一个 JSON 对象(允许被 ```json ... ``` 包裹)。
  31. 逻辑参考 tools/search_and_eval.py 中的实现。
  32. """
  33. content = content.strip()
  34. # 处理 ```json ... ``` 包裹的情况
  35. import re
  36. m = re.search(r"```(?:json)?\s*([\s\S]*?)\s*```", content)
  37. if m:
  38. content = m.group(1).strip()
  39. # 截取最外层 { ... }
  40. start = content.find("{")
  41. end = content.rfind("}")
  42. if start != -1 and end != -1:
  43. content = content[start : end + 1]
  44. return json.loads(content)
  45. def _load_topic_point_tree(account_name: str) -> Dict[str, Any]:
  46. """
  47. 读取 input/{account_name}/tree 目录下的所有 JSON 文件,并合并成一个字典:
  48. {
  49. "<文件名去掉后缀>": <该文件对应的树 JSON>,
  50. ...
  51. }
  52. """
  53. tree_dir = INPUT_BASE / account_name / "tree"
  54. if not tree_dir.is_dir():
  55. raise FileNotFoundError(f"人设树目录不存在: {tree_dir}")
  56. merged: Dict[str, Any] = {}
  57. files = sorted(tree_dir.glob("*.json"))
  58. if not files:
  59. raise FileNotFoundError(f"人设树目录中未找到任何 JSON 文件: {tree_dir}")
  60. for path in files:
  61. with open(path, "r", encoding="utf-8") as f:
  62. try:
  63. data = json.load(f)
  64. except json.JSONDecodeError as e:
  65. raise ValueError(f"解析 JSON 文件失败: {path}") from e
  66. merged[path.stem] = data
  67. logger.info("已加载人设树文件: %s", path.name)
  68. return merged
  69. def _load_prompt_template() -> str:
  70. """读取 topic_summary_prompt.md 模板。"""
  71. prompt_path = BASE_DIR / "topic_summary_prompt.md"
  72. if not prompt_path.is_file():
  73. raise FileNotFoundError(f"找不到 prompt 模板文件: {prompt_path}")
  74. with open(prompt_path, "r", encoding="utf-8") as f:
  75. return f.read()
  76. async def generate_topic_summary(account_name: str) -> Dict[str, Any]:
  77. """
  78. 生成账号人设总结,并返回解析后的 JSON 结果。
  79. 同时将结果写入 persona_summary.json 文件。
  80. """
  81. if openrouter_llm_call is None:
  82. raise RuntimeError("未找到 openrouter_llm_call,请检查 agent.llm 依赖是否可用。")
  83. # 1. 加载并合并人设树
  84. topic_tree = _load_topic_point_tree(account_name)
  85. topic_tree_str = json.dumps(topic_tree, ensure_ascii=False, indent=2)
  86. logger.info("已合并人设树,共包含 %d 个子树", len(topic_tree))
  87. # 2. 读取并填充 prompt 模板
  88. prompt_template = _load_prompt_template()
  89. system_prompt = prompt_template.replace("{topic_point_tree}", topic_tree_str)
  90. # 3. 调用 LLM 生成总结
  91. messages = [
  92. {"role": "system", "content": system_prompt},
  93. {
  94. "role": "user",
  95. "content": "请根据以上说明,严格按照 JSON 模板输出账号人设总结,仅输出 JSON,不要包含其他解释性文字。",
  96. },
  97. ]
  98. logger.info("开始调用 LLM 生成账号人设总结,account_name=%s", account_name)
  99. llm_result = await openrouter_llm_call(messages, model=EVAL_LLM_MODEL)
  100. content = llm_result.get("content", "") if isinstance(llm_result, dict) else ""
  101. if not content:
  102. raise RuntimeError("LLM 未返回任何内容")
  103. try:
  104. summary_data = _extract_json_object(content)
  105. except Exception as e: # noqa: BLE001
  106. logger.exception("解析 LLM 返回的 JSON 失败")
  107. raise RuntimeError(f"解析 LLM 返回内容失败: {e}") from e
  108. # 4. 写入 persona_summary.json
  109. persona_dir = INPUT_BASE / account_name / "persona_data"
  110. persona_dir.mkdir(parents=True, exist_ok=True)
  111. persona_file = persona_dir / "persona_summary.json"
  112. with open(persona_file, "w", encoding="utf-8") as f:
  113. json.dump(summary_data, f, ensure_ascii=False, indent=2)
  114. logger.info("已写入账号人设总结到文件: %s", persona_file)
  115. return summary_data
  116. def main(account_name) -> None:
  117. # parser = argparse.ArgumentParser(description="根据人设树生成账号人设总结")
  118. # parser.add_argument("account_name", help="账号名称(对应 input/{account_name} 目录)")
  119. # args = parser.parse_args(argv)
  120. logging.basicConfig(
  121. level=logging.INFO,
  122. format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
  123. datefmt="%H:%M:%S",
  124. )
  125. logger.info("生成账号人设总结,account_name=%s", account_name)
  126. async def _run() -> None:
  127. summary = await generate_topic_summary(account_name)
  128. print(json.dumps(summary, ensure_ascii=False, indent=2))
  129. asyncio.run(_run())
  130. if __name__ == "__main__":
  131. main(account_name="家有大志")