search.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460
  1. """
  2. 搜索工具模块
  3. 提供帖子搜索、帖子详情查看和建议词搜索功能,支持多个渠道平台。
  4. 主要功能:
  5. 1. search_posts - 帖子搜索(浏览模式:封面图+标题+内容截断)
  6. 2. select_post - 帖子详情(从搜索结果中选取单个帖子的完整内容)
  7. 3. get_search_suggestions - 获取平台的搜索补全建议词
  8. """
  9. import asyncio
  10. import base64
  11. import io
  12. import json
  13. import math
  14. from enum import Enum
  15. from typing import Any, Dict, List, Optional
  16. import httpx
  17. from PIL import Image, ImageDraw, ImageFont
  18. from agent.tools import tool, ToolResult
  19. # API 基础配置
  20. BASE_URL = "http://aigc-channel.aiddit.com/aigc/channel"
  21. DEFAULT_TIMEOUT = 60.0
  22. # 搜索结果缓存,以序号为 key
  23. _search_cache: Dict[int, Dict[str, Any]] = {}
  24. # 拼接图配置
  25. THUMB_WIDTH = 250
  26. THUMB_HEIGHT = 250
  27. TEXT_HEIGHT = 80
  28. GRID_COLS = 5
  29. PADDING = 12
  30. BG_COLOR = (255, 255, 255)
  31. TEXT_COLOR = (30, 30, 30)
  32. INDEX_COLOR = (220, 60, 60)
  33. def _truncate_text(text: str, max_len: int = 14) -> str:
  34. """截断文本,超出部分用省略号"""
  35. return text[:max_len] + "..." if len(text) > max_len else text
  36. async def _download_image(client: httpx.AsyncClient, url: str) -> Optional[Image.Image]:
  37. """下载单张图片,失败返回 None"""
  38. try:
  39. resp = await client.get(url, timeout=15.0)
  40. resp.raise_for_status()
  41. return Image.open(io.BytesIO(resp.content)).convert("RGB")
  42. except Exception:
  43. return None
  44. async def _build_collage(posts: List[Dict[str, Any]]) -> Optional[str]:
  45. """
  46. 将帖子封面图+序号+标题拼接成网格图,返回 base64 编码的 PNG。
  47. 每个格子:序号 + 封面图 + 标题
  48. """
  49. if not posts:
  50. return None
  51. # 收集有封面图的帖子,记录原始序号
  52. items = []
  53. for idx, post in enumerate(posts):
  54. imgs = post.get("images", [])
  55. cover_url = imgs[0] if imgs else None
  56. if cover_url:
  57. items.append({
  58. "url": cover_url,
  59. "title": post.get("title", "") or "",
  60. "index": idx + 1,
  61. })
  62. if not items:
  63. return None
  64. # 并发下载封面图
  65. async with httpx.AsyncClient() as client:
  66. tasks = [_download_image(client, item["url"]) for item in items]
  67. downloaded = await asyncio.gather(*tasks)
  68. # 过滤下载失败的
  69. valid = [(item, img) for item, img in zip(items, downloaded) if img is not None]
  70. if not valid:
  71. return None
  72. cols = min(GRID_COLS, len(valid))
  73. rows = math.ceil(len(valid) / cols)
  74. cell_w = THUMB_WIDTH + PADDING
  75. cell_h = THUMB_HEIGHT + TEXT_HEIGHT + PADDING
  76. canvas_w = cols * cell_w + PADDING
  77. canvas_h = rows * cell_h + PADDING
  78. canvas = Image.new("RGB", (canvas_w, canvas_h), BG_COLOR)
  79. draw = ImageDraw.Draw(canvas)
  80. # 尝试加载字体
  81. try:
  82. font_title = ImageFont.truetype("msyh.ttc", 16)
  83. font_index = ImageFont.truetype("msyh.ttc", 32)
  84. except Exception:
  85. try:
  86. font_title = ImageFont.truetype("arial.ttf", 16)
  87. font_index = ImageFont.truetype("arial.ttf", 32)
  88. except Exception:
  89. font_title = ImageFont.load_default()
  90. font_index = font_title
  91. for item, img in valid:
  92. idx = item["index"]
  93. col = (idx - 1) % cols
  94. row = (idx - 1) // cols
  95. x = PADDING + col * cell_w
  96. y = PADDING + row * cell_h
  97. # 缩放封面图
  98. thumb = img.resize((THUMB_WIDTH, THUMB_HEIGHT), Image.LANCZOS)
  99. canvas.paste(thumb, (x, y))
  100. # 左上角写序号(带背景)
  101. index_text = f" {idx} "
  102. bbox = draw.textbbox((0, 0), index_text, font=font_index)
  103. tw, th = bbox[2] - bbox[0], bbox[3] - bbox[1]
  104. # 增加背景块的 padding,确保完全覆盖数字
  105. pad_x, pad_y = 8, 6
  106. draw.rectangle([x, y, x + tw + pad_x * 2, y + th + pad_y * 2], fill=INDEX_COLOR)
  107. draw.text((x + pad_x, y + pad_y), index_text, fill=(255, 255, 255), font=font_index)
  108. # 写标题
  109. title_text = _truncate_text(item["title"], max_len=16)
  110. draw.text((x, y + THUMB_HEIGHT + 6), title_text, fill=TEXT_COLOR, font=font_title)
  111. # 转 base64
  112. buf = io.BytesIO()
  113. canvas.save(buf, format="PNG")
  114. return base64.b64encode(buf.getvalue()).decode("utf-8")
  115. class PostSearchChannel(str, Enum):
  116. """
  117. 帖子搜索支持的渠道类型
  118. """
  119. XHS = "xhs" # 小红书
  120. GZH = "gzh" # 公众号
  121. SPH = "sph" # 视频号
  122. GITHUB = "github" # GitHub
  123. TOUTIAO = "toutiao" # 头条
  124. DOUYIN = "douyin" # 抖音
  125. BILI = "bili" # B站
  126. ZHIHU = "zhihu" # 知乎
  127. WEIBO = "weibo" # 微博
  128. class SuggestSearchChannel(str, Enum):
  129. """
  130. 建议词搜索支持的渠道类型
  131. """
  132. XHS = "xhs" # 小红书
  133. WX = "wx" # 微信
  134. GITHUB = "github" # GitHub
  135. TOUTIAO = "toutiao" # 头条
  136. DOUYIN = "douyin" # 抖音
  137. BILI = "bili" # B站
  138. ZHIHU = "zhihu" # 知乎
  139. @tool(
  140. display={
  141. "zh": {
  142. "name": "帖子搜索",
  143. "params": {
  144. "keyword": "搜索关键词",
  145. "channel": "搜索渠道",
  146. "cursor": "分页游标",
  147. "max_count": "返回条数",
  148. "content_type": "内容类型-视频/图文"
  149. }
  150. },
  151. "en": {
  152. "name": "Search Posts",
  153. "params": {
  154. "keyword": "Search keyword",
  155. "channel": "Search channel",
  156. "cursor": "Pagination cursor",
  157. "max_count": "Max results",
  158. "content_type": "content type-视频/图文"
  159. }
  160. }
  161. }
  162. )
  163. async def search_posts(
  164. keyword: str,
  165. channel: str = "xhs",
  166. cursor: str = "0",
  167. max_count: int = 20,
  168. content_type: str = ""
  169. ) -> ToolResult:
  170. """
  171. 帖子搜索(浏览模式)
  172. 根据关键词在指定渠道平台搜索帖子,返回封面图+标题+内容摘要,用于快速浏览。
  173. 如需查看某个帖子的完整内容,请使用 select_post 工具。
  174. Args:
  175. keyword: 搜索关键词
  176. channel: 搜索渠道,支持的渠道有:
  177. - xhs: 小红书
  178. - gzh: 公众号
  179. - sph: 视频号
  180. - github: GitHub
  181. - toutiao: 头条
  182. - douyin: 抖音
  183. - bili: B站
  184. - zhihu: 知乎
  185. - weibo: 微博
  186. cursor: 分页游标,默认为 "0"(第一页)
  187. max_count: 返回的最大条数,默认为 20
  188. content_type:内容类型-视频/图文,默认不传为不限制类型
  189. Returns:
  190. ToolResult 包含搜索结果摘要列表(封面图+标题+内容截断),
  191. 可通过 channel_content_id 调用 select_post 查看完整内容。
  192. """
  193. global _search_cache
  194. try:
  195. channel_value = channel.value if isinstance(channel, PostSearchChannel) else channel
  196. url = f"{BASE_URL}/data"
  197. payload = {
  198. "type": channel_value,
  199. "keyword": keyword,
  200. "cursor": cursor,
  201. "max_count": max_count,
  202. "content_type": content_type
  203. }
  204. async with httpx.AsyncClient(timeout=DEFAULT_TIMEOUT) as client:
  205. response = await client.post(
  206. url,
  207. json=payload,
  208. headers={"Content-Type": "application/json"},
  209. )
  210. response.raise_for_status()
  211. data = response.json()
  212. posts = data.get("data", [])
  213. # 缓存完整结果(以序号为 key)
  214. _search_cache.clear()
  215. for idx, post in enumerate(posts):
  216. _search_cache[idx + 1] = post
  217. # 构建摘要列表(带序号)
  218. summary_list = []
  219. for idx, post in enumerate(posts):
  220. body = post.get("body_text", "") or ""
  221. summary_list.append({
  222. "index": idx + 1,
  223. "channel_content_id": post.get("channel_content_id"),
  224. "title": post.get("title"),
  225. "body_text": body[:100] + ("..." if len(body) > 100 else ""),
  226. "like_count": post.get("like_count"),
  227. "channel": post.get("channel"),
  228. "link": post.get("link"),
  229. "content_type": post.get("content_type"),
  230. "publish_timestamp": post.get("publish_timestamp"),
  231. })
  232. # 拼接封面图网格
  233. images = []
  234. collage_b64 = await _build_collage(posts)
  235. if collage_b64:
  236. images.append({
  237. "type": "base64",
  238. "media_type": "image/png",
  239. "data": collage_b64
  240. })
  241. output_data = {
  242. "code": data.get("code"),
  243. "message": data.get("message"),
  244. "data": summary_list
  245. }
  246. return ToolResult(
  247. title=f"搜索结果: {keyword} ({channel_value})",
  248. output=json.dumps(output_data, ensure_ascii=False, indent=2),
  249. long_term_memory=f"Searched '{keyword}' on {channel_value}, found {len(posts)} posts. Use select_post(index) to view full details of a specific post.",
  250. images=images
  251. )
  252. except httpx.HTTPStatusError as e:
  253. return ToolResult(
  254. title="搜索失败",
  255. output="",
  256. error=f"HTTP error {e.response.status_code}: {e.response.text}"
  257. )
  258. except Exception as e:
  259. return ToolResult(
  260. title="搜索失败",
  261. output="",
  262. error=str(e)
  263. )
  264. @tool(
  265. display={
  266. "zh": {
  267. "name": "帖子详情",
  268. "params": {
  269. "index": "帖子序号"
  270. }
  271. },
  272. "en": {
  273. "name": "Select Post",
  274. "params": {
  275. "index": "Post index"
  276. }
  277. }
  278. }
  279. )
  280. async def select_post(
  281. index: int,
  282. ) -> ToolResult:
  283. """
  284. 查看帖子详情
  285. 从最近一次 search_posts 的搜索结果中,根据序号选取指定帖子并返回完整内容(全部正文、全部图片、视频等)。
  286. 需要先调用 search_posts 进行搜索。
  287. Args:
  288. index: 帖子序号,来自 search_posts 返回结果中的 index 字段(从 1 开始)
  289. Returns:
  290. ToolResult 包含该帖子的完整信息和所有图片。
  291. """
  292. post = _search_cache.get(index)
  293. if not post:
  294. return ToolResult(
  295. title="未找到帖子",
  296. output="",
  297. error=f"未找到序号 {index} 的帖子,请先调用 search_posts 搜索。"
  298. )
  299. # 返回所有图片
  300. images = []
  301. for img_url in post.get("images", []):
  302. if img_url:
  303. images.append({
  304. "type": "url",
  305. "url": img_url
  306. })
  307. return ToolResult(
  308. title=f"帖子详情 #{index}: {post.get('title', '')}",
  309. output=json.dumps(post, ensure_ascii=False, indent=2),
  310. long_term_memory=f"Viewed post detail #{index}: {post.get('title', '')}",
  311. images=images
  312. )
  313. @tool(
  314. display={
  315. "zh": {
  316. "name": "获取搜索关键词补全建议",
  317. "params": {
  318. "keyword": "搜索关键词",
  319. "channel": "搜索渠道"
  320. }
  321. },
  322. "en": {
  323. "name": "Get Search Suggestions",
  324. "params": {
  325. "keyword": "Search keyword",
  326. "channel": "Search channel"
  327. }
  328. }
  329. }
  330. )
  331. async def get_search_suggestions(
  332. keyword: str,
  333. channel: str = "xhs",
  334. ) -> ToolResult:
  335. """
  336. 获取搜索关键词补全建议
  337. 根据关键词在指定渠道平台获取搜索建议词。
  338. Args:
  339. keyword: 搜索关键词
  340. channel: 搜索渠道,支持的渠道有:
  341. - xhs: 小红书
  342. - wx: 微信
  343. - github: GitHub
  344. - toutiao: 头条
  345. - douyin: 抖音
  346. - bili: B站
  347. - zhihu: 知乎
  348. Returns:
  349. ToolResult 包含建议词数据:
  350. {
  351. "code": 0, # 状态码,0 表示成功
  352. "message": "success", # 状态消息
  353. "data": [ # 建议词数据
  354. {
  355. "type": "xhs", # 渠道类型
  356. "list": [ # 建议词列表
  357. {
  358. "name": "彩虹染发" # 建议词
  359. }
  360. ]
  361. }
  362. ]
  363. }
  364. """
  365. try:
  366. # 处理 channel 参数,支持枚举和字符串
  367. channel_value = channel.value if isinstance(channel, SuggestSearchChannel) else channel
  368. url = f"{BASE_URL}/suggest"
  369. payload = {
  370. "type": channel_value,
  371. "keyword": keyword,
  372. }
  373. async with httpx.AsyncClient(timeout=DEFAULT_TIMEOUT) as client:
  374. response = await client.post(
  375. url,
  376. json=payload,
  377. headers={"Content-Type": "application/json"},
  378. )
  379. response.raise_for_status()
  380. data = response.json()
  381. # 计算建议词数量
  382. suggestion_count = 0
  383. for item in data.get("data", []):
  384. suggestion_count += len(item.get("list", []))
  385. return ToolResult(
  386. title=f"建议词: {keyword} ({channel_value})",
  387. output=json.dumps(data, ensure_ascii=False, indent=2),
  388. long_term_memory=f"Got {suggestion_count} suggestions for '{keyword}' on {channel_value}"
  389. )
  390. except httpx.HTTPStatusError as e:
  391. return ToolResult(
  392. title="获取建议词失败",
  393. output="",
  394. error=f"HTTP error {e.response.status_code}: {e.response.text}"
  395. )
  396. except Exception as e:
  397. return ToolResult(
  398. title="获取建议词失败",
  399. output="",
  400. error=str(e)
  401. )