tools.py 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266
  1. """
  2. 内容工具族 —— 统一入口
  3. 4 个 @tool 注册给 LLM:
  4. - content_platforms: 列出/查询平台及其参数
  5. - content_search: 跨平台搜索
  6. - content_detail: 查看详情
  7. - content_suggest: 搜索建议词
  8. 所有平台的具体实现在 platforms/ 子目录,按模块自注册到 registry。
  9. """
  10. import json
  11. import os
  12. import uuid
  13. from typing import Any, Dict, Optional
  14. from agent.tools import tool, ToolResult, ToolContext
  15. from agent.tools.builtin.content.registry import (
  16. all_platforms, get_platform, match_platforms,
  17. )
  18. from agent.tools.builtin.content import cache as _cache
  19. # 导入平台模块以触发自注册(副作用导入)
  20. import agent.tools.builtin.content.platforms.aigc_channel # noqa: F401
  21. import agent.tools.builtin.content.platforms.youtube # noqa: F401
  22. import agent.tools.builtin.content.platforms.x # noqa: F401
  23. def _get_trace_id(context: Optional[ToolContext]) -> str:
  24. """从 context 取 trace_id,回退到环境变量或自动生成"""
  25. if context and hasattr(context, "trace_id") and context.trace_id:
  26. return context.trace_id
  27. return os.getenv("TRACE_ID") or f"anon-{uuid.uuid4().hex[:8]}"
  28. # ── content_platforms ──
  29. @tool(hidden_params=["context"], groups=["content"])
  30. async def content_platforms(
  31. platform: str = "",
  32. context: Optional[ToolContext] = None,
  33. ) -> ToolResult:
  34. """
  35. 列出支持的内容平台及其搜索参数。
  36. 不传 platform 时返回所有平台的概要列表(仅名称和 ID)。
  37. 传入 platform 时模糊匹配并返回匹配平台的详细参数说明(支持 ID、中文名、别名)。
  38. 建议在不熟悉平台参数时先调用此工具查看,再构造 content_search / content_detail 的参数。
  39. Args:
  40. platform: 可选,平台名称或关键词。支持模糊匹配(如 "xhs"、"小红书"、"youtube")。
  41. 留空返回全部平台概要。
  42. context: 工具上下文(自动注入)
  43. """
  44. hits = match_platforms(platform)
  45. if not hits:
  46. all_ids = [p.id for p in all_platforms()]
  47. return ToolResult(
  48. title="未找到匹配平台",
  49. output=f"没有匹配 '{platform}' 的平台。可用平台: {', '.join(all_ids)}",
  50. )
  51. if platform:
  52. # 有 query:返回匹配平台的详细参数
  53. result = [p.detail() for p in hits]
  54. else:
  55. # 无 query:返回概要列表
  56. result = [p.summary() for p in hits]
  57. return ToolResult(
  58. title=f"内容平台" + (f" ({platform})" if platform else ""),
  59. output=json.dumps(result, ensure_ascii=False, indent=2),
  60. )
  61. # ── content_search ──
  62. @tool(hidden_params=["context"], groups=["content"])
  63. async def content_search(
  64. platform: str,
  65. keyword: str,
  66. max_count: int = 20,
  67. cursor: str = "",
  68. extras: Optional[Dict[str, Any]] = None,
  69. context: Optional[ToolContext] = None,
  70. ) -> ToolResult:
  71. """
  72. 跨平台内容搜索,返回带索引编号的封面拼图 + 概览列表。
  73. 返回的是摘要信息(标题 + 正文截断 + 互动数据),不含完整正文和所有图片。
  74. 如需查看某条内容的完整信息,请使用 content_detail。
  75. Args:
  76. platform: 平台标识,如 'xhs'、'youtube'、'x'。完整列表见 content_platforms。
  77. keyword: 搜索关键词。
  78. max_count: 返回条数上限,默认 20。
  79. cursor: 分页游标,首次搜索留空,翻页时传入上次返回值。
  80. extras: 平台专用参数(dict)。不同平台支持不同参数,
  81. 如 xhs 支持 sort_type / publish_time / content_type / filter_note_range。
  82. 不清楚可先调 content_platforms(platform) 查看。
  83. context: 工具上下文(自动注入)
  84. """
  85. pdef = get_platform(platform)
  86. if not pdef:
  87. # 尝试模糊匹配
  88. hits = match_platforms(platform)
  89. if hits:
  90. suggestions = ", ".join(f"{p.id}({p.name})" for p in hits[:3])
  91. return ToolResult(title="平台不存在", output=f"未找到平台 '{platform}'。你是否想要: {suggestions}")
  92. all_ids = [p.id for p in all_platforms()]
  93. return ToolResult(title="平台不存在", output=f"未找到平台 '{platform}'。可用: {', '.join(all_ids)}")
  94. if not pdef.search_impl:
  95. return ToolResult(title="不支持搜索", output=f"平台 {pdef.name} 暂不支持搜索")
  96. result = await pdef.search_impl(
  97. platform_id=pdef.id,
  98. keyword=keyword,
  99. max_count=max_count,
  100. cursor=cursor,
  101. extras=extras,
  102. )
  103. # 持久化搜索结果到磁盘缓存
  104. if not result.error:
  105. posts = result.metadata.pop("posts", [])
  106. trace_id = _get_trace_id(context)
  107. _cache.save_search_results(trace_id, pdef.id, keyword, posts)
  108. return result
  109. # ── content_detail ──
  110. @tool(hidden_params=["context"], groups=["content"])
  111. async def content_detail(
  112. platform: str,
  113. index: int,
  114. extras: Optional[Dict[str, Any]] = None,
  115. context: Optional[ToolContext] = None,
  116. ) -> ToolResult:
  117. """
  118. 查看内容详情。从最近一次 content_search 的结果中按索引取完整记录。
  119. Args:
  120. platform: 平台标识(必须和之前 content_search 用的一致)。
  121. index: 内容序号(1-based),来自 content_search 返回的 index 字段。
  122. extras: 平台专用详情参数。YouTube 支持 include_captions / download_video。
  123. 其他平台通常不需要。
  124. context: 工具上下文(自动注入)
  125. """
  126. pdef = get_platform(platform)
  127. if not pdef:
  128. return ToolResult(title="平台不存在", output=f"未找到平台 '{platform}'")
  129. trace_id = _get_trace_id(context)
  130. post = _cache.get_cached_post(trace_id, pdef.id, index)
  131. if not post:
  132. info = _cache.get_cached_search_info(trace_id, pdef.id)
  133. if info:
  134. return ToolResult(
  135. title="索引无效",
  136. output=f"平台 {pdef.name} 上次搜索 '{info['keyword']}' 共 {info['total']} 条,"
  137. f"有效索引 1-{info['total']},你传入了 {index}。",
  138. error="Invalid index",
  139. )
  140. return ToolResult(
  141. title="缓存未命中",
  142. output=f"没有 {pdef.name} 的搜索缓存。请先调用 content_search(platform='{pdef.id}', keyword=...) 搜索。",
  143. error="No cache",
  144. )
  145. if pdef.detail_impl:
  146. return await pdef.detail_impl(post, extras)
  147. # fallback:直接返回缓存的完整数据
  148. return ToolResult(
  149. title=f"详情 #{index}",
  150. output=json.dumps(post, ensure_ascii=False, indent=2),
  151. )
  152. # ── content_suggest ──
  153. @tool(hidden_params=["context"], groups=["content"])
  154. async def content_suggest(
  155. platform: str,
  156. keyword: str,
  157. context: Optional[ToolContext] = None,
  158. ) -> ToolResult:
  159. """
  160. 获取搜索关键词补全建议。
  161. 仅部分平台支持(xhs、toutiao、douyin、bili、zhihu)。
  162. 用于辅助用户发现更精准的搜索词。
  163. Args:
  164. platform: 平台标识。
  165. keyword: 搜索关键词(输入中的部分词即可)。
  166. context: 工具上下文(自动注入)
  167. """
  168. pdef = get_platform(platform)
  169. if not pdef:
  170. return ToolResult(title="平台不存在", output=f"未找到平台 '{platform}'")
  171. if not pdef.suggest_impl:
  172. supported = [p.id for p in all_platforms() if p.supports_suggest]
  173. return ToolResult(
  174. title="不支持建议词",
  175. output=f"平台 {pdef.name} 不支持建议词。支持的平台: {', '.join(supported)}",
  176. )
  177. channel = (pdef.suggest_channels or [pdef.id])[0]
  178. return await pdef.suggest_impl(channel, keyword)
  179. # ── CLI 入口 ──
  180. def _parse_args(argv: list) -> dict:
  181. """解析 --key=value 格式的 CLI 参数"""
  182. kwargs = {}
  183. for arg in argv:
  184. if arg.startswith("--") and "=" in arg:
  185. key, val = arg[2:].split("=", 1)
  186. # 尝试 JSON 解析(dict / int / bool)
  187. try:
  188. val = json.loads(val)
  189. except (json.JSONDecodeError, ValueError):
  190. pass
  191. kwargs[key] = val
  192. return kwargs
  193. if __name__ == "__main__":
  194. import sys
  195. import asyncio
  196. COMMANDS = {
  197. "platforms": content_platforms,
  198. "search": content_search,
  199. "detail": content_detail,
  200. "suggest": content_suggest,
  201. }
  202. if len(sys.argv) < 2 or sys.argv[1] not in COMMANDS:
  203. print(f"Usage: python {sys.argv[0]} <{'|'.join(COMMANDS)}> [--key=value ...]")
  204. sys.exit(1)
  205. cmd = sys.argv[1]
  206. kwargs = _parse_args(sys.argv[2:])
  207. # trace_id:CLI 参数 > 环境变量 > 自动生成
  208. trace_id = kwargs.pop("trace_id", None) or os.getenv("TRACE_ID") or f"cli-{uuid.uuid4().hex[:8]}"
  209. os.environ["TRACE_ID"] = trace_id
  210. result = asyncio.run(COMMANDS[cmd](**kwargs))
  211. # 输出 JSON(与 toolhub CLI 格式一致)
  212. out = {"trace_id": trace_id, "output": result.output, "error": result.error}
  213. if result.metadata:
  214. out["metadata"] = result.metadata
  215. print(json.dumps(out, ensure_ascii=False, indent=2))