search.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227
  1. """
  2. 搜索工具模块
  3. 提供帖子搜索和建议词搜索功能,支持多个渠道平台。
  4. 主要功能:
  5. 1. search_posts - 帖子搜索
  6. 2. search_suggestions - 建议词搜索
  7. """
  8. from enum import Enum
  9. from typing import Any, Dict, Optional
  10. import httpx
  11. from agent import tool
  12. # API 基础配置
  13. BASE_URL = "http://aigc-channel.aiddit.com/aigc/channel"
  14. DEFAULT_TIMEOUT = 60.0
  15. class PostSearchChannel(str, Enum):
  16. """
  17. 帖子搜索支持的渠道类型
  18. """
  19. XHS = "xhs" # 小红书
  20. GZH = "gzh" # 公众号
  21. SPH = "sph" # 视频号
  22. GITHUB = "github" # GitHub
  23. TOUTIAO = "toutiao" # 头条
  24. DOUYIN = "douyin" # 抖音
  25. BILI = "bili" # B站
  26. ZHIHU = "zhihu" # 知乎
  27. WEIBO = "weibo" # 微博
  28. class SuggestSearchChannel(str, Enum):
  29. """
  30. 建议词搜索支持的渠道类型
  31. """
  32. XHS = "xhs" # 小红书
  33. WX = "wx" # 微信
  34. GITHUB = "github" # GitHub
  35. TOUTIAO = "toutiao" # 头条
  36. DOUYIN = "douyin" # 抖音
  37. BILI = "bili" # B站
  38. ZHIHU = "zhihu" # 知乎
  39. @tool(
  40. display={
  41. "zh": {
  42. "name": "帖子搜索",
  43. "params": {
  44. "keyword": "搜索关键词",
  45. "channel": "搜索渠道",
  46. "cursor": "分页游标",
  47. "max_count": "返回条数"
  48. }
  49. },
  50. "en": {
  51. "name": "Search Posts",
  52. "params": {
  53. "keyword": "Search keyword",
  54. "channel": "Search channel",
  55. "cursor": "Pagination cursor",
  56. "max_count": "Max results"
  57. }
  58. }
  59. }
  60. )
  61. async def search_posts(
  62. keyword: str,
  63. channel: str = "xhs",
  64. cursor: str = "0",
  65. max_count: int = 5,
  66. uid: str = "",
  67. ) -> Dict[str, Any]:
  68. """
  69. 帖子搜索
  70. 根据关键词在指定渠道平台搜索帖子内容。
  71. Args:
  72. keyword: 搜索关键词
  73. channel: 搜索渠道,支持的渠道有:
  74. - xhs: 小红书
  75. - gzh: 公众号
  76. - sph: 视频号
  77. - github: GitHub
  78. - toutiao: 头条
  79. - douyin: 抖音
  80. - bili: B站
  81. - zhihu: 知乎
  82. - weibo: 微博
  83. cursor: 分页游标,默认为 "0"(第一页)
  84. max_count: 返回的最大条数,默认为 5
  85. uid: 用户ID(自动注入)
  86. Returns:
  87. API 返回的原始响应,结构如下:
  88. {
  89. "code": 0, # 状态码,0 表示成功
  90. "message": "success", # 状态消息
  91. "data": [ # 帖子列表
  92. {
  93. "channel_content_id": "68dd03db000000000303beb2", # 内容唯一ID
  94. "title": "", # 标题
  95. "content_type": "note", # 内容类型
  96. "body_text": "", # 正文内容
  97. "like_count": 127, # 点赞数
  98. "publish_timestamp": 1759314907000, # 发布时间戳(毫秒)
  99. "images": ["https://xxx.webp"], # 图片列表
  100. "videos": [], # 视频列表
  101. "channel": "xhs", # 来源渠道
  102. "link": "xxx" # 原文链接
  103. }
  104. ]
  105. }
  106. Raises:
  107. httpx.HTTPStatusError: HTTP 请求返回非 2xx 状态码
  108. httpx.TimeoutException: 请求超时
  109. httpx.RequestError: 其他请求错误
  110. """
  111. # 处理 channel 参数,支持枚举和字符串
  112. channel_value = channel.value if isinstance(channel, PostSearchChannel) else channel
  113. url = f"{BASE_URL}/data"
  114. payload = {
  115. "type": channel_value,
  116. "keyword": keyword,
  117. "cursor": cursor,
  118. "max_count": max_count,
  119. }
  120. async with httpx.AsyncClient(timeout=DEFAULT_TIMEOUT) as client:
  121. response = await client.post(
  122. url,
  123. json=payload,
  124. headers={"Content-Type": "application/json"},
  125. )
  126. response.raise_for_status()
  127. return response.json()
  128. @tool(
  129. display={
  130. "zh": {
  131. "name": "建议词搜索",
  132. "params": {
  133. "keyword": "搜索关键词",
  134. "channel": "搜索渠道"
  135. }
  136. },
  137. "en": {
  138. "name": "Search Suggestions",
  139. "params": {
  140. "keyword": "Search keyword",
  141. "channel": "Search channel"
  142. }
  143. }
  144. }
  145. )
  146. async def search_suggestions(
  147. keyword: str,
  148. channel: str = "xhs",
  149. uid: str = "",
  150. ) -> Dict[str, Any]:
  151. """
  152. 建议词搜索
  153. 根据关键词在指定渠道平台获取搜索建议词。
  154. Args:
  155. keyword: 搜索关键词
  156. channel: 搜索渠道,支持的渠道有:
  157. - xhs: 小红书
  158. - wx: 微信
  159. - github: GitHub
  160. - toutiao: 头条
  161. - douyin: 抖音
  162. - bili: B站
  163. - zhihu: 知乎
  164. uid: 用户ID(自动注入)
  165. Returns:
  166. API 返回的原始响应,结构如下:
  167. {
  168. "code": 0, # 状态码,0 表示成功
  169. "message": "success", # 状态消息
  170. "data": [ # 建议词数据
  171. {
  172. "type": "xhs", # 渠道类型
  173. "list": [ # 建议词列表
  174. {
  175. "name": "彩虹染发" # 建议词
  176. }
  177. ]
  178. }
  179. ]
  180. }
  181. Raises:
  182. httpx.HTTPStatusError: HTTP 请求返回非 2xx 状态码
  183. httpx.TimeoutException: 请求超时
  184. httpx.RequestError: 其他请求错误
  185. """
  186. # 处理 channel 参数,支持枚举和字符串
  187. channel_value = channel.value if isinstance(channel, SuggestSearchChannel) else channel
  188. url = f"{BASE_URL}/suggest"
  189. payload = {
  190. "type": channel_value,
  191. "keyword": keyword,
  192. }
  193. async with httpx.AsyncClient(timeout=DEFAULT_TIMEOUT) as client:
  194. response = await client.post(
  195. url,
  196. json=payload,
  197. headers={"Content-Type": "application/json"},
  198. )
  199. response.raise_for_status()
  200. return response.json()