search_person_tree.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183
  1. """
  2. 人设树常量点检索工具 - 根据人设名称提取所有常量点
  3. 用于 Agent 执行时自主提取人设树中的常量点。
  4. """
  5. import json
  6. from pathlib import Path
  7. from typing import Any, Dict, List, Optional
  8. import os
  9. from agent.tools import tool, ToolResult
  10. # 人设数据基础路径
  11. PERSONA_DATA_BASE_PATH = Path(__file__).parent.parent / "data"
  12. GRAPH_FULL_DATA_PATH = os.getenv(
  13. "GRAPH_FULL_DATA_PATH",
  14. # str(Path(__file__).parent.parent / "data/library/item_graph/item_graph_full_all_levels.json")
  15. str(Path(__file__).parent.parent / "data/家有大志/topic_point_data/point_classification_results_trans.json")
  16. )
  17. # 缓存图数据,避免重复加载
  18. _graph_full_cache: Optional[Dict[str, Any]] = None
  19. def _load_graph_full() -> Dict[str, Any]:
  20. """加载完整图数据(带缓存,包含 edges)"""
  21. global _graph_full_cache
  22. if _graph_full_cache is None:
  23. with open(GRAPH_FULL_DATA_PATH, 'r', encoding='utf-8') as f:
  24. _graph_full_cache = json.load(f)
  25. return _graph_full_cache
  26. def _extract_constant_points(tree_data: Dict[str, Any]) -> List[str]:
  27. """
  28. 递归提取树中所有常量点的名称
  29. Args:
  30. tree_data: 树节点数据
  31. Returns:
  32. 常量点名称列表
  33. """
  34. constant_points = []
  35. for node_name, node_data in tree_data.items():
  36. # 跳过元数据字段
  37. if node_name.startswith("_"):
  38. continue
  39. # 检查是否为常量点:_type 为 "ID" 且 _is_constant 为 true
  40. if isinstance(node_data, dict):
  41. node_type = node_data.get("_type")
  42. is_constant = node_data.get("_is_constant", False)
  43. if node_type == "ID" and is_constant:
  44. constant_points.append(node_name)
  45. # 递归处理子节点
  46. if "children" in node_data:
  47. child_points = _extract_constant_points(node_data["children"])
  48. constant_points.extend(child_points)
  49. return constant_points
  50. def _load_persona_tree(persona_name: str) -> Dict[str, Any]:
  51. """
  52. 加载人设树数据并提取所有常量点
  53. Args:
  54. persona_name: 人设名称,如 "家有大志"
  55. Returns:
  56. 包含所有常量点的列表
  57. """
  58. persona_dir = PERSONA_DATA_BASE_PATH / persona_name / "tree"
  59. if not persona_dir.exists():
  60. return {
  61. "found": False,
  62. "persona_name": persona_name,
  63. "error": f"人设目录不存在: {persona_dir}"
  64. }
  65. tree_files = {
  66. "形式": persona_dir / "形式_point_tree_how.json",
  67. "实质": persona_dir / "实质_point_tree_how.json",
  68. "意图": persona_dir / "意图_point_tree_how.json"
  69. }
  70. all_constant_points = []
  71. missing_files = []
  72. for dimension, tree_file in tree_files.items():
  73. if not tree_file.exists():
  74. missing_files.append(f"{dimension}_point_tree_how.json")
  75. continue
  76. try:
  77. with open(tree_file, 'r', encoding='utf-8') as f:
  78. tree_data = json.load(f)
  79. # 提取该维度的所有常量点
  80. constant_points = _extract_constant_points(tree_data)
  81. all_constant_points.extend(constant_points)
  82. except Exception as e:
  83. return {
  84. "found": False,
  85. "persona_name": persona_name,
  86. "error": f"读取文件 {tree_file.name} 失败: {str(e)}"
  87. }
  88. dict = _load_graph_full()
  89. data = []
  90. for point in all_constant_points:
  91. if point in dict:
  92. data.append(dict.get(point))
  93. return data
  94. @tool(
  95. description="根据人设名称检索该人设树中的所有常量点(_is_constant=true的点)。",
  96. display={
  97. "zh": {
  98. "name": "人设常量点检索",
  99. "params": {
  100. "persona_name": "人设名称(如:家有大志)",
  101. },
  102. },
  103. },
  104. )
  105. async def search_person_tree_constants(persona_name: str) -> ToolResult:
  106. """
  107. 根据人设名称检索该人设树中的所有常量点。
  108. 常量点是指人设树中 _type 为 "ID" 且 _is_constant 为 true 的节点。
  109. 这些点代表了该人设的核心特征和固定属性。
  110. Args:
  111. persona_name: 人设名称,如 "家有大志"
  112. Returns:
  113. ToolResult: 包含三个维度(形式、实质、意图)的常量点列表
  114. """
  115. if not persona_name:
  116. return ToolResult(
  117. title="人设常量点检索失败",
  118. output="",
  119. error="请提供人设名称",
  120. )
  121. try:
  122. constant_points = _load_persona_tree(persona_name)
  123. except Exception as e:
  124. return ToolResult(
  125. title="人设常量点检索失败",
  126. output="",
  127. error=f"检索异常: {str(e)}",
  128. )
  129. if not constant_points:
  130. return ToolResult(
  131. title="人设常量点检索失败",
  132. output="",
  133. error="未找到常量点数据",
  134. )
  135. output = json.dumps(constant_points, ensure_ascii=False, indent=2)
  136. return ToolResult(
  137. title=f"人设常量点检索 - {persona_name}",
  138. output=output,
  139. long_term_memory=f"检索到 {len(constant_points)} 个常量点",
  140. )
  141. if __name__ == "__main__":
  142. print(_load_persona_tree("家有大志"))