search.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265
  1. """
  2. 搜索工具模块
  3. 提供帖子搜索和建议词搜索功能,支持多个渠道平台。
  4. 主要功能:
  5. 1. search_posts - 帖子搜索
  6. 2. get_search_suggestions - 获取平台的搜索补全建议词
  7. """
  8. import json
  9. from enum import Enum
  10. from typing import Any, Dict
  11. import httpx
  12. from agent import tool
  13. from agent.tools.models import ToolResult
  14. # API 基础配置
  15. BASE_URL = "http://aigc-channel.aiddit.com/aigc/channel"
  16. DEFAULT_TIMEOUT = 60.0
  17. class PostSearchChannel(str, Enum):
  18. """
  19. 帖子搜索支持的渠道类型
  20. """
  21. XHS = "xhs" # 小红书
  22. GZH = "gzh" # 公众号
  23. SPH = "sph" # 视频号
  24. GITHUB = "github" # GitHub
  25. TOUTIAO = "toutiao" # 头条
  26. DOUYIN = "douyin" # 抖音
  27. BILI = "bili" # B站
  28. ZHIHU = "zhihu" # 知乎
  29. WEIBO = "weibo" # 微博
  30. class SuggestSearchChannel(str, Enum):
  31. """
  32. 建议词搜索支持的渠道类型
  33. """
  34. XHS = "xhs" # 小红书
  35. WX = "wx" # 微信
  36. GITHUB = "github" # GitHub
  37. TOUTIAO = "toutiao" # 头条
  38. DOUYIN = "douyin" # 抖音
  39. BILI = "bili" # B站
  40. ZHIHU = "zhihu" # 知乎
  41. @tool(
  42. display={
  43. "zh": {
  44. "name": "帖子搜索",
  45. "params": {
  46. "keyword": "搜索关键词",
  47. "channel": "搜索渠道",
  48. "cursor": "分页游标",
  49. "max_count": "返回条数"
  50. }
  51. },
  52. "en": {
  53. "name": "Search Posts",
  54. "params": {
  55. "keyword": "Search keyword",
  56. "channel": "Search channel",
  57. "cursor": "Pagination cursor",
  58. "max_count": "Max results"
  59. }
  60. }
  61. }
  62. )
  63. async def search_posts(
  64. keyword: str,
  65. channel: str = "xhs",
  66. cursor: str = "0",
  67. max_count: int = 5,
  68. uid: str = "",
  69. ) -> ToolResult:
  70. """
  71. 帖子搜索
  72. 根据关键词在指定渠道平台搜索帖子内容。
  73. Args:
  74. keyword: 搜索关键词
  75. channel: 搜索渠道,支持的渠道有:
  76. - xhs: 小红书
  77. - gzh: 公众号
  78. - sph: 视频号
  79. - github: GitHub
  80. - toutiao: 头条
  81. - douyin: 抖音
  82. - bili: B站
  83. - zhihu: 知乎
  84. - weibo: 微博
  85. cursor: 分页游标,默认为 "0"(第一页)
  86. max_count: 返回的最大条数,默认为 5
  87. uid: 用户ID(自动注入)
  88. Returns:
  89. ToolResult 包含搜索结果:
  90. {
  91. "code": 0, # 状态码,0 表示成功
  92. "message": "success", # 状态消息
  93. "data": [ # 帖子列表
  94. {
  95. "channel_content_id": "68dd03db000000000303beb2", # 内容唯一ID
  96. "title": "", # 标题
  97. "content_type": "note", # 内容类型
  98. "body_text": "", # 正文内容
  99. "like_count": 127, # 点赞数
  100. "publish_timestamp": 1759314907000, # 发布时间戳(毫秒)
  101. "images": ["https://xxx.webp"], # 图片列表
  102. "videos": [], # 视频列表
  103. "channel": "xhs", # 来源渠道
  104. "link": "xxx" # 原文链接
  105. }
  106. ]
  107. }
  108. """
  109. try:
  110. # 处理 channel 参数,支持枚举和字符串
  111. channel_value = channel.value if isinstance(channel, PostSearchChannel) else channel
  112. url = f"{BASE_URL}/data"
  113. payload = {
  114. "type": channel_value,
  115. "keyword": keyword,
  116. "cursor": cursor,
  117. "max_count": max_count,
  118. }
  119. async with httpx.AsyncClient(timeout=DEFAULT_TIMEOUT) as client:
  120. response = await client.post(
  121. url,
  122. json=payload,
  123. headers={"Content-Type": "application/json"},
  124. )
  125. response.raise_for_status()
  126. data = response.json()
  127. # 计算结果数量
  128. result_count = len(data.get("data", []))
  129. return ToolResult(
  130. title=f"搜索结果: {keyword} ({channel_value})",
  131. output=json.dumps(data, ensure_ascii=False, indent=2),
  132. long_term_memory=f"Searched '{keyword}' on {channel_value}, found {result_count} posts"
  133. )
  134. except httpx.HTTPStatusError as e:
  135. return ToolResult(
  136. title="搜索失败",
  137. output="",
  138. error=f"HTTP error {e.response.status_code}: {e.response.text}"
  139. )
  140. except Exception as e:
  141. return ToolResult(
  142. title="搜索失败",
  143. output="",
  144. error=str(e)
  145. )
  146. @tool(
  147. display={
  148. "zh": {
  149. "name": "获取搜索关键词补全建议",
  150. "params": {
  151. "keyword": "搜索关键词",
  152. "channel": "搜索渠道"
  153. }
  154. },
  155. "en": {
  156. "name": "Get Search Suggestions",
  157. "params": {
  158. "keyword": "Search keyword",
  159. "channel": "Search channel"
  160. }
  161. }
  162. }
  163. )
  164. async def get_search_suggestions(
  165. keyword: str,
  166. channel: str = "xhs",
  167. uid: str = "",
  168. ) -> ToolResult:
  169. """
  170. 获取搜索关键词补全建议
  171. 根据关键词在指定渠道平台获取搜索建议词。
  172. Args:
  173. keyword: 搜索关键词
  174. channel: 搜索渠道,支持的渠道有:
  175. - xhs: 小红书
  176. - wx: 微信
  177. - github: GitHub
  178. - toutiao: 头条
  179. - douyin: 抖音
  180. - bili: B站
  181. - zhihu: 知乎
  182. uid: 用户ID(自动注入)
  183. Returns:
  184. ToolResult 包含建议词数据:
  185. {
  186. "code": 0, # 状态码,0 表示成功
  187. "message": "success", # 状态消息
  188. "data": [ # 建议词数据
  189. {
  190. "type": "xhs", # 渠道类型
  191. "list": [ # 建议词列表
  192. {
  193. "name": "彩虹染发" # 建议词
  194. }
  195. ]
  196. }
  197. ]
  198. }
  199. """
  200. try:
  201. # 处理 channel 参数,支持枚举和字符串
  202. channel_value = channel.value if isinstance(channel, SuggestSearchChannel) else channel
  203. url = f"{BASE_URL}/suggest"
  204. payload = {
  205. "type": channel_value,
  206. "keyword": keyword,
  207. }
  208. async with httpx.AsyncClient(timeout=DEFAULT_TIMEOUT) as client:
  209. response = await client.post(
  210. url,
  211. json=payload,
  212. headers={"Content-Type": "application/json"},
  213. )
  214. response.raise_for_status()
  215. data = response.json()
  216. # 计算建议词数量
  217. suggestion_count = 0
  218. for item in data.get("data", []):
  219. suggestion_count += len(item.get("list", []))
  220. return ToolResult(
  221. title=f"建议词: {keyword} ({channel_value})",
  222. output=json.dumps(data, ensure_ascii=False, indent=2),
  223. long_term_memory=f"Got {suggestion_count} suggestions for '{keyword}' on {channel_value}"
  224. )
  225. except httpx.HTTPStatusError as e:
  226. return ToolResult(
  227. title="获取建议词失败",
  228. output="",
  229. error=f"HTTP error {e.response.status_code}: {e.response.text}"
  230. )
  231. except Exception as e:
  232. return ToolResult(
  233. title="获取建议词失败",
  234. output="",
  235. error=str(e)
  236. )