tanjingyu 1 месяц назад
Родитель
Сommit
8de7ddc82e
3 измененных файлов с 242 добавлено и 0 удалено
  1. 3 0
      agent/tools/__init__.py
  2. 12 0
      tools/__init__.py
  3. 227 0
      tools/search.py

+ 3 - 0
agent/tools/__init__.py

@@ -6,6 +6,9 @@ from agent.tools.registry import ToolRegistry, tool, get_tool_registry
 from agent.tools.schema import SchemaGenerator
 from agent.tools.models import ToolResult, ToolContext, ToolContextImpl
 
+# 导入工具模块,触发 @tool 装饰器执行,完成工具注册
+import tools  # noqa: F401
+
 __all__ = [
 	"ToolRegistry",
 	"tool",

+ 12 - 0
tools/__init__.py

@@ -0,0 +1,12 @@
+"""
+工具模块
+
+导入此模块会自动注册所有工具到全局 ToolRegistry。
+"""
+
+from tools.search import search_posts, search_suggestions
+
+__all__ = [
+    "search_posts",
+    "search_suggestions",
+]

+ 227 - 0
tools/search.py

@@ -0,0 +1,227 @@
+"""
+搜索工具模块
+
+提供帖子搜索和建议词搜索功能,支持多个渠道平台。
+
+主要功能:
+1. search_posts - 帖子搜索
+2. search_suggestions - 建议词搜索
+"""
+
+from enum import Enum
+from typing import Any, Dict, Optional
+
+import httpx
+
+from agent import tool
+
+
+# API 基础配置
+BASE_URL = "http://aigc-channel.aiddit.com/aigc/channel"
+DEFAULT_TIMEOUT = 60.0
+
+
+class PostSearchChannel(str, Enum):
+    """
+    帖子搜索支持的渠道类型
+    """
+    XHS = "xhs"           # 小红书
+    GZH = "gzh"           # 公众号
+    SPH = "sph"           # 视频号
+    GITHUB = "github"     # GitHub
+    TOUTIAO = "toutiao"   # 头条
+    DOUYIN = "douyin"     # 抖音
+    BILI = "bili"         # B站
+    ZHIHU = "zhihu"       # 知乎
+    WEIBO = "weibo"       # 微博
+
+
+class SuggestSearchChannel(str, Enum):
+    """
+    建议词搜索支持的渠道类型
+    """
+    XHS = "xhs"           # 小红书
+    WX = "wx"             # 微信
+    GITHUB = "github"     # GitHub
+    TOUTIAO = "toutiao"   # 头条
+    DOUYIN = "douyin"     # 抖音
+    BILI = "bili"         # B站
+    ZHIHU = "zhihu"       # 知乎
+
+
+@tool(
+    display={
+        "zh": {
+            "name": "帖子搜索",
+            "params": {
+                "keyword": "搜索关键词",
+                "channel": "搜索渠道",
+                "cursor": "分页游标",
+                "max_count": "返回条数"
+            }
+        },
+        "en": {
+            "name": "Search Posts",
+            "params": {
+                "keyword": "Search keyword",
+                "channel": "Search channel",
+                "cursor": "Pagination cursor",
+                "max_count": "Max results"
+            }
+        }
+    }
+)
+async def search_posts(
+    keyword: str,
+    channel: str = "xhs",
+    cursor: str = "0",
+    max_count: int = 5,
+    uid: str = "",
+) -> Dict[str, Any]:
+    """
+    帖子搜索
+
+    根据关键词在指定渠道平台搜索帖子内容。
+
+    Args:
+        keyword: 搜索关键词
+        channel: 搜索渠道,支持的渠道有:
+            - xhs: 小红书
+            - gzh: 公众号
+            - sph: 视频号
+            - github: GitHub
+            - toutiao: 头条
+            - douyin: 抖音
+            - bili: B站
+            - zhihu: 知乎
+            - weibo: 微博
+        cursor: 分页游标,默认为 "0"(第一页)
+        max_count: 返回的最大条数,默认为 5
+        uid: 用户ID(自动注入)
+
+    Returns:
+        API 返回的原始响应,结构如下:
+        {
+            "code": 0,                    # 状态码,0 表示成功
+            "message": "success",         # 状态消息
+            "data": [                     # 帖子列表
+                {
+                    "channel_content_id": "68dd03db000000000303beb2",  # 内容唯一ID
+                    "title": "",                                       # 标题
+                    "content_type": "note",                            # 内容类型
+                    "body_text": "",                                   # 正文内容
+                    "like_count": 127,                                 # 点赞数
+                    "publish_timestamp": 1759314907000,                # 发布时间戳(毫秒)
+                    "images": ["https://xxx.webp"],                    # 图片列表
+                    "videos": [],                                      # 视频列表
+                    "channel": "xhs",                                  # 来源渠道
+                    "link": "xxx"                                      # 原文链接
+                }
+            ]
+        }
+
+    Raises:
+        httpx.HTTPStatusError: HTTP 请求返回非 2xx 状态码
+        httpx.TimeoutException: 请求超时
+        httpx.RequestError: 其他请求错误
+    """
+    # 处理 channel 参数,支持枚举和字符串
+    channel_value = channel.value if isinstance(channel, PostSearchChannel) else channel
+
+    url = f"{BASE_URL}/data"
+    payload = {
+        "type": channel_value,
+        "keyword": keyword,
+        "cursor": cursor,
+        "max_count": max_count,
+    }
+
+    async with httpx.AsyncClient(timeout=DEFAULT_TIMEOUT) as client:
+        response = await client.post(
+            url,
+            json=payload,
+            headers={"Content-Type": "application/json"},
+        )
+        response.raise_for_status()
+        return response.json()
+
+
+@tool(
+    display={
+        "zh": {
+            "name": "建议词搜索",
+            "params": {
+                "keyword": "搜索关键词",
+                "channel": "搜索渠道"
+            }
+        },
+        "en": {
+            "name": "Search Suggestions",
+            "params": {
+                "keyword": "Search keyword",
+                "channel": "Search channel"
+            }
+        }
+    }
+)
+async def search_suggestions(
+    keyword: str,
+    channel: str = "xhs",
+    uid: str = "",
+) -> Dict[str, Any]:
+    """
+    建议词搜索
+
+    根据关键词在指定渠道平台获取搜索建议词。
+
+    Args:
+        keyword: 搜索关键词
+        channel: 搜索渠道,支持的渠道有:
+            - xhs: 小红书
+            - wx: 微信
+            - github: GitHub
+            - toutiao: 头条
+            - douyin: 抖音
+            - bili: B站
+            - zhihu: 知乎
+        uid: 用户ID(自动注入)
+
+    Returns:
+        API 返回的原始响应,结构如下:
+        {
+            "code": 0,                    # 状态码,0 表示成功
+            "message": "success",         # 状态消息
+            "data": [                     # 建议词数据
+                {
+                    "type": "xhs",        # 渠道类型
+                    "list": [             # 建议词列表
+                        {
+                            "name": "彩虹染发"  # 建议词
+                        }
+                    ]
+                }
+            ]
+        }
+
+    Raises:
+        httpx.HTTPStatusError: HTTP 请求返回非 2xx 状态码
+        httpx.TimeoutException: 请求超时
+        httpx.RequestError: 其他请求错误
+    """
+    # 处理 channel 参数,支持枚举和字符串
+    channel_value = channel.value if isinstance(channel, SuggestSearchChannel) else channel
+
+    url = f"{BASE_URL}/suggest"
+    payload = {
+        "type": channel_value,
+        "keyword": keyword,
+    }
+
+    async with httpx.AsyncClient(timeout=DEFAULT_TIMEOUT) as client:
+        response = await client.post(
+            url,
+            json=payload,
+            headers={"Content-Type": "application/json"},
+        )
+        response.raise_for_status()
+        return response.json()