search.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420
  1. """
  2. 搜索工具模块
  3. 提供帖子搜索、帖子详情查看和建议词搜索功能,支持多个渠道平台。
  4. 主要功能:
  5. 1. search_posts - 帖子搜索(浏览模式:封面图+标题+内容截断)
  6. 2. select_post - 帖子详情(从搜索结果中选取单个帖子的完整内容)
  7. 3. get_search_suggestions - 获取平台的搜索补全建议词
  8. """
  9. import json
  10. from enum import Enum
  11. from typing import Any, Dict, List, Optional
  12. import httpx
  13. from agent.tools import tool, ToolResult
  14. from agent.tools.utils.image import build_image_grid, encode_base64, load_images
  15. # API 基础配置
  16. BASE_URL = "http://aigc-channel.aiddit.com/aigc/channel"
  17. DEFAULT_TIMEOUT = 60.0
  18. # 搜索结果缓存,以序号为 key
  19. _search_cache: Dict[int, Dict[str, Any]] = {}
  20. async def _build_collage(posts: List[Dict[str, Any]]) -> Optional[str]:
  21. """
  22. 将帖子封面图+序号+标题拼接成网格图,返回 base64 编码的 PNG。
  23. 复用 agent.tools.utils.image 中的共享拼图逻辑。
  24. """
  25. if not posts:
  26. return None
  27. # 收集有封面图的帖子
  28. urls: List[str] = []
  29. titles: List[str] = []
  30. for post in posts:
  31. imgs = post.get("images", [])
  32. cover_url = imgs[0] if imgs else None
  33. if cover_url:
  34. urls.append(cover_url)
  35. titles.append(post.get("title", "") or "")
  36. if not urls:
  37. return None
  38. # 并发加载图片
  39. loaded = await load_images(urls)
  40. # 过滤加载失败的(保持 url 和 title 对齐)
  41. valid_images = []
  42. valid_labels = []
  43. for (_, img), title in zip(loaded, titles):
  44. if img is not None:
  45. valid_images.append(img)
  46. valid_labels.append(title)
  47. if not valid_images:
  48. return None
  49. grid = build_image_grid(images=valid_images, labels=valid_labels)
  50. b64, _ = encode_base64(grid, format="PNG")
  51. return b64
  52. class PostSearchChannel(str, Enum):
  53. """
  54. 帖子搜索支持的渠道类型
  55. """
  56. XHS = "xhs" # 小红书
  57. GZH = "gzh" # 公众号
  58. SPH = "sph" # 视频号
  59. GITHUB = "github" # GitHub
  60. TOUTIAO = "toutiao" # 头条
  61. DOUYIN = "douyin" # 抖音
  62. BILI = "bili" # B站
  63. ZHIHU = "zhihu" # 知乎
  64. WEIBO = "weibo" # 微博
  65. class SuggestSearchChannel(str, Enum):
  66. """
  67. 建议词搜索支持的渠道类型
  68. """
  69. XHS = "xhs" # 小红书
  70. WX = "wx" # 微信
  71. GITHUB = "github" # GitHub
  72. TOUTIAO = "toutiao" # 头条
  73. DOUYIN = "douyin" # 抖音
  74. BILI = "bili" # B站
  75. ZHIHU = "zhihu" # 知乎
  76. @tool(
  77. display={
  78. "zh": {
  79. "name": "帖子搜索",
  80. "params": {
  81. "keyword": "搜索关键词",
  82. "channel": "搜索渠道(xhs=小红书, gzh=公众号, sph=视频号, github, toutiao=头条, douyin=抖音, bili=B站, zhihu=知乎, weibo=微博)",
  83. "cursor": "分页游标",
  84. "max_count": "返回条数",
  85. "content_type": "内容类型-视频/图文",
  86. "sort_type": "排序方式(xhs专用)",
  87. "publish_time": "发布时间筛选(xhs专用)",
  88. "filter_note_range": "笔记时长筛选(xhs专用)"
  89. }
  90. },
  91. "en": {
  92. "name": "Search Posts",
  93. "params": {
  94. "keyword": "Search keyword",
  95. "channel": "Search channel (xhs=XiaoHongShu, gzh=WeChat Official Account, sph=WeChat Channels, github, toutiao, douyin, bili, zhihu, weibo)",
  96. "cursor": "Pagination cursor",
  97. "max_count": "Max results",
  98. "content_type": "content type-视频/图文",
  99. "sort_type": "Sort type (xhs only)",
  100. "publish_time": "Publish time filter (xhs only)",
  101. "filter_note_range": "Note duration filter (xhs only)"
  102. }
  103. }
  104. }
  105. )
  106. async def search_posts(
  107. keyword: str,
  108. channel: str = "xhs",
  109. cursor: str = "",
  110. max_count: int = 20,
  111. content_type: str = "",
  112. sort_type: str = "综合排序",
  113. publish_time: str = "不限",
  114. filter_note_range: str = "不限"
  115. ) -> ToolResult:
  116. """
  117. 帖子搜索(浏览模式)
  118. 根据关键词在指定渠道平台搜索帖子,返回封面图+标题+内容摘要,用于快速浏览。
  119. 如需查看某个帖子的完整内容,请使用 select_post 工具。
  120. Args:
  121. keyword: 搜索关键词
  122. channel: 搜索渠道,支持的渠道有:
  123. - xhs: 小红书
  124. - gzh: 公众号
  125. - sph: 视频号
  126. - github: GitHub
  127. - toutiao: 头条
  128. - douyin: 抖音
  129. - bili: B站
  130. - zhihu: 知乎
  131. - weibo: 微博
  132. cursor: 分页游标,首次请求为空字符串,后续使用上次返回的 cursor
  133. max_count: 返回的最大条数,默认为 20
  134. content_type: 内容类型筛选,默认不限;
  135. xhs 可选值:'不限' | '图文' | '视频' | '文章';
  136. 其他渠道可选值:'视频' | '图文'
  137. sort_type: 排序方式(仅 xhs 有效),可选值:'综合排序' | '最新发布' | '最多点赞',默认'综合排序'
  138. publish_time: 发布时间筛选(仅 xhs 有效),可选值:'不限' | '近30天' | '近7天' | '近1天',默认'不限'
  139. filter_note_range: 笔记时长筛选,视频内容有效(仅 xhs 有效),可选值:'不限' | '1分钟以内' | '1-5分钟' | '5分钟以上',默认'不限'
  140. Returns:
  141. ToolResult 包含搜索结果摘要列表(封面图+标题+内容截断),
  142. 可通过 channel_content_id 调用 select_post 查看完整内容。
  143. """
  144. global _search_cache
  145. try:
  146. channel_value = channel.value if isinstance(channel, PostSearchChannel) else channel
  147. url = f"{BASE_URL}/data"
  148. if channel_value == "xhs":
  149. payload = {
  150. "type": channel_value,
  151. "keyword": keyword,
  152. "cursor": cursor,
  153. "content_type": content_type if content_type else "不限",
  154. "sort_type": sort_type,
  155. "publish_time": publish_time,
  156. "filter_note_range": filter_note_range,
  157. }
  158. else:
  159. payload = {
  160. "type": channel_value,
  161. "keyword": keyword,
  162. "cursor": cursor if cursor else "0",
  163. "max_count": max_count,
  164. "content_type": content_type,
  165. }
  166. async with httpx.AsyncClient(timeout=DEFAULT_TIMEOUT) as client:
  167. response = await client.post(
  168. url,
  169. json=payload,
  170. headers={"Content-Type": "application/json"},
  171. )
  172. response.raise_for_status()
  173. data = response.json()
  174. posts = data.get("data", [])
  175. # 缓存完整结果(以序号为 key)
  176. _search_cache.clear()
  177. for idx, post in enumerate(posts):
  178. _search_cache[idx + 1] = post
  179. # 构建摘要列表(带序号)
  180. summary_list = []
  181. for idx, post in enumerate(posts):
  182. body = post.get("body_text", "") or ""
  183. title = post.get("title") or body[:20] or ""
  184. summary_list.append({
  185. "index": idx + 1,
  186. "channel_content_id": post.get("channel_content_id"),
  187. "title": title,
  188. "body_text": body[:100] + ("..." if len(body) > 100 else ""),
  189. "like_count": post.get("like_count"),
  190. "collect_count": post.get("collect_count"),
  191. "comment_count": post.get("comment_count"),
  192. "channel": post.get("channel"),
  193. "link": post.get("link"),
  194. "content_type": post.get("content_type"),
  195. "publish_timestamp": post.get("publish_timestamp"),
  196. })
  197. # 拼接封面图网格
  198. images = []
  199. try:
  200. collage_b64 = await _build_collage(posts)
  201. if collage_b64:
  202. images.append({
  203. "type": "base64",
  204. "media_type": "image/png",
  205. "data": collage_b64
  206. })
  207. except Exception as collage_error:
  208. # 图片拼接失败不影响主流程,记录错误但继续返回结果
  209. import logging
  210. logging.warning(f"Failed to build collage for {channel_value}: {collage_error}")
  211. output_data = {
  212. "code": data.get("code"),
  213. "message": data.get("message"),
  214. "data": summary_list
  215. }
  216. return ToolResult(
  217. title=f"搜索结果: {keyword} ({channel_value})",
  218. output=json.dumps(output_data, ensure_ascii=False, indent=2),
  219. long_term_memory=f"Searched '{keyword}' on {channel_value}, found {len(posts)} posts. Use select_post(index) to view full details of a specific post.",
  220. images=images
  221. )
  222. except httpx.HTTPStatusError as e:
  223. return ToolResult(
  224. title="搜索失败",
  225. output="",
  226. error=f"HTTP error {e.response.status_code}: {e.response.text}"
  227. )
  228. except Exception as e:
  229. return ToolResult(
  230. title="搜索失败",
  231. output="",
  232. error=str(e)
  233. )
  234. @tool(
  235. display={
  236. "zh": {
  237. "name": "帖子详情",
  238. "params": {
  239. "index": "帖子序号"
  240. }
  241. },
  242. "en": {
  243. "name": "Select Post",
  244. "params": {
  245. "index": "Post index"
  246. }
  247. }
  248. }
  249. )
  250. async def select_post(
  251. index: int,
  252. ) -> ToolResult:
  253. """
  254. 查看帖子详情
  255. 从最近一次 search_posts 的搜索结果中,根据序号选取指定帖子并返回完整内容(全部正文、全部图片、视频等)。
  256. 需要先调用 search_posts 进行搜索。
  257. Args:
  258. index: 帖子序号,来自 search_posts 返回结果中的 index 字段(从 1 开始)
  259. Returns:
  260. ToolResult 包含该帖子的完整信息和所有图片。
  261. """
  262. post = _search_cache.get(index)
  263. if not post:
  264. return ToolResult(
  265. title="未找到帖子",
  266. output="",
  267. error=f"未找到序号 {index} 的帖子,请先调用 search_posts 搜索。"
  268. )
  269. # 返回所有图片
  270. images = []
  271. for img_url in post.get("images", []):
  272. if img_url:
  273. images.append({
  274. "type": "url",
  275. "url": img_url
  276. })
  277. return ToolResult(
  278. title=f"帖子详情 #{index}: {post.get('title', '')}",
  279. output=json.dumps(post, ensure_ascii=False, indent=2),
  280. long_term_memory=f"Viewed post detail #{index}: {post.get('title', '')}",
  281. images=images
  282. )
  283. @tool(
  284. display={
  285. "zh": {
  286. "name": "获取搜索关键词补全建议",
  287. "params": {
  288. "keyword": "搜索关键词",
  289. "channel": "搜索渠道"
  290. }
  291. },
  292. "en": {
  293. "name": "Get Search Suggestions",
  294. "params": {
  295. "keyword": "Search keyword",
  296. "channel": "Search channel"
  297. }
  298. }
  299. }
  300. )
  301. async def get_search_suggestions(
  302. keyword: str,
  303. channel: str = "xhs",
  304. ) -> ToolResult:
  305. """
  306. 获取搜索关键词补全建议
  307. 根据关键词在指定渠道平台获取搜索建议词。
  308. Args:
  309. keyword: 搜索关键词
  310. channel: 搜索渠道,支持的渠道有:
  311. - xhs: 小红书
  312. - wx: 微信
  313. - github: GitHub
  314. - toutiao: 头条
  315. - douyin: 抖音
  316. - bili: B站
  317. - zhihu: 知乎
  318. Returns:
  319. ToolResult 包含建议词数据:
  320. {
  321. "code": 0, # 状态码,0 表示成功
  322. "message": "success", # 状态消息
  323. "data": [ # 建议词数据
  324. {
  325. "type": "xhs", # 渠道类型
  326. "list": [ # 建议词列表
  327. {
  328. "name": "彩虹染发" # 建议词
  329. }
  330. ]
  331. }
  332. ]
  333. }
  334. """
  335. try:
  336. # 处理 channel 参数,支持枚举和字符串
  337. channel_value = channel.value if isinstance(channel, SuggestSearchChannel) else channel
  338. url = f"{BASE_URL}/suggest"
  339. payload = {
  340. "type": channel_value,
  341. "keyword": keyword,
  342. }
  343. async with httpx.AsyncClient(timeout=DEFAULT_TIMEOUT) as client:
  344. response = await client.post(
  345. url,
  346. json=payload,
  347. headers={"Content-Type": "application/json"},
  348. )
  349. response.raise_for_status()
  350. data = response.json()
  351. # 计算建议词数量
  352. suggestion_count = 0
  353. for item in data.get("data", []):
  354. suggestion_count += len(item.get("list", []))
  355. return ToolResult(
  356. title=f"建议词: {keyword} ({channel_value})",
  357. output=json.dumps(data, ensure_ascii=False, indent=2),
  358. long_term_memory=f"Got {suggestion_count} suggestions for '{keyword}' on {channel_value}"
  359. )
  360. except httpx.HTTPStatusError as e:
  361. return ToolResult(
  362. title="获取建议词失败",
  363. output="",
  364. error=f"HTTP error {e.response.status_code}: {e.response.text}"
  365. )
  366. except Exception as e:
  367. return ToolResult(
  368. title="获取建议词失败",
  369. output="",
  370. error=str(e)
  371. )