| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460 |
- """
- 搜索工具模块
- 提供帖子搜索、帖子详情查看和建议词搜索功能,支持多个渠道平台。
- 主要功能:
- 1. search_posts - 帖子搜索(浏览模式:封面图+标题+内容截断)
- 2. select_post - 帖子详情(从搜索结果中选取单个帖子的完整内容)
- 3. get_search_suggestions - 获取平台的搜索补全建议词
- """
- import asyncio
- import base64
- import io
- import json
- import math
- from enum import Enum
- from typing import Any, Dict, List, Optional
- import httpx
- from PIL import Image, ImageDraw, ImageFont
- from agent.tools import tool, ToolResult
- # API 基础配置
- BASE_URL = "http://aigc-channel.aiddit.com/aigc/channel"
- DEFAULT_TIMEOUT = 60.0
- # 搜索结果缓存,以序号为 key
- _search_cache: Dict[int, Dict[str, Any]] = {}
- # 拼接图配置
- THUMB_WIDTH = 250
- THUMB_HEIGHT = 250
- TEXT_HEIGHT = 80
- GRID_COLS = 5
- PADDING = 12
- BG_COLOR = (255, 255, 255)
- TEXT_COLOR = (30, 30, 30)
- INDEX_COLOR = (220, 60, 60)
- def _truncate_text(text: str, max_len: int = 14) -> str:
- """截断文本,超出部分用省略号"""
- return text[:max_len] + "..." if len(text) > max_len else text
- async def _download_image(client: httpx.AsyncClient, url: str) -> Optional[Image.Image]:
- """下载单张图片,失败返回 None"""
- try:
- resp = await client.get(url, timeout=15.0)
- resp.raise_for_status()
- return Image.open(io.BytesIO(resp.content)).convert("RGB")
- except Exception:
- return None
- async def _build_collage(posts: List[Dict[str, Any]]) -> Optional[str]:
- """
- 将帖子封面图+序号+标题拼接成网格图,返回 base64 编码的 PNG。
- 每个格子:序号 + 封面图 + 标题
- """
- if not posts:
- return None
- # 收集有封面图的帖子,记录原始序号
- items = []
- for idx, post in enumerate(posts):
- imgs = post.get("images", [])
- cover_url = imgs[0] if imgs else None
- if cover_url:
- items.append({
- "url": cover_url,
- "title": post.get("title", "") or "",
- "index": idx + 1,
- })
- if not items:
- return None
- # 并发下载封面图
- async with httpx.AsyncClient() as client:
- tasks = [_download_image(client, item["url"]) for item in items]
- downloaded = await asyncio.gather(*tasks)
- # 过滤下载失败的
- valid = [(item, img) for item, img in zip(items, downloaded) if img is not None]
- if not valid:
- return None
- cols = min(GRID_COLS, len(valid))
- rows = math.ceil(len(valid) / cols)
- cell_w = THUMB_WIDTH + PADDING
- cell_h = THUMB_HEIGHT + TEXT_HEIGHT + PADDING
- canvas_w = cols * cell_w + PADDING
- canvas_h = rows * cell_h + PADDING
- canvas = Image.new("RGB", (canvas_w, canvas_h), BG_COLOR)
- draw = ImageDraw.Draw(canvas)
- # 尝试加载字体
- try:
- font_title = ImageFont.truetype("msyh.ttc", 16)
- font_index = ImageFont.truetype("msyh.ttc", 32)
- except Exception:
- try:
- font_title = ImageFont.truetype("arial.ttf", 16)
- font_index = ImageFont.truetype("arial.ttf", 32)
- except Exception:
- font_title = ImageFont.load_default()
- font_index = font_title
- for item, img in valid:
- idx = item["index"]
- col = (idx - 1) % cols
- row = (idx - 1) // cols
- x = PADDING + col * cell_w
- y = PADDING + row * cell_h
- # 缩放封面图
- thumb = img.resize((THUMB_WIDTH, THUMB_HEIGHT), Image.LANCZOS)
- canvas.paste(thumb, (x, y))
- # 左上角写序号(带背景)
- index_text = f" {idx} "
- bbox = draw.textbbox((0, 0), index_text, font=font_index)
- tw, th = bbox[2] - bbox[0], bbox[3] - bbox[1]
- # 增加背景块的 padding,确保完全覆盖数字
- pad_x, pad_y = 8, 6
- draw.rectangle([x, y, x + tw + pad_x * 2, y + th + pad_y * 2], fill=INDEX_COLOR)
- draw.text((x + pad_x, y + pad_y), index_text, fill=(255, 255, 255), font=font_index)
- # 写标题
- title_text = _truncate_text(item["title"], max_len=16)
- draw.text((x, y + THUMB_HEIGHT + 6), title_text, fill=TEXT_COLOR, font=font_title)
- # 转 base64
- buf = io.BytesIO()
- canvas.save(buf, format="PNG")
- return base64.b64encode(buf.getvalue()).decode("utf-8")
- class PostSearchChannel(str, Enum):
- """
- 帖子搜索支持的渠道类型
- """
- XHS = "xhs" # 小红书
- GZH = "gzh" # 公众号
- SPH = "sph" # 视频号
- GITHUB = "github" # GitHub
- TOUTIAO = "toutiao" # 头条
- DOUYIN = "douyin" # 抖音
- BILI = "bili" # B站
- ZHIHU = "zhihu" # 知乎
- WEIBO = "weibo" # 微博
- class SuggestSearchChannel(str, Enum):
- """
- 建议词搜索支持的渠道类型
- """
- XHS = "xhs" # 小红书
- WX = "wx" # 微信
- GITHUB = "github" # GitHub
- TOUTIAO = "toutiao" # 头条
- DOUYIN = "douyin" # 抖音
- BILI = "bili" # B站
- ZHIHU = "zhihu" # 知乎
- @tool(
- display={
- "zh": {
- "name": "帖子搜索",
- "params": {
- "keyword": "搜索关键词",
- "channel": "搜索渠道",
- "cursor": "分页游标",
- "max_count": "返回条数",
- "content_type": "内容类型-视频/图文"
- }
- },
- "en": {
- "name": "Search Posts",
- "params": {
- "keyword": "Search keyword",
- "channel": "Search channel",
- "cursor": "Pagination cursor",
- "max_count": "Max results",
- "content_type": "content type-视频/图文"
- }
- }
- }
- )
- async def search_posts(
- keyword: str,
- channel: str = "xhs",
- cursor: str = "0",
- max_count: int = 20,
- content_type: str = ""
- ) -> ToolResult:
- """
- 帖子搜索(浏览模式)
- 根据关键词在指定渠道平台搜索帖子,返回封面图+标题+内容摘要,用于快速浏览。
- 如需查看某个帖子的完整内容,请使用 select_post 工具。
- Args:
- keyword: 搜索关键词
- channel: 搜索渠道,支持的渠道有:
- - xhs: 小红书
- - gzh: 公众号
- - sph: 视频号
- - github: GitHub
- - toutiao: 头条
- - douyin: 抖音
- - bili: B站
- - zhihu: 知乎
- - weibo: 微博
- cursor: 分页游标,默认为 "0"(第一页)
- max_count: 返回的最大条数,默认为 20
- content_type:内容类型-视频/图文,默认不传为不限制类型
- Returns:
- ToolResult 包含搜索结果摘要列表(封面图+标题+内容截断),
- 可通过 channel_content_id 调用 select_post 查看完整内容。
- """
- global _search_cache
- try:
- channel_value = channel.value if isinstance(channel, PostSearchChannel) else channel
- url = f"{BASE_URL}/data"
- payload = {
- "type": channel_value,
- "keyword": keyword,
- "cursor": cursor,
- "max_count": max_count,
- "content_type": content_type
- }
- async with httpx.AsyncClient(timeout=DEFAULT_TIMEOUT) as client:
- response = await client.post(
- url,
- json=payload,
- headers={"Content-Type": "application/json"},
- )
- response.raise_for_status()
- data = response.json()
- posts = data.get("data", [])
- # 缓存完整结果(以序号为 key)
- _search_cache.clear()
- for idx, post in enumerate(posts):
- _search_cache[idx + 1] = post
- # 构建摘要列表(带序号)
- summary_list = []
- for idx, post in enumerate(posts):
- body = post.get("body_text", "") or ""
- summary_list.append({
- "index": idx + 1,
- "channel_content_id": post.get("channel_content_id"),
- "title": post.get("title"),
- "body_text": body[:100] + ("..." if len(body) > 100 else ""),
- "like_count": post.get("like_count"),
- "channel": post.get("channel"),
- "link": post.get("link"),
- "content_type": post.get("content_type"),
- "publish_timestamp": post.get("publish_timestamp"),
- })
- # 拼接封面图网格
- images = []
- collage_b64 = await _build_collage(posts)
- if collage_b64:
- images.append({
- "type": "base64",
- "media_type": "image/png",
- "data": collage_b64
- })
- output_data = {
- "code": data.get("code"),
- "message": data.get("message"),
- "data": summary_list
- }
- return ToolResult(
- title=f"搜索结果: {keyword} ({channel_value})",
- output=json.dumps(output_data, ensure_ascii=False, indent=2),
- long_term_memory=f"Searched '{keyword}' on {channel_value}, found {len(posts)} posts. Use select_post(index) to view full details of a specific post.",
- images=images
- )
- except httpx.HTTPStatusError as e:
- return ToolResult(
- title="搜索失败",
- output="",
- error=f"HTTP error {e.response.status_code}: {e.response.text}"
- )
- except Exception as e:
- return ToolResult(
- title="搜索失败",
- output="",
- error=str(e)
- )
- @tool(
- display={
- "zh": {
- "name": "帖子详情",
- "params": {
- "index": "帖子序号"
- }
- },
- "en": {
- "name": "Select Post",
- "params": {
- "index": "Post index"
- }
- }
- }
- )
- async def select_post(
- index: int,
- ) -> ToolResult:
- """
- 查看帖子详情
- 从最近一次 search_posts 的搜索结果中,根据序号选取指定帖子并返回完整内容(全部正文、全部图片、视频等)。
- 需要先调用 search_posts 进行搜索。
- Args:
- index: 帖子序号,来自 search_posts 返回结果中的 index 字段(从 1 开始)
- Returns:
- ToolResult 包含该帖子的完整信息和所有图片。
- """
- post = _search_cache.get(index)
- if not post:
- return ToolResult(
- title="未找到帖子",
- output="",
- error=f"未找到序号 {index} 的帖子,请先调用 search_posts 搜索。"
- )
- # 返回所有图片
- images = []
- for img_url in post.get("images", []):
- if img_url:
- images.append({
- "type": "url",
- "url": img_url
- })
- return ToolResult(
- title=f"帖子详情 #{index}: {post.get('title', '')}",
- output=json.dumps(post, ensure_ascii=False, indent=2),
- long_term_memory=f"Viewed post detail #{index}: {post.get('title', '')}",
- images=images
- )
- @tool(
- display={
- "zh": {
- "name": "获取搜索关键词补全建议",
- "params": {
- "keyword": "搜索关键词",
- "channel": "搜索渠道"
- }
- },
- "en": {
- "name": "Get Search Suggestions",
- "params": {
- "keyword": "Search keyword",
- "channel": "Search channel"
- }
- }
- }
- )
- async def get_search_suggestions(
- keyword: str,
- channel: str = "xhs",
- ) -> ToolResult:
- """
- 获取搜索关键词补全建议
- 根据关键词在指定渠道平台获取搜索建议词。
- Args:
- keyword: 搜索关键词
- channel: 搜索渠道,支持的渠道有:
- - xhs: 小红书
- - wx: 微信
- - github: GitHub
- - toutiao: 头条
- - douyin: 抖音
- - bili: B站
- - zhihu: 知乎
- Returns:
- ToolResult 包含建议词数据:
- {
- "code": 0, # 状态码,0 表示成功
- "message": "success", # 状态消息
- "data": [ # 建议词数据
- {
- "type": "xhs", # 渠道类型
- "list": [ # 建议词列表
- {
- "name": "彩虹染发" # 建议词
- }
- ]
- }
- ]
- }
- """
- try:
- # 处理 channel 参数,支持枚举和字符串
- channel_value = channel.value if isinstance(channel, SuggestSearchChannel) else channel
- url = f"{BASE_URL}/suggest"
- payload = {
- "type": channel_value,
- "keyword": keyword,
- }
- async with httpx.AsyncClient(timeout=DEFAULT_TIMEOUT) as client:
- response = await client.post(
- url,
- json=payload,
- headers={"Content-Type": "application/json"},
- )
- response.raise_for_status()
- data = response.json()
- # 计算建议词数量
- suggestion_count = 0
- for item in data.get("data", []):
- suggestion_count += len(item.get("list", []))
- return ToolResult(
- title=f"建议词: {keyword} ({channel_value})",
- output=json.dumps(data, ensure_ascii=False, indent=2),
- long_term_memory=f"Got {suggestion_count} suggestions for '{keyword}' on {channel_value}"
- )
- except httpx.HTTPStatusError as e:
- return ToolResult(
- title="获取建议词失败",
- output="",
- error=f"HTTP error {e.response.status_code}: {e.response.text}"
- )
- except Exception as e:
- return ToolResult(
- title="获取建议词失败",
- output="",
- error=str(e)
- )
|