| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265 |
- """
- 搜索工具模块
- 提供帖子搜索和建议词搜索功能,支持多个渠道平台。
- 主要功能:
- 1. search_posts - 帖子搜索
- 2. get_search_suggestions - 获取平台的搜索补全建议词
- """
- import json
- from enum import Enum
- from typing import Any, Dict
- import httpx
- from agent import tool
- from agent.tools.models import ToolResult
- # API 基础配置
- BASE_URL = "http://aigc-channel.aiddit.com/aigc/channel"
- DEFAULT_TIMEOUT = 60.0
- class PostSearchChannel(str, Enum):
- """
- 帖子搜索支持的渠道类型
- """
- XHS = "xhs" # 小红书
- GZH = "gzh" # 公众号
- SPH = "sph" # 视频号
- GITHUB = "github" # GitHub
- TOUTIAO = "toutiao" # 头条
- DOUYIN = "douyin" # 抖音
- BILI = "bili" # B站
- ZHIHU = "zhihu" # 知乎
- WEIBO = "weibo" # 微博
- class SuggestSearchChannel(str, Enum):
- """
- 建议词搜索支持的渠道类型
- """
- XHS = "xhs" # 小红书
- WX = "wx" # 微信
- GITHUB = "github" # GitHub
- TOUTIAO = "toutiao" # 头条
- DOUYIN = "douyin" # 抖音
- BILI = "bili" # B站
- ZHIHU = "zhihu" # 知乎
- @tool(
- display={
- "zh": {
- "name": "帖子搜索",
- "params": {
- "keyword": "搜索关键词",
- "channel": "搜索渠道",
- "cursor": "分页游标",
- "max_count": "返回条数"
- }
- },
- "en": {
- "name": "Search Posts",
- "params": {
- "keyword": "Search keyword",
- "channel": "Search channel",
- "cursor": "Pagination cursor",
- "max_count": "Max results"
- }
- }
- }
- )
- async def search_posts(
- keyword: str,
- channel: str = "xhs",
- cursor: str = "0",
- max_count: int = 5,
- uid: str = "",
- ) -> ToolResult:
- """
- 帖子搜索
- 根据关键词在指定渠道平台搜索帖子内容。
- Args:
- keyword: 搜索关键词
- channel: 搜索渠道,支持的渠道有:
- - xhs: 小红书
- - gzh: 公众号
- - sph: 视频号
- - github: GitHub
- - toutiao: 头条
- - douyin: 抖音
- - bili: B站
- - zhihu: 知乎
- - weibo: 微博
- cursor: 分页游标,默认为 "0"(第一页)
- max_count: 返回的最大条数,默认为 5
- uid: 用户ID(自动注入)
- Returns:
- ToolResult 包含搜索结果:
- {
- "code": 0, # 状态码,0 表示成功
- "message": "success", # 状态消息
- "data": [ # 帖子列表
- {
- "channel_content_id": "68dd03db000000000303beb2", # 内容唯一ID
- "title": "", # 标题
- "content_type": "note", # 内容类型
- "body_text": "", # 正文内容
- "like_count": 127, # 点赞数
- "publish_timestamp": 1759314907000, # 发布时间戳(毫秒)
- "images": ["https://xxx.webp"], # 图片列表
- "videos": [], # 视频列表
- "channel": "xhs", # 来源渠道
- "link": "xxx" # 原文链接
- }
- ]
- }
- """
- try:
- # 处理 channel 参数,支持枚举和字符串
- channel_value = channel.value if isinstance(channel, PostSearchChannel) else channel
- url = f"{BASE_URL}/data"
- payload = {
- "type": channel_value,
- "keyword": keyword,
- "cursor": cursor,
- "max_count": max_count,
- }
- async with httpx.AsyncClient(timeout=DEFAULT_TIMEOUT) as client:
- response = await client.post(
- url,
- json=payload,
- headers={"Content-Type": "application/json"},
- )
- response.raise_for_status()
- data = response.json()
- # 计算结果数量
- result_count = len(data.get("data", []))
- return ToolResult(
- title=f"搜索结果: {keyword} ({channel_value})",
- output=json.dumps(data, ensure_ascii=False, indent=2),
- long_term_memory=f"Searched '{keyword}' on {channel_value}, found {result_count} posts"
- )
- except httpx.HTTPStatusError as e:
- return ToolResult(
- title="搜索失败",
- output="",
- error=f"HTTP error {e.response.status_code}: {e.response.text}"
- )
- except Exception as e:
- return ToolResult(
- title="搜索失败",
- output="",
- error=str(e)
- )
- @tool(
- display={
- "zh": {
- "name": "获取搜索关键词补全建议",
- "params": {
- "keyword": "搜索关键词",
- "channel": "搜索渠道"
- }
- },
- "en": {
- "name": "Get Search Suggestions",
- "params": {
- "keyword": "Search keyword",
- "channel": "Search channel"
- }
- }
- }
- )
- async def get_search_suggestions(
- keyword: str,
- channel: str = "xhs",
- uid: str = "",
- ) -> ToolResult:
- """
- 获取搜索关键词补全建议
- 根据关键词在指定渠道平台获取搜索建议词。
- Args:
- keyword: 搜索关键词
- channel: 搜索渠道,支持的渠道有:
- - xhs: 小红书
- - wx: 微信
- - github: GitHub
- - toutiao: 头条
- - douyin: 抖音
- - bili: B站
- - zhihu: 知乎
- uid: 用户ID(自动注入)
- Returns:
- ToolResult 包含建议词数据:
- {
- "code": 0, # 状态码,0 表示成功
- "message": "success", # 状态消息
- "data": [ # 建议词数据
- {
- "type": "xhs", # 渠道类型
- "list": [ # 建议词列表
- {
- "name": "彩虹染发" # 建议词
- }
- ]
- }
- ]
- }
- """
- try:
- # 处理 channel 参数,支持枚举和字符串
- channel_value = channel.value if isinstance(channel, SuggestSearchChannel) else channel
- url = f"{BASE_URL}/suggest"
- payload = {
- "type": channel_value,
- "keyword": keyword,
- }
- async with httpx.AsyncClient(timeout=DEFAULT_TIMEOUT) as client:
- response = await client.post(
- url,
- json=payload,
- headers={"Content-Type": "application/json"},
- )
- response.raise_for_status()
- data = response.json()
- # 计算建议词数量
- suggestion_count = 0
- for item in data.get("data", []):
- suggestion_count += len(item.get("list", []))
- return ToolResult(
- title=f"建议词: {keyword} ({channel_value})",
- output=json.dumps(data, ensure_ascii=False, indent=2),
- long_term_memory=f"Got {suggestion_count} suggestions for '{keyword}' on {channel_value}"
- )
- except httpx.HTTPStatusError as e:
- return ToolResult(
- title="获取建议词失败",
- output="",
- error=f"HTTP error {e.response.status_code}: {e.response.text}"
- )
- except Exception as e:
- return ToolResult(
- title="获取建议词失败",
- output="",
- error=str(e)
- )
|