content_tree.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234
  1. """
  2. 内容树 API 工具
  3. 封装内容树搜索接口:
  4. 1. search_content_tree - 关键词搜索分类和元素
  5. 2. get_category_tree - 获取指定分类的完整路径和子树
  6. """
  7. import logging
  8. from typing import Optional
  9. import httpx
  10. from agent.tools import tool
  11. from agent.tools.models import ToolResult
  12. logger = logging.getLogger(__name__)
  13. BASE_URL = "http://8.147.104.190:8001"
  14. @tool(description="在内容树中搜索分类(category)和元素(element),支持获取祖先路径和子孙节点")
  15. async def search_content_tree(
  16. q: str,
  17. source_type: str,
  18. entity_type: str = "all",
  19. top_k: int = 20,
  20. use_description: bool = False,
  21. include_ancestors: bool = False,
  22. descendant_depth: int = 0,
  23. ) -> ToolResult:
  24. """
  25. 关键词搜索内容树中的分类和元素。
  26. Args:
  27. q: 搜索关键词
  28. source_type: 维度,必须是 "实质" / "形式" / "意图" 之一
  29. entity_type: 搜索对象类型,"category" / "element" / "all"(默认)
  30. top_k: 返回结果数量,1-100(默认20)
  31. use_description: 是否同时搜索描述字段(默认仅搜索名称)
  32. include_ancestors: 是否返回祖先路径
  33. descendant_depth: 返回子孙节点深度,0=不返回,1=直接子节点,2=子+孙...
  34. """
  35. params = {
  36. "q": q,
  37. "source_type": source_type,
  38. "entity_type": entity_type,
  39. "top_k": top_k,
  40. "use_description": str(use_description).lower(),
  41. "include_ancestors": str(include_ancestors).lower(),
  42. "descendant_depth": descendant_depth,
  43. }
  44. try:
  45. async with httpx.AsyncClient(timeout=30.0) as client:
  46. resp = await client.get(f"{BASE_URL}/api/agent/search", params=params)
  47. resp.raise_for_status()
  48. data = resp.json()
  49. count = data.get("count", 0)
  50. results = data.get("results", [])
  51. # 格式化输出
  52. lines = [f"搜索「{q}」({source_type}维度)共找到 {count} 条结果:\n"]
  53. for r in results:
  54. etype = r.get("entity_type", "")
  55. name = r.get("name", "")
  56. score = r.get("score", 0)
  57. if etype == "category":
  58. sid = r.get("entity_id", "")
  59. path = r.get("path", "")
  60. desc = r.get("description", "")
  61. lines.append(f"[分类] entity_id={sid} | {path} | score={score:.2f}")
  62. if desc:
  63. lines.append(f" 描述: {desc}")
  64. ancestors = r.get("ancestors", [])
  65. if ancestors:
  66. anc_names = " > ".join(a["name"] for a in ancestors)
  67. lines.append(f" 祖先: {anc_names}")
  68. descendants = r.get("descendants", [])
  69. if descendants:
  70. desc_names = ", ".join(d["name"] for d in descendants[:10])
  71. lines.append(f" 子孙({len(descendants)}): {desc_names}")
  72. else:
  73. eid = r.get("entity_id", "")
  74. belong = r.get("belong_category_entity_id", "")
  75. occ = r.get("occurrence_count", 0)
  76. lines.append(f"[元素] entity_id={eid} | {name} | belong_category={belong} | 出现次数={occ} | score={score:.2f}")
  77. edesc = r.get("description", "")
  78. if edesc:
  79. lines.append(f" 描述: {edesc}")
  80. lines.append("")
  81. return ToolResult(
  82. title=f"内容树搜索: {q} ({source_type}) → {count} 条",
  83. output="\n".join(lines),
  84. )
  85. except httpx.HTTPError as e:
  86. return ToolResult(title="内容树搜索失败", output=f"HTTP 错误: {e}")
  87. except Exception as e:
  88. logger.exception("search_content_tree error")
  89. return ToolResult(title="内容树搜索失败", output=f"错误: {e}")
  90. @tool(description="获取指定分类节点的完整路径、祖先和子孙结构(通过 entity_id 精确查询)")
  91. async def get_category_tree(
  92. entity_id: int,
  93. source_type: str,
  94. include_ancestors: bool = True,
  95. descendant_depth: int = -1,
  96. ) -> ToolResult:
  97. """
  98. 获取指定分类的完整路径和子树结构。
  99. Args:
  100. entity_id: 分类的 entity_id
  101. source_type: 维度,"实质" / "形式" / "意图"
  102. include_ancestors: 是否返回祖先路径(默认 True)
  103. descendant_depth: 子孙深度,-1=全部,0=仅当前,1=子节点,2=子+孙...
  104. """
  105. params = {
  106. "source_type": source_type,
  107. "include_ancestors": str(include_ancestors).lower(),
  108. "descendant_depth": descendant_depth,
  109. }
  110. try:
  111. async with httpx.AsyncClient(timeout=30.0) as client:
  112. resp = await client.get(f"{BASE_URL}/api/agent/search/category/{entity_id}", params=params)
  113. resp.raise_for_status()
  114. data = resp.json()
  115. current = data.get("current", {})
  116. ancestors = data.get("ancestors", [])
  117. descendants = data.get("descendants", [])
  118. lines = []
  119. lines.append(f"分类节点: {current.get('name', '')} (entity_id={entity_id})")
  120. lines.append(f"路径: {current.get('path', '')}")
  121. if current.get("description"):
  122. lines.append(f"描述: {current['description']}")
  123. lines.append("")
  124. if ancestors:
  125. lines.append("祖先路径:")
  126. for a in ancestors:
  127. lines.append(f" L{a.get('level', '?')} {a.get('name', '')} (entity_id={a.get('entity_id', '')})")
  128. lines.append("")
  129. if descendants:
  130. lines.append(f"子孙节点 ({len(descendants)} 个):")
  131. for d in descendants:
  132. indent = " " * d.get("depth_from_parent", 1)
  133. leaf_mark = " [叶]" if d.get("is_leaf") else ""
  134. lines.append(f"{indent}L{d.get('level', '?')} {d.get('name', '')} (entity_id={d.get('entity_id', '')}){leaf_mark}")
  135. return ToolResult(
  136. title=f"分类树: {current.get('name', entity_id)} (entity_id={entity_id})",
  137. output="\n".join(lines),
  138. )
  139. except httpx.HTTPError as e:
  140. return ToolResult(title="获取分类树失败", output=f"HTTP 错误: {e}")
  141. except Exception as e:
  142. logger.exception("get_category_tree error")
  143. return ToolResult(title="获取分类树失败", output=f"错误: {e}")
  144. @tool(description="获取指定分类下的所有元素,支持分页、排序和筛选")
  145. async def get_category_elements(
  146. category_id: int,
  147. source_type: str,
  148. page_size: int = 50,
  149. sort_by: str = "occurrence_count",
  150. order: str = "desc",
  151. min_occurrence: Optional[int] = None,
  152. ) -> ToolResult:
  153. """
  154. 获取指定分类下的所有元素。
  155. Args:
  156. category_id: 分类的 entity_id
  157. source_type: 维度,"实质" / "形式" / "意图"
  158. page_size: 每页数量,1-200(默认50)
  159. sort_by: 排序字段,"name" / "id" / "occurrence_count"(默认)
  160. order: 排序方向,"asc" / "desc"(默认)
  161. min_occurrence: 最小出现次数,用于过滤低频元素(可选)
  162. """
  163. params = {
  164. "source_type": source_type,
  165. "category_entity_id": category_id,
  166. "page_size": page_size,
  167. "sort_by": sort_by,
  168. "order": order,
  169. }
  170. if min_occurrence is not None:
  171. params["min_occurrence"] = min_occurrence
  172. try:
  173. async with httpx.AsyncClient(timeout=30.0) as client:
  174. resp = await client.get(f"{BASE_URL}/api/agent/search/elements", params=params)
  175. resp.raise_for_status()
  176. data = resp.json()
  177. total = data.get("total", 0)
  178. results = data.get("results", [])
  179. lines = [f"分类 {category_id} 下的元素({source_type}维度)共 {total} 个,返回 {len(results)} 个:\n"]
  180. for r in results:
  181. eid = r.get("id", "")
  182. name = r.get("name", "")
  183. occ = r.get("occurrence_count", 0)
  184. desc = r.get("description", "")
  185. category = r.get("category", {})
  186. cat_path = category.get("path", "")
  187. lines.append(f"[元素] entity_id={eid} | {name} | 出现次数={occ}")
  188. if desc:
  189. lines.append(f" 描述: {desc}")
  190. if cat_path:
  191. lines.append(f" 所属分类: {cat_path}")
  192. lines.append("")
  193. return ToolResult(
  194. title=f"分类元素: category_id={category_id} → {total} 个",
  195. output="\n".join(lines),
  196. )
  197. except httpx.HTTPError as e:
  198. return ToolResult(title="获取分类元素失败", output=f"HTTP 错误: {e}")
  199. except Exception as e:
  200. logger.exception("get_category_elements error")
  201. return ToolResult(title="获取分类元素失败", output=f"错误: {e}")