|
|
@@ -1,18 +1,24 @@
|
|
|
"""
|
|
|
搜索工具模块
|
|
|
|
|
|
-提供帖子搜索和建议词搜索功能,支持多个渠道平台。
|
|
|
+提供帖子搜索、帖子详情查看和建议词搜索功能,支持多个渠道平台。
|
|
|
|
|
|
主要功能:
|
|
|
-1. search_posts - 帖子搜索
|
|
|
-2. get_search_suggestions - 获取平台的搜索补全建议词
|
|
|
+1. search_posts - 帖子搜索(浏览模式:封面图+标题+内容截断)
|
|
|
+2. select_post - 帖子详情(从搜索结果中选取单个帖子的完整内容)
|
|
|
+3. get_search_suggestions - 获取平台的搜索补全建议词
|
|
|
"""
|
|
|
|
|
|
+import asyncio
|
|
|
+import base64
|
|
|
+import io
|
|
|
import json
|
|
|
+import math
|
|
|
from enum import Enum
|
|
|
-from typing import Any, Dict
|
|
|
+from typing import Any, Dict, List, Optional
|
|
|
|
|
|
import httpx
|
|
|
+from PIL import Image, ImageDraw, ImageFont
|
|
|
|
|
|
from agent.tools import tool, ToolResult
|
|
|
|
|
|
@@ -21,6 +27,116 @@ from agent.tools import tool, ToolResult
|
|
|
BASE_URL = "http://aigc-channel.aiddit.com/aigc/channel"
|
|
|
DEFAULT_TIMEOUT = 60.0
|
|
|
|
|
|
+# 搜索结果缓存,以序号为 key
|
|
|
+_search_cache: Dict[int, Dict[str, Any]] = {}
|
|
|
+
|
|
|
+# 拼接图配置
|
|
|
+THUMB_WIDTH = 250
|
|
|
+THUMB_HEIGHT = 250
|
|
|
+TEXT_HEIGHT = 80
|
|
|
+GRID_COLS = 5
|
|
|
+PADDING = 12
|
|
|
+BG_COLOR = (255, 255, 255)
|
|
|
+TEXT_COLOR = (30, 30, 30)
|
|
|
+INDEX_COLOR = (220, 60, 60)
|
|
|
+
|
|
|
+
|
|
|
+def _truncate_text(text: str, max_len: int = 14) -> str:
|
|
|
+ """截断文本,超出部分用省略号"""
|
|
|
+ return text[:max_len] + "..." if len(text) > max_len else text
|
|
|
+
|
|
|
+
|
|
|
+async def _download_image(client: httpx.AsyncClient, url: str) -> Optional[Image.Image]:
|
|
|
+ """下载单张图片,失败返回 None"""
|
|
|
+ try:
|
|
|
+ resp = await client.get(url, timeout=15.0)
|
|
|
+ resp.raise_for_status()
|
|
|
+ return Image.open(io.BytesIO(resp.content)).convert("RGB")
|
|
|
+ except Exception:
|
|
|
+ return None
|
|
|
+
|
|
|
+
|
|
|
+async def _build_collage(posts: List[Dict[str, Any]]) -> Optional[str]:
|
|
|
+ """
|
|
|
+ 将帖子封面图+序号+标题拼接成网格图,返回 base64 编码的 PNG。
|
|
|
+ 每个格子:序号 + 封面图 + 标题
|
|
|
+ """
|
|
|
+ if not posts:
|
|
|
+ return None
|
|
|
+
|
|
|
+ # 收集有封面图的帖子,记录原始序号
|
|
|
+ items = []
|
|
|
+ for idx, post in enumerate(posts):
|
|
|
+ imgs = post.get("images", [])
|
|
|
+ cover_url = imgs[0] if imgs else None
|
|
|
+ if cover_url:
|
|
|
+ items.append({
|
|
|
+ "url": cover_url,
|
|
|
+ "title": post.get("title", "") or "",
|
|
|
+ "index": idx + 1,
|
|
|
+ })
|
|
|
+ if not items:
|
|
|
+ return None
|
|
|
+
|
|
|
+ # 并发下载封面图
|
|
|
+ async with httpx.AsyncClient() as client:
|
|
|
+ tasks = [_download_image(client, item["url"]) for item in items]
|
|
|
+ downloaded = await asyncio.gather(*tasks)
|
|
|
+
|
|
|
+ # 过滤下载失败的
|
|
|
+ valid = [(item, img) for item, img in zip(items, downloaded) if img is not None]
|
|
|
+ if not valid:
|
|
|
+ return None
|
|
|
+
|
|
|
+ cols = min(GRID_COLS, len(valid))
|
|
|
+ rows = math.ceil(len(valid) / cols)
|
|
|
+ cell_w = THUMB_WIDTH + PADDING
|
|
|
+ cell_h = THUMB_HEIGHT + TEXT_HEIGHT + PADDING
|
|
|
+ canvas_w = cols * cell_w + PADDING
|
|
|
+ canvas_h = rows * cell_h + PADDING
|
|
|
+
|
|
|
+ canvas = Image.new("RGB", (canvas_w, canvas_h), BG_COLOR)
|
|
|
+ draw = ImageDraw.Draw(canvas)
|
|
|
+
|
|
|
+ # 尝试加载字体
|
|
|
+ try:
|
|
|
+ font_title = ImageFont.truetype("msyh.ttc", 16)
|
|
|
+ font_index = ImageFont.truetype("msyh.ttc", 22)
|
|
|
+ except Exception:
|
|
|
+ try:
|
|
|
+ font_title = ImageFont.truetype("arial.ttf", 16)
|
|
|
+ font_index = ImageFont.truetype("arial.ttf", 22)
|
|
|
+ except Exception:
|
|
|
+ font_title = ImageFont.load_default()
|
|
|
+ font_index = font_title
|
|
|
+
|
|
|
+ for item, img in valid:
|
|
|
+ idx = item["index"]
|
|
|
+ col = (idx - 1) % cols
|
|
|
+ row = (idx - 1) // cols
|
|
|
+ x = PADDING + col * cell_w
|
|
|
+ y = PADDING + row * cell_h
|
|
|
+
|
|
|
+ # 缩放封面图
|
|
|
+ thumb = img.resize((THUMB_WIDTH, THUMB_HEIGHT), Image.LANCZOS)
|
|
|
+ canvas.paste(thumb, (x, y))
|
|
|
+
|
|
|
+ # 左上角写序号(带背景)
|
|
|
+ index_text = f" {idx} "
|
|
|
+ bbox = draw.textbbox((0, 0), index_text, font=font_index)
|
|
|
+ tw, th = bbox[2] - bbox[0], bbox[3] - bbox[1]
|
|
|
+ draw.rectangle([x, y, x + tw + 4, y + th + 4], fill=INDEX_COLOR)
|
|
|
+ draw.text((x + 2, y + 2), index_text, fill=(255, 255, 255), font=font_index)
|
|
|
+
|
|
|
+ # 写标题
|
|
|
+ title_text = _truncate_text(item["title"], max_len=16)
|
|
|
+ draw.text((x, y + THUMB_HEIGHT + 6), title_text, fill=TEXT_COLOR, font=font_title)
|
|
|
+
|
|
|
+ # 转 base64
|
|
|
+ buf = io.BytesIO()
|
|
|
+ canvas.save(buf, format="PNG")
|
|
|
+ return base64.b64encode(buf.getvalue()).decode("utf-8")
|
|
|
+
|
|
|
|
|
|
class PostSearchChannel(str, Enum):
|
|
|
"""
|
|
|
@@ -59,7 +175,6 @@ class SuggestSearchChannel(str, Enum):
|
|
|
"channel": "搜索渠道",
|
|
|
"cursor": "分页游标",
|
|
|
"max_count": "返回条数",
|
|
|
- "include_images": "是否包含图片",
|
|
|
"content_type": "内容类型-视频/图文"
|
|
|
}
|
|
|
},
|
|
|
@@ -70,7 +185,6 @@ class SuggestSearchChannel(str, Enum):
|
|
|
"channel": "Search channel",
|
|
|
"cursor": "Pagination cursor",
|
|
|
"max_count": "Max results",
|
|
|
- "include_images": "Include images",
|
|
|
"content_type": "content type-视频/图文"
|
|
|
}
|
|
|
}
|
|
|
@@ -80,14 +194,14 @@ async def search_posts(
|
|
|
keyword: str,
|
|
|
channel: str = "xhs",
|
|
|
cursor: str = "0",
|
|
|
- max_count: int = 5,
|
|
|
- include_images: bool = False,
|
|
|
+ max_count: int = 20,
|
|
|
content_type: str = ""
|
|
|
) -> ToolResult:
|
|
|
"""
|
|
|
- 帖子搜索
|
|
|
+ 帖子搜索(浏览模式)
|
|
|
|
|
|
- 根据关键词在指定渠道平台搜索帖子内容。
|
|
|
+ 根据关键词在指定渠道平台搜索帖子,返回封面图+标题+内容摘要,用于快速浏览。
|
|
|
+ 如需查看某个帖子的完整内容,请使用 select_post 工具。
|
|
|
|
|
|
Args:
|
|
|
keyword: 搜索关键词
|
|
|
@@ -102,33 +216,15 @@ async def search_posts(
|
|
|
- zhihu: 知乎
|
|
|
- weibo: 微博
|
|
|
cursor: 分页游标,默认为 "0"(第一页)
|
|
|
- max_count: 返回的最大条数,默认为 5
|
|
|
- include_images: 是否将帖子中的图片传给 LLM 查看,默认为 False
|
|
|
+ max_count: 返回的最大条数,默认为 20
|
|
|
content_type:内容类型-视频/图文,默认不传为不限制类型
|
|
|
|
|
|
Returns:
|
|
|
- ToolResult 包含搜索结果:
|
|
|
- {
|
|
|
- "code": 0, # 状态码,0 表示成功
|
|
|
- "message": "success", # 状态消息
|
|
|
- "data": [ # 帖子列表
|
|
|
- {
|
|
|
- "channel_content_id": "68dd03db000000000303beb2", # 内容唯一ID
|
|
|
- "title": "", # 标题
|
|
|
- "content_type": "note", # 内容类型
|
|
|
- "body_text": "", # 正文内容
|
|
|
- "like_count": 127, # 点赞数
|
|
|
- "publish_timestamp": 1759314907000, # 发布时间戳(毫秒)
|
|
|
- "images": ["https://xxx.webp"], # 图片列表
|
|
|
- "videos": [], # 视频列表
|
|
|
- "channel": "xhs", # 来源渠道
|
|
|
- "link": "xxx" # 原文链接
|
|
|
- }
|
|
|
- ]
|
|
|
- }
|
|
|
+ ToolResult 包含搜索结果摘要列表(封面图+标题+内容截断),
|
|
|
+ 可通过 channel_content_id 调用 select_post 查看完整内容。
|
|
|
"""
|
|
|
+ global _search_cache
|
|
|
try:
|
|
|
- # 处理 channel 参数,支持枚举和字符串
|
|
|
channel_value = channel.value if isinstance(channel, PostSearchChannel) else channel
|
|
|
|
|
|
url = f"{BASE_URL}/data"
|
|
|
@@ -149,27 +245,49 @@ async def search_posts(
|
|
|
response.raise_for_status()
|
|
|
data = response.json()
|
|
|
|
|
|
- # 计算结果数量
|
|
|
- result_count = len(data.get("data", []))
|
|
|
+ posts = data.get("data", [])
|
|
|
+
|
|
|
+ # 缓存完整结果(以序号为 key)
|
|
|
+ _search_cache.clear()
|
|
|
+ for idx, post in enumerate(posts):
|
|
|
+ _search_cache[idx + 1] = post
|
|
|
|
|
|
- # 提取图片 URL 并构建 images 列表供 LLM 查看(仅当 include_images=True 时)
|
|
|
+ # 构建摘要列表(带序号)
|
|
|
+ summary_list = []
|
|
|
+ for idx, post in enumerate(posts):
|
|
|
+ body = post.get("body_text", "") or ""
|
|
|
+ summary_list.append({
|
|
|
+ "index": idx + 1,
|
|
|
+ "channel_content_id": post.get("channel_content_id"),
|
|
|
+ "title": post.get("title"),
|
|
|
+ "body_text": body[:100] + ("..." if len(body) > 100 else ""),
|
|
|
+ "like_count": post.get("like_count"),
|
|
|
+ "channel": post.get("channel"),
|
|
|
+ "link": post.get("link"),
|
|
|
+ "content_type": post.get("content_type"),
|
|
|
+ "publish_timestamp": post.get("publish_timestamp"),
|
|
|
+ })
|
|
|
+
|
|
|
+ # 拼接封面图网格
|
|
|
images = []
|
|
|
- if include_images:
|
|
|
- for post in data.get("data", []):
|
|
|
- for img_url in post.get("images", [])[:3]: # 每个帖子最多取前3张图
|
|
|
- if img_url:
|
|
|
- images.append({
|
|
|
- "type": "url",
|
|
|
- "url": img_url
|
|
|
- })
|
|
|
- # 限制总图片数量,避免过多
|
|
|
- if len(images) >= 10:
|
|
|
- break
|
|
|
+ collage_b64 = await _build_collage(posts)
|
|
|
+ if collage_b64:
|
|
|
+ images.append({
|
|
|
+ "type": "base64",
|
|
|
+ "media_type": "image/png",
|
|
|
+ "data": collage_b64
|
|
|
+ })
|
|
|
+
|
|
|
+ output_data = {
|
|
|
+ "code": data.get("code"),
|
|
|
+ "message": data.get("message"),
|
|
|
+ "data": summary_list
|
|
|
+ }
|
|
|
|
|
|
return ToolResult(
|
|
|
title=f"搜索结果: {keyword} ({channel_value})",
|
|
|
- output=json.dumps(data, ensure_ascii=False, indent=2),
|
|
|
- long_term_memory=f"Searched '{keyword}' on {channel_value}, found {result_count} posts",
|
|
|
+ output=json.dumps(output_data, ensure_ascii=False, indent=2),
|
|
|
+ long_term_memory=f"Searched '{keyword}' on {channel_value}, found {len(posts)} posts. Use select_post(index) to view full details of a specific post.",
|
|
|
images=images
|
|
|
)
|
|
|
except httpx.HTTPStatusError as e:
|
|
|
@@ -186,6 +304,62 @@ async def search_posts(
|
|
|
)
|
|
|
|
|
|
|
|
|
+@tool(
|
|
|
+ display={
|
|
|
+ "zh": {
|
|
|
+ "name": "帖子详情",
|
|
|
+ "params": {
|
|
|
+ "index": "帖子序号"
|
|
|
+ }
|
|
|
+ },
|
|
|
+ "en": {
|
|
|
+ "name": "Select Post",
|
|
|
+ "params": {
|
|
|
+ "index": "Post index"
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+)
|
|
|
+async def select_post(
|
|
|
+ index: int,
|
|
|
+) -> ToolResult:
|
|
|
+ """
|
|
|
+ 查看帖子详情
|
|
|
+
|
|
|
+ 从最近一次 search_posts 的搜索结果中,根据序号选取指定帖子并返回完整内容(全部正文、全部图片、视频等)。
|
|
|
+ 需要先调用 search_posts 进行搜索。
|
|
|
+
|
|
|
+ Args:
|
|
|
+ index: 帖子序号,来自 search_posts 返回结果中的 index 字段(从 1 开始)
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ ToolResult 包含该帖子的完整信息和所有图片。
|
|
|
+ """
|
|
|
+ post = _search_cache.get(index)
|
|
|
+ if not post:
|
|
|
+ return ToolResult(
|
|
|
+ title="未找到帖子",
|
|
|
+ output="",
|
|
|
+ error=f"未找到序号 {index} 的帖子,请先调用 search_posts 搜索。"
|
|
|
+ )
|
|
|
+
|
|
|
+ # 返回所有图片
|
|
|
+ images = []
|
|
|
+ for img_url in post.get("images", []):
|
|
|
+ if img_url:
|
|
|
+ images.append({
|
|
|
+ "type": "url",
|
|
|
+ "url": img_url
|
|
|
+ })
|
|
|
+
|
|
|
+ return ToolResult(
|
|
|
+ title=f"帖子详情 #{index}: {post.get('title', '')}",
|
|
|
+ output=json.dumps(post, ensure_ascii=False, indent=2),
|
|
|
+ long_term_memory=f"Viewed post detail #{index}: {post.get('title', '')}",
|
|
|
+ images=images
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
@tool(
|
|
|
display={
|
|
|
"zh": {
|