""" 搜索工具模块 提供帖子搜索、帖子详情查看和建议词搜索功能,支持多个渠道平台。 主要功能: 1. search_posts - 帖子搜索(浏览模式:封面图+标题+内容截断) 2. select_post - 帖子详情(从搜索结果中选取单个帖子的完整内容) 3. get_search_suggestions - 获取平台的搜索补全建议词 """ import asyncio import base64 import io import json import math import textwrap from enum import Enum from typing import Any, Dict, List, Optional import httpx from PIL import Image, ImageDraw, ImageFont from agent.tools import tool, ToolResult # API 基础配置 BASE_URL = "http://aigc-channel.aiddit.com/aigc/channel" DEFAULT_TIMEOUT = 60.0 # 搜索结果缓存,以序号为 key _search_cache: Dict[int, Dict[str, Any]] = {} # 拼接图配置 THUMB_WIDTH = 250 THUMB_HEIGHT = 250 TEXT_HEIGHT = 80 GRID_COLS = 5 PADDING = 12 BG_COLOR = (255, 255, 255) TEXT_COLOR = (30, 30, 30) INDEX_COLOR = (220, 60, 60) def _truncate_text(text: str, max_len: int = 14) -> str: """截断文本,超出部分用省略号""" return text[:max_len] + "..." if len(text) > max_len else text async def _download_image(client: httpx.AsyncClient, url: str) -> Optional[Image.Image]: """下载单张图片,失败返回 None""" try: resp = await client.get(url, timeout=15.0) resp.raise_for_status() return Image.open(io.BytesIO(resp.content)).convert("RGB") except Exception: return None async def _build_collage(posts: List[Dict[str, Any]]) -> Optional[str]: """ 将帖子封面图+序号+标题拼接成网格图,返回 base64 编码的 PNG。 每个格子:序号 + 封面图 + 标题 """ if not posts: return None # 收集有封面图的帖子,记录原始序号 items = [] for idx, post in enumerate(posts): imgs = post.get("images", []) cover_url = imgs[0] if imgs else None if cover_url: items.append({ "url": cover_url, "title": post.get("title", "") or "", "index": idx + 1, }) if not items: return None # 并发下载封面图 async with httpx.AsyncClient() as client: tasks = [_download_image(client, item["url"]) for item in items] downloaded = await asyncio.gather(*tasks) # 过滤下载失败的 valid = [(item, img) for item, img in zip(items, downloaded) if img is not None] if not valid: return None cols = min(GRID_COLS, len(valid)) rows = math.ceil(len(valid) / cols) cell_w = THUMB_WIDTH + PADDING cell_h = THUMB_HEIGHT + TEXT_HEIGHT + PADDING canvas_w = cols * cell_w + PADDING canvas_h = rows * cell_h + PADDING canvas = Image.new("RGB", (canvas_w, canvas_h), BG_COLOR) draw = ImageDraw.Draw(canvas) # 尝试加载字体(跨平台中文支持) font_title = None font_index = None # 按优先级尝试不同平台的中文字体 font_candidates = [ "msyh.ttc", # Windows 微软雅黑 "simhei.ttf", # Windows 黑体 "simsun.ttc", # Windows 宋体 "/System/Library/Fonts/PingFang.ttc", # macOS 苹方 "/usr/share/fonts/truetype/droid/DroidSansFallbackFull.ttf", # Linux "/usr/share/fonts/truetype/wqy/wqy-microhei.ttc", # Linux WenQuanYi "/usr/share/fonts/opentype/noto/NotoSansCJK-Regular.ttc", # Linux Noto ] for font_path in font_candidates: try: font_title = ImageFont.truetype(font_path, 16) font_index = ImageFont.truetype(font_path, 32) break except Exception: continue # 如果都失败,使用默认字体(可能不支持中文) if not font_title: font_title = ImageFont.load_default() font_index = font_title for item, img in valid: idx = item["index"] col = (idx - 1) % cols row = (idx - 1) // cols x = PADDING + col * cell_w y = PADDING + row * cell_h # 等比缩放封面图,保持原始比例,居中放置 scale = min(THUMB_WIDTH / img.width, THUMB_HEIGHT / img.height) new_w = int(img.width * scale) new_h = int(img.height * scale) thumb = img.resize((new_w, new_h), Image.LANCZOS) offset_x = x + (THUMB_WIDTH - new_w) // 2 offset_y = y + (THUMB_HEIGHT - new_h) // 2 canvas.paste(thumb, (offset_x, offset_y)) # 左上角写序号(带背景),固定大小,跟随图片位置 index_text = str(idx) idx_x = offset_x idx_y = offset_y + 4 box_size = 52 draw.rectangle([idx_x, idx_y, idx_x + box_size, idx_y + box_size], fill=INDEX_COLOR) # 序号居中绘制 bbox = draw.textbbox((0, 0), index_text, font=font_index) tw, th = bbox[2] - bbox[0], bbox[3] - bbox[1] text_x = idx_x + (box_size - tw) // 2 text_y = idx_y + (box_size - th) // 2 draw.text((text_x, text_y), index_text, fill=(255, 255, 255), font=font_index) # 写标题(完整显示,按像素宽度自动换行) title = item["title"] or "" if title: words = list(title) # 逐字符拆分,兼容中英文 lines = [] current_line = "" for ch in words: test_line = current_line + ch bbox_line = draw.textbbox((0, 0), test_line, font=font_title) if bbox_line[2] - bbox_line[0] > THUMB_WIDTH: if current_line: lines.append(current_line) current_line = ch else: current_line = test_line if current_line: lines.append(current_line) for line_i, line in enumerate(lines): draw.text((x, y + THUMB_HEIGHT + 6 + line_i * 22), line, fill=TEXT_COLOR, font=font_title) # 转 base64 buf = io.BytesIO() canvas.save(buf, format="PNG") return base64.b64encode(buf.getvalue()).decode("utf-8") class PostSearchChannel(str, Enum): """ 帖子搜索支持的渠道类型 """ XHS = "xhs" # 小红书 GZH = "gzh" # 公众号 SPH = "sph" # 视频号 GITHUB = "github" # GitHub TOUTIAO = "toutiao" # 头条 DOUYIN = "douyin" # 抖音 BILI = "bili" # B站 ZHIHU = "zhihu" # 知乎 WEIBO = "weibo" # 微博 class SuggestSearchChannel(str, Enum): """ 建议词搜索支持的渠道类型 """ XHS = "xhs" # 小红书 WX = "wx" # 微信 GITHUB = "github" # GitHub TOUTIAO = "toutiao" # 头条 DOUYIN = "douyin" # 抖音 BILI = "bili" # B站 ZHIHU = "zhihu" # 知乎 @tool( display={ "zh": { "name": "帖子搜索", "params": { "keyword": "搜索关键词", "channel": "搜索渠道", "cursor": "分页游标", "max_count": "返回条数", "content_type": "内容类型-视频/图文" } }, "en": { "name": "Search Posts", "params": { "keyword": "Search keyword", "channel": "Search channel", "cursor": "Pagination cursor", "max_count": "Max results", "content_type": "content type-视频/图文" } } } ) async def search_posts( keyword: str, channel: str = "xhs", cursor: str = "0", max_count: int = 20, content_type: str = "" ) -> ToolResult: """ 帖子搜索(浏览模式) 根据关键词在指定渠道平台搜索帖子,返回封面图+标题+内容摘要,用于快速浏览。 如需查看某个帖子的完整内容,请使用 select_post 工具。 Args: keyword: 搜索关键词 channel: 搜索渠道,支持的渠道有: - xhs: 小红书 - gzh: 公众号 - sph: 视频号 - github: GitHub - toutiao: 头条 - douyin: 抖音 - bili: B站 - zhihu: 知乎 - weibo: 微博 cursor: 分页游标,默认为 "0"(第一页) max_count: 返回的最大条数,默认为 20 content_type:内容类型-视频/图文,默认不传为不限制类型 Returns: ToolResult 包含搜索结果摘要列表(封面图+标题+内容截断), 可通过 channel_content_id 调用 select_post 查看完整内容。 """ global _search_cache try: channel_value = channel.value if isinstance(channel, PostSearchChannel) else channel url = f"{BASE_URL}/data" payload = { "type": channel_value, "keyword": keyword, "cursor": cursor, "max_count": max_count, "content_type": content_type } async with httpx.AsyncClient(timeout=DEFAULT_TIMEOUT) as client: response = await client.post( url, json=payload, headers={"Content-Type": "application/json"}, ) response.raise_for_status() data = response.json() posts = data.get("data", []) # 缓存完整结果(以序号为 key) _search_cache.clear() for idx, post in enumerate(posts): _search_cache[idx + 1] = post # 构建摘要列表(带序号) summary_list = [] for idx, post in enumerate(posts): body = post.get("body_text", "") or "" summary_list.append({ "index": idx + 1, "channel_content_id": post.get("channel_content_id"), "title": post.get("title"), "body_text": body[:100] + ("..." if len(body) > 100 else ""), "like_count": post.get("like_count"), "collect_count": post.get("collect_count"), "comment_count": post.get("comment_count"), "channel": post.get("channel"), "link": post.get("link"), "content_type": post.get("content_type"), "publish_timestamp": post.get("publish_timestamp"), }) # 拼接封面图网格 images = [] collage_b64 = await _build_collage(posts) if collage_b64: images.append({ "type": "base64", "media_type": "image/png", "data": collage_b64 }) output_data = { "code": data.get("code"), "message": data.get("message"), "data": summary_list } return ToolResult( title=f"搜索结果: {keyword} ({channel_value})", output=json.dumps(output_data, ensure_ascii=False, indent=2), long_term_memory=f"Searched '{keyword}' on {channel_value}, found {len(posts)} posts. Use select_post(index) to view full details of a specific post.", images=images ) except httpx.HTTPStatusError as e: return ToolResult( title="搜索失败", output="", error=f"HTTP error {e.response.status_code}: {e.response.text}" ) except Exception as e: return ToolResult( title="搜索失败", output="", error=str(e) ) @tool( display={ "zh": { "name": "帖子详情", "params": { "index": "帖子序号" } }, "en": { "name": "Select Post", "params": { "index": "Post index" } } } ) async def select_post( index: int, ) -> ToolResult: """ 查看帖子详情 从最近一次 search_posts 的搜索结果中,根据序号选取指定帖子并返回完整内容(全部正文、全部图片、视频等)。 需要先调用 search_posts 进行搜索。 Args: index: 帖子序号,来自 search_posts 返回结果中的 index 字段(从 1 开始) Returns: ToolResult 包含该帖子的完整信息和所有图片。 """ post = _search_cache.get(index) if not post: return ToolResult( title="未找到帖子", output="", error=f"未找到序号 {index} 的帖子,请先调用 search_posts 搜索。" ) # 返回所有图片 images = [] for img_url in post.get("images", []): if img_url: images.append({ "type": "url", "url": img_url }) return ToolResult( title=f"帖子详情 #{index}: {post.get('title', '')}", output=json.dumps(post, ensure_ascii=False, indent=2), long_term_memory=f"Viewed post detail #{index}: {post.get('title', '')}", images=images ) @tool( display={ "zh": { "name": "获取搜索关键词补全建议", "params": { "keyword": "搜索关键词", "channel": "搜索渠道" } }, "en": { "name": "Get Search Suggestions", "params": { "keyword": "Search keyword", "channel": "Search channel" } } } ) async def get_search_suggestions( keyword: str, channel: str = "xhs", ) -> ToolResult: """ 获取搜索关键词补全建议 根据关键词在指定渠道平台获取搜索建议词。 Args: keyword: 搜索关键词 channel: 搜索渠道,支持的渠道有: - xhs: 小红书 - wx: 微信 - github: GitHub - toutiao: 头条 - douyin: 抖音 - bili: B站 - zhihu: 知乎 Returns: ToolResult 包含建议词数据: { "code": 0, # 状态码,0 表示成功 "message": "success", # 状态消息 "data": [ # 建议词数据 { "type": "xhs", # 渠道类型 "list": [ # 建议词列表 { "name": "彩虹染发" # 建议词 } ] } ] } """ try: # 处理 channel 参数,支持枚举和字符串 channel_value = channel.value if isinstance(channel, SuggestSearchChannel) else channel url = f"{BASE_URL}/suggest" payload = { "type": channel_value, "keyword": keyword, } async with httpx.AsyncClient(timeout=DEFAULT_TIMEOUT) as client: response = await client.post( url, json=payload, headers={"Content-Type": "application/json"}, ) response.raise_for_status() data = response.json() # 计算建议词数量 suggestion_count = 0 for item in data.get("data", []): suggestion_count += len(item.get("list", [])) return ToolResult( title=f"建议词: {keyword} ({channel_value})", output=json.dumps(data, ensure_ascii=False, indent=2), long_term_memory=f"Got {suggestion_count} suggestions for '{keyword}' on {channel_value}" ) except httpx.HTTPStatusError as e: return ToolResult( title="获取建议词失败", output="", error=f"HTTP error {e.response.status_code}: {e.response.text}" ) except Exception as e: return ToolResult( title="获取建议词失败", output="", error=str(e) )