search.py 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284
  1. """
  2. 搜索工具模块
  3. 提供帖子搜索和建议词搜索功能,支持多个渠道平台。
  4. 主要功能:
  5. 1. search_posts - 帖子搜索
  6. 2. get_search_suggestions - 获取平台的搜索补全建议词
  7. """
  8. import json
  9. from enum import Enum
  10. from typing import Any, Dict
  11. import httpx
  12. from agent.tools import tool, ToolResult
  13. # API 基础配置
  14. BASE_URL = "http://aigc-channel.aiddit.com/aigc/channel"
  15. DEFAULT_TIMEOUT = 60.0
  16. class PostSearchChannel(str, Enum):
  17. """
  18. 帖子搜索支持的渠道类型
  19. """
  20. XHS = "xhs" # 小红书
  21. GZH = "gzh" # 公众号
  22. SPH = "sph" # 视频号
  23. GITHUB = "github" # GitHub
  24. TOUTIAO = "toutiao" # 头条
  25. DOUYIN = "douyin" # 抖音
  26. BILI = "bili" # B站
  27. ZHIHU = "zhihu" # 知乎
  28. WEIBO = "weibo" # 微博
  29. class SuggestSearchChannel(str, Enum):
  30. """
  31. 建议词搜索支持的渠道类型
  32. """
  33. XHS = "xhs" # 小红书
  34. WX = "wx" # 微信
  35. GITHUB = "github" # GitHub
  36. TOUTIAO = "toutiao" # 头条
  37. DOUYIN = "douyin" # 抖音
  38. BILI = "bili" # B站
  39. ZHIHU = "zhihu" # 知乎
  40. @tool(
  41. display={
  42. "zh": {
  43. "name": "帖子搜索",
  44. "params": {
  45. "keyword": "搜索关键词",
  46. "channel": "搜索渠道",
  47. "cursor": "分页游标",
  48. "max_count": "返回条数",
  49. "include_images": "是否包含图片",
  50. "content_type": "内容类型-视频/图文"
  51. }
  52. },
  53. "en": {
  54. "name": "Search Posts",
  55. "params": {
  56. "keyword": "Search keyword",
  57. "channel": "Search channel",
  58. "cursor": "Pagination cursor",
  59. "max_count": "Max results",
  60. "include_images": "Include images",
  61. "content_type": "content type-视频/图文"
  62. }
  63. }
  64. }
  65. )
  66. async def search_posts(
  67. keyword: str,
  68. channel: str = "xhs",
  69. cursor: str = "0",
  70. max_count: int = 5,
  71. include_images: bool = False,
  72. content_type: str = ""
  73. ) -> ToolResult:
  74. """
  75. 帖子搜索
  76. 根据关键词在指定渠道平台搜索帖子内容。
  77. Args:
  78. keyword: 搜索关键词
  79. channel: 搜索渠道,支持的渠道有:
  80. - xhs: 小红书
  81. - gzh: 公众号
  82. - sph: 视频号
  83. - github: GitHub
  84. - toutiao: 头条
  85. - douyin: 抖音
  86. - bili: B站
  87. - zhihu: 知乎
  88. - weibo: 微博
  89. cursor: 分页游标,默认为 "0"(第一页)
  90. max_count: 返回的最大条数,默认为 5
  91. include_images: 是否将帖子中的图片传给 LLM 查看,默认为 False
  92. content_type:内容类型-视频/图文,默认不传为不限制类型
  93. Returns:
  94. ToolResult 包含搜索结果:
  95. {
  96. "code": 0, # 状态码,0 表示成功
  97. "message": "success", # 状态消息
  98. "data": [ # 帖子列表
  99. {
  100. "channel_content_id": "68dd03db000000000303beb2", # 内容唯一ID
  101. "title": "", # 标题
  102. "content_type": "note", # 内容类型
  103. "body_text": "", # 正文内容
  104. "like_count": 127, # 点赞数
  105. "publish_timestamp": 1759314907000, # 发布时间戳(毫秒)
  106. "images": ["https://xxx.webp"], # 图片列表
  107. "videos": [], # 视频列表
  108. "channel": "xhs", # 来源渠道
  109. "link": "xxx" # 原文链接
  110. }
  111. ]
  112. }
  113. """
  114. try:
  115. # 处理 channel 参数,支持枚举和字符串
  116. channel_value = channel.value if isinstance(channel, PostSearchChannel) else channel
  117. url = f"{BASE_URL}/data"
  118. payload = {
  119. "type": channel_value,
  120. "keyword": keyword,
  121. "cursor": cursor,
  122. "max_count": max_count,
  123. "content_type": content_type
  124. }
  125. async with httpx.AsyncClient(timeout=DEFAULT_TIMEOUT) as client:
  126. response = await client.post(
  127. url,
  128. json=payload,
  129. headers={"Content-Type": "application/json"},
  130. )
  131. response.raise_for_status()
  132. data = response.json()
  133. # 计算结果数量
  134. result_count = len(data.get("data", []))
  135. # 提取图片 URL 并构建 images 列表供 LLM 查看(仅当 include_images=True 时)
  136. images = []
  137. if include_images:
  138. for post in data.get("data", []):
  139. for img_url in post.get("images", [])[:3]: # 每个帖子最多取前3张图
  140. if img_url:
  141. images.append({
  142. "type": "url",
  143. "url": img_url
  144. })
  145. # 限制总图片数量,避免过多
  146. if len(images) >= 10:
  147. break
  148. return ToolResult(
  149. title=f"搜索结果: {keyword} ({channel_value})",
  150. output=json.dumps(data, ensure_ascii=False, indent=2),
  151. long_term_memory=f"Searched '{keyword}' on {channel_value}, found {result_count} posts",
  152. images=images
  153. )
  154. except httpx.HTTPStatusError as e:
  155. return ToolResult(
  156. title="搜索失败",
  157. output="",
  158. error=f"HTTP error {e.response.status_code}: {e.response.text}"
  159. )
  160. except Exception as e:
  161. return ToolResult(
  162. title="搜索失败",
  163. output="",
  164. error=str(e)
  165. )
  166. @tool(
  167. display={
  168. "zh": {
  169. "name": "获取搜索关键词补全建议",
  170. "params": {
  171. "keyword": "搜索关键词",
  172. "channel": "搜索渠道"
  173. }
  174. },
  175. "en": {
  176. "name": "Get Search Suggestions",
  177. "params": {
  178. "keyword": "Search keyword",
  179. "channel": "Search channel"
  180. }
  181. }
  182. }
  183. )
  184. async def get_search_suggestions(
  185. keyword: str,
  186. channel: str = "xhs",
  187. ) -> ToolResult:
  188. """
  189. 获取搜索关键词补全建议
  190. 根据关键词在指定渠道平台获取搜索建议词。
  191. Args:
  192. keyword: 搜索关键词
  193. channel: 搜索渠道,支持的渠道有:
  194. - xhs: 小红书
  195. - wx: 微信
  196. - github: GitHub
  197. - toutiao: 头条
  198. - douyin: 抖音
  199. - bili: B站
  200. - zhihu: 知乎
  201. Returns:
  202. ToolResult 包含建议词数据:
  203. {
  204. "code": 0, # 状态码,0 表示成功
  205. "message": "success", # 状态消息
  206. "data": [ # 建议词数据
  207. {
  208. "type": "xhs", # 渠道类型
  209. "list": [ # 建议词列表
  210. {
  211. "name": "彩虹染发" # 建议词
  212. }
  213. ]
  214. }
  215. ]
  216. }
  217. """
  218. try:
  219. # 处理 channel 参数,支持枚举和字符串
  220. channel_value = channel.value if isinstance(channel, SuggestSearchChannel) else channel
  221. url = f"{BASE_URL}/suggest"
  222. payload = {
  223. "type": channel_value,
  224. "keyword": keyword,
  225. }
  226. async with httpx.AsyncClient(timeout=DEFAULT_TIMEOUT) as client:
  227. response = await client.post(
  228. url,
  229. json=payload,
  230. headers={"Content-Type": "application/json"},
  231. )
  232. response.raise_for_status()
  233. data = response.json()
  234. # 计算建议词数量
  235. suggestion_count = 0
  236. for item in data.get("data", []):
  237. suggestion_count += len(item.get("list", []))
  238. return ToolResult(
  239. title=f"建议词: {keyword} ({channel_value})",
  240. output=json.dumps(data, ensure_ascii=False, indent=2),
  241. long_term_memory=f"Got {suggestion_count} suggestions for '{keyword}' on {channel_value}"
  242. )
  243. except httpx.HTTPStatusError as e:
  244. return ToolResult(
  245. title="获取建议词失败",
  246. output="",
  247. error=f"HTTP error {e.response.status_code}: {e.response.text}"
  248. )
  249. except Exception as e:
  250. return ToolResult(
  251. title="获取建议词失败",
  252. output="",
  253. error=str(e)
  254. )