search.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504
  1. """
  2. 搜索工具模块
  3. 提供帖子搜索、帖子详情查看和建议词搜索功能,支持多个渠道平台。
  4. 主要功能:
  5. 1. search_posts - 帖子搜索(浏览模式:封面图+标题+内容截断)
  6. 2. select_post - 帖子详情(从搜索结果中选取单个帖子的完整内容)
  7. 3. get_search_suggestions - 获取平台的搜索补全建议词
  8. """
  9. import asyncio
  10. import base64
  11. import io
  12. import json
  13. import math
  14. import textwrap
  15. from enum import Enum
  16. from typing import Any, Dict, List, Optional
  17. import httpx
  18. from PIL import Image, ImageDraw, ImageFont
  19. from agent.tools import tool, ToolResult
  20. # API 基础配置
  21. BASE_URL = "http://aigc-channel.aiddit.com/aigc/channel"
  22. DEFAULT_TIMEOUT = 60.0
  23. # 搜索结果缓存,以序号为 key
  24. _search_cache: Dict[int, Dict[str, Any]] = {}
  25. # 拼接图配置
  26. THUMB_WIDTH = 250
  27. THUMB_HEIGHT = 250
  28. TEXT_HEIGHT = 80
  29. GRID_COLS = 5
  30. PADDING = 12
  31. BG_COLOR = (255, 255, 255)
  32. TEXT_COLOR = (30, 30, 30)
  33. INDEX_COLOR = (220, 60, 60)
  34. def _truncate_text(text: str, max_len: int = 14) -> str:
  35. """截断文本,超出部分用省略号"""
  36. return text[:max_len] + "..." if len(text) > max_len else text
  37. async def _download_image(client: httpx.AsyncClient, url: str) -> Optional[Image.Image]:
  38. """下载单张图片,失败返回 None"""
  39. try:
  40. resp = await client.get(url, timeout=15.0)
  41. resp.raise_for_status()
  42. return Image.open(io.BytesIO(resp.content)).convert("RGB")
  43. except Exception:
  44. return None
  45. async def _build_collage(posts: List[Dict[str, Any]]) -> Optional[str]:
  46. """
  47. 将帖子封面图+序号+标题拼接成网格图,返回 base64 编码的 PNG。
  48. 每个格子:序号 + 封面图 + 标题
  49. """
  50. if not posts:
  51. return None
  52. # 收集有封面图的帖子,记录原始序号
  53. items = []
  54. for idx, post in enumerate(posts):
  55. imgs = post.get("images", [])
  56. cover_url = imgs[0] if imgs else None
  57. if cover_url:
  58. items.append({
  59. "url": cover_url,
  60. "title": post.get("title", "") or "",
  61. "index": idx + 1,
  62. })
  63. if not items:
  64. return None
  65. # 并发下载封面图
  66. async with httpx.AsyncClient() as client:
  67. tasks = [_download_image(client, item["url"]) for item in items]
  68. downloaded = await asyncio.gather(*tasks)
  69. # 过滤下载失败的
  70. valid = [(item, img) for item, img in zip(items, downloaded) if img is not None]
  71. if not valid:
  72. return None
  73. cols = min(GRID_COLS, len(valid))
  74. rows = math.ceil(len(valid) / cols)
  75. cell_w = THUMB_WIDTH + PADDING
  76. cell_h = THUMB_HEIGHT + TEXT_HEIGHT + PADDING
  77. canvas_w = cols * cell_w + PADDING
  78. canvas_h = rows * cell_h + PADDING
  79. canvas = Image.new("RGB", (canvas_w, canvas_h), BG_COLOR)
  80. draw = ImageDraw.Draw(canvas)
  81. # 尝试加载字体(跨平台中文支持)
  82. font_title = None
  83. font_index = None
  84. # 按优先级尝试不同平台的中文字体
  85. font_candidates = [
  86. "msyh.ttc", # Windows 微软雅黑
  87. "simhei.ttf", # Windows 黑体
  88. "simsun.ttc", # Windows 宋体
  89. "/System/Library/Fonts/PingFang.ttc", # macOS 苹方
  90. "/usr/share/fonts/truetype/droid/DroidSansFallbackFull.ttf", # Linux
  91. "/usr/share/fonts/truetype/wqy/wqy-microhei.ttc", # Linux WenQuanYi
  92. "/usr/share/fonts/opentype/noto/NotoSansCJK-Regular.ttc", # Linux Noto
  93. ]
  94. for font_path in font_candidates:
  95. try:
  96. font_title = ImageFont.truetype(font_path, 16)
  97. font_index = ImageFont.truetype(font_path, 32)
  98. break
  99. except Exception:
  100. continue
  101. # 如果都失败,使用默认字体(可能不支持中文)
  102. if not font_title:
  103. font_title = ImageFont.load_default()
  104. font_index = font_title
  105. for item, img in valid:
  106. idx = item["index"]
  107. col = (idx - 1) % cols
  108. row = (idx - 1) // cols
  109. x = PADDING + col * cell_w
  110. y = PADDING + row * cell_h
  111. # 等比缩放封面图,保持原始比例,居中放置
  112. scale = min(THUMB_WIDTH / img.width, THUMB_HEIGHT / img.height)
  113. new_w = int(img.width * scale)
  114. new_h = int(img.height * scale)
  115. thumb = img.resize((new_w, new_h), Image.LANCZOS)
  116. offset_x = x + (THUMB_WIDTH - new_w) // 2
  117. offset_y = y + (THUMB_HEIGHT - new_h) // 2
  118. canvas.paste(thumb, (offset_x, offset_y))
  119. # 左上角写序号(带背景),固定大小,跟随图片位置
  120. index_text = str(idx)
  121. idx_x = offset_x
  122. idx_y = offset_y + 4
  123. box_size = 52
  124. draw.rectangle([idx_x, idx_y, idx_x + box_size, idx_y + box_size], fill=INDEX_COLOR)
  125. # 序号居中绘制
  126. bbox = draw.textbbox((0, 0), index_text, font=font_index)
  127. tw, th = bbox[2] - bbox[0], bbox[3] - bbox[1]
  128. text_x = idx_x + (box_size - tw) // 2
  129. text_y = idx_y + (box_size - th) // 2
  130. draw.text((text_x, text_y), index_text, fill=(255, 255, 255), font=font_index)
  131. # 写标题(完整显示,按像素宽度自动换行)
  132. title = item["title"] or ""
  133. if title:
  134. words = list(title) # 逐字符拆分,兼容中英文
  135. lines = []
  136. current_line = ""
  137. for ch in words:
  138. test_line = current_line + ch
  139. bbox_line = draw.textbbox((0, 0), test_line, font=font_title)
  140. if bbox_line[2] - bbox_line[0] > THUMB_WIDTH:
  141. if current_line:
  142. lines.append(current_line)
  143. current_line = ch
  144. else:
  145. current_line = test_line
  146. if current_line:
  147. lines.append(current_line)
  148. for line_i, line in enumerate(lines):
  149. draw.text((x, y + THUMB_HEIGHT + 6 + line_i * 22), line, fill=TEXT_COLOR, font=font_title)
  150. # 转 base64
  151. buf = io.BytesIO()
  152. canvas.save(buf, format="PNG")
  153. return base64.b64encode(buf.getvalue()).decode("utf-8")
  154. class PostSearchChannel(str, Enum):
  155. """
  156. 帖子搜索支持的渠道类型
  157. """
  158. XHS = "xhs" # 小红书
  159. GZH = "gzh" # 公众号
  160. SPH = "sph" # 视频号
  161. GITHUB = "github" # GitHub
  162. TOUTIAO = "toutiao" # 头条
  163. DOUYIN = "douyin" # 抖音
  164. BILI = "bili" # B站
  165. ZHIHU = "zhihu" # 知乎
  166. WEIBO = "weibo" # 微博
  167. class SuggestSearchChannel(str, Enum):
  168. """
  169. 建议词搜索支持的渠道类型
  170. """
  171. XHS = "xhs" # 小红书
  172. WX = "wx" # 微信
  173. GITHUB = "github" # GitHub
  174. TOUTIAO = "toutiao" # 头条
  175. DOUYIN = "douyin" # 抖音
  176. BILI = "bili" # B站
  177. ZHIHU = "zhihu" # 知乎
  178. @tool(
  179. display={
  180. "zh": {
  181. "name": "帖子搜索",
  182. "params": {
  183. "keyword": "搜索关键词",
  184. "channel": "搜索渠道",
  185. "cursor": "分页游标",
  186. "max_count": "返回条数",
  187. "content_type": "内容类型-视频/图文"
  188. }
  189. },
  190. "en": {
  191. "name": "Search Posts",
  192. "params": {
  193. "keyword": "Search keyword",
  194. "channel": "Search channel",
  195. "cursor": "Pagination cursor",
  196. "max_count": "Max results",
  197. "content_type": "content type-视频/图文"
  198. }
  199. }
  200. }
  201. )
  202. async def search_posts(
  203. keyword: str,
  204. channel: str = "xhs",
  205. cursor: str = "0",
  206. max_count: int = 20,
  207. content_type: str = ""
  208. ) -> ToolResult:
  209. """
  210. 帖子搜索(浏览模式)
  211. 根据关键词在指定渠道平台搜索帖子,返回封面图+标题+内容摘要,用于快速浏览。
  212. 如需查看某个帖子的完整内容,请使用 select_post 工具。
  213. Args:
  214. keyword: 搜索关键词
  215. channel: 搜索渠道,支持的渠道有:
  216. - xhs: 小红书
  217. - gzh: 公众号
  218. - sph: 视频号
  219. - github: GitHub
  220. - toutiao: 头条
  221. - douyin: 抖音
  222. - bili: B站
  223. - zhihu: 知乎
  224. - weibo: 微博
  225. cursor: 分页游标,默认为 "0"(第一页)
  226. max_count: 返回的最大条数,默认为 20
  227. content_type:内容类型-视频/图文,默认不传为不限制类型
  228. Returns:
  229. ToolResult 包含搜索结果摘要列表(封面图+标题+内容截断),
  230. 可通过 channel_content_id 调用 select_post 查看完整内容。
  231. """
  232. global _search_cache
  233. try:
  234. channel_value = channel.value if isinstance(channel, PostSearchChannel) else channel
  235. url = f"{BASE_URL}/data"
  236. payload = {
  237. "type": channel_value,
  238. "keyword": keyword,
  239. "cursor": cursor,
  240. "max_count": max_count,
  241. "content_type": content_type
  242. }
  243. async with httpx.AsyncClient(timeout=DEFAULT_TIMEOUT) as client:
  244. response = await client.post(
  245. url,
  246. json=payload,
  247. headers={"Content-Type": "application/json"},
  248. )
  249. response.raise_for_status()
  250. data = response.json()
  251. posts = data.get("data", [])
  252. # 缓存完整结果(以序号为 key)
  253. _search_cache.clear()
  254. for idx, post in enumerate(posts):
  255. _search_cache[idx + 1] = post
  256. # 构建摘要列表(带序号)
  257. summary_list = []
  258. for idx, post in enumerate(posts):
  259. body = post.get("body_text", "") or ""
  260. summary_list.append({
  261. "index": idx + 1,
  262. "channel_content_id": post.get("channel_content_id"),
  263. "title": post.get("title"),
  264. "body_text": body[:100] + ("..." if len(body) > 100 else ""),
  265. "like_count": post.get("like_count"),
  266. "collect_count": post.get("collect_count"),
  267. "comment_count": post.get("comment_count"),
  268. "channel": post.get("channel"),
  269. "link": post.get("link"),
  270. "content_type": post.get("content_type"),
  271. "publish_timestamp": post.get("publish_timestamp"),
  272. })
  273. # 拼接封面图网格
  274. images = []
  275. collage_b64 = await _build_collage(posts)
  276. if collage_b64:
  277. images.append({
  278. "type": "base64",
  279. "media_type": "image/png",
  280. "data": collage_b64
  281. })
  282. output_data = {
  283. "code": data.get("code"),
  284. "message": data.get("message"),
  285. "data": summary_list
  286. }
  287. return ToolResult(
  288. title=f"搜索结果: {keyword} ({channel_value})",
  289. output=json.dumps(output_data, ensure_ascii=False, indent=2),
  290. long_term_memory=f"Searched '{keyword}' on {channel_value}, found {len(posts)} posts. Use select_post(index) to view full details of a specific post.",
  291. images=images
  292. )
  293. except httpx.HTTPStatusError as e:
  294. return ToolResult(
  295. title="搜索失败",
  296. output="",
  297. error=f"HTTP error {e.response.status_code}: {e.response.text}"
  298. )
  299. except Exception as e:
  300. return ToolResult(
  301. title="搜索失败",
  302. output="",
  303. error=str(e)
  304. )
  305. @tool(
  306. display={
  307. "zh": {
  308. "name": "帖子详情",
  309. "params": {
  310. "index": "帖子序号"
  311. }
  312. },
  313. "en": {
  314. "name": "Select Post",
  315. "params": {
  316. "index": "Post index"
  317. }
  318. }
  319. }
  320. )
  321. async def select_post(
  322. index: int,
  323. ) -> ToolResult:
  324. """
  325. 查看帖子详情
  326. 从最近一次 search_posts 的搜索结果中,根据序号选取指定帖子并返回完整内容(全部正文、全部图片、视频等)。
  327. 需要先调用 search_posts 进行搜索。
  328. Args:
  329. index: 帖子序号,来自 search_posts 返回结果中的 index 字段(从 1 开始)
  330. Returns:
  331. ToolResult 包含该帖子的完整信息和所有图片。
  332. """
  333. post = _search_cache.get(index)
  334. if not post:
  335. return ToolResult(
  336. title="未找到帖子",
  337. output="",
  338. error=f"未找到序号 {index} 的帖子,请先调用 search_posts 搜索。"
  339. )
  340. # 返回所有图片
  341. images = []
  342. for img_url in post.get("images", []):
  343. if img_url:
  344. images.append({
  345. "type": "url",
  346. "url": img_url
  347. })
  348. return ToolResult(
  349. title=f"帖子详情 #{index}: {post.get('title', '')}",
  350. output=json.dumps(post, ensure_ascii=False, indent=2),
  351. long_term_memory=f"Viewed post detail #{index}: {post.get('title', '')}",
  352. images=images
  353. )
  354. @tool(
  355. display={
  356. "zh": {
  357. "name": "获取搜索关键词补全建议",
  358. "params": {
  359. "keyword": "搜索关键词",
  360. "channel": "搜索渠道"
  361. }
  362. },
  363. "en": {
  364. "name": "Get Search Suggestions",
  365. "params": {
  366. "keyword": "Search keyword",
  367. "channel": "Search channel"
  368. }
  369. }
  370. }
  371. )
  372. async def get_search_suggestions(
  373. keyword: str,
  374. channel: str = "xhs",
  375. ) -> ToolResult:
  376. """
  377. 获取搜索关键词补全建议
  378. 根据关键词在指定渠道平台获取搜索建议词。
  379. Args:
  380. keyword: 搜索关键词
  381. channel: 搜索渠道,支持的渠道有:
  382. - xhs: 小红书
  383. - wx: 微信
  384. - github: GitHub
  385. - toutiao: 头条
  386. - douyin: 抖音
  387. - bili: B站
  388. - zhihu: 知乎
  389. Returns:
  390. ToolResult 包含建议词数据:
  391. {
  392. "code": 0, # 状态码,0 表示成功
  393. "message": "success", # 状态消息
  394. "data": [ # 建议词数据
  395. {
  396. "type": "xhs", # 渠道类型
  397. "list": [ # 建议词列表
  398. {
  399. "name": "彩虹染发" # 建议词
  400. }
  401. ]
  402. }
  403. ]
  404. }
  405. """
  406. try:
  407. # 处理 channel 参数,支持枚举和字符串
  408. channel_value = channel.value if isinstance(channel, SuggestSearchChannel) else channel
  409. url = f"{BASE_URL}/suggest"
  410. payload = {
  411. "type": channel_value,
  412. "keyword": keyword,
  413. }
  414. async with httpx.AsyncClient(timeout=DEFAULT_TIMEOUT) as client:
  415. response = await client.post(
  416. url,
  417. json=payload,
  418. headers={"Content-Type": "application/json"},
  419. )
  420. response.raise_for_status()
  421. data = response.json()
  422. # 计算建议词数量
  423. suggestion_count = 0
  424. for item in data.get("data", []):
  425. suggestion_count += len(item.get("list", []))
  426. return ToolResult(
  427. title=f"建议词: {keyword} ({channel_value})",
  428. output=json.dumps(data, ensure_ascii=False, indent=2),
  429. long_term_memory=f"Got {suggestion_count} suggestions for '{keyword}' on {channel_value}"
  430. )
  431. except httpx.HTTPStatusError as e:
  432. return ToolResult(
  433. title="获取建议词失败",
  434. output="",
  435. error=f"HTTP error {e.response.status_code}: {e.response.text}"
  436. )
  437. except Exception as e:
  438. return ToolResult(
  439. title="获取建议词失败",
  440. output="",
  441. error=str(e)
  442. )