tanjingyu 1 месяц назад
Родитель
Сommit
159588c99d
2 измененных файлов с 227 добавлено и 227 удалено
  1. 227 0
      agent/tools/builtin/search.py
  2. 0 227
      tools/search.py

+ 227 - 0
agent/tools/builtin/search.py

@@ -0,0 +1,227 @@
+"""
+搜索工具模块
+
+提供帖子搜索和建议词搜索功能,支持多个渠道平台。
+
+主要功能:
+1. search_posts - 帖子搜索
+2. search_suggestions - 建议词搜索
+"""
+
+from enum import Enum
+from typing import Any, Dict, Optional
+
+import httpx
+
+from agent import tool
+
+
+# API 基础配置
+BASE_URL = "http://aigc-channel.aiddit.com/aigc/channel"
+DEFAULT_TIMEOUT = 60.0
+
+
+class PostSearchChannel(str, Enum):
+    """
+    帖子搜索支持的渠道类型
+    """
+    XHS = "xhs"           # 小红书
+    GZH = "gzh"           # 公众号
+    SPH = "sph"           # 视频号
+    GITHUB = "github"     # GitHub
+    TOUTIAO = "toutiao"   # 头条
+    DOUYIN = "douyin"     # 抖音
+    BILI = "bili"         # B站
+    ZHIHU = "zhihu"       # 知乎
+    WEIBO = "weibo"       # 微博
+
+
+class SuggestSearchChannel(str, Enum):
+    """
+    建议词搜索支持的渠道类型
+    """
+    XHS = "xhs"           # 小红书
+    WX = "wx"             # 微信
+    GITHUB = "github"     # GitHub
+    TOUTIAO = "toutiao"   # 头条
+    DOUYIN = "douyin"     # 抖音
+    BILI = "bili"         # B站
+    ZHIHU = "zhihu"       # 知乎
+
+
+@tool(
+    display={
+        "zh": {
+            "name": "帖子搜索",
+            "params": {
+                "keyword": "搜索关键词",
+                "channel": "搜索渠道",
+                "cursor": "分页游标",
+                "max_count": "返回条数"
+            }
+        },
+        "en": {
+            "name": "Search Posts",
+            "params": {
+                "keyword": "Search keyword",
+                "channel": "Search channel",
+                "cursor": "Pagination cursor",
+                "max_count": "Max results"
+            }
+        }
+    }
+)
+async def search_posts(
+    keyword: str,
+    channel: str = "xhs",
+    cursor: str = "0",
+    max_count: int = 5,
+    uid: str = "",
+) -> Dict[str, Any]:
+    """
+    帖子搜索
+
+    根据关键词在指定渠道平台搜索帖子内容。
+
+    Args:
+        keyword: 搜索关键词
+        channel: 搜索渠道,支持的渠道有:
+            - xhs: 小红书
+            - gzh: 公众号
+            - sph: 视频号
+            - github: GitHub
+            - toutiao: 头条
+            - douyin: 抖音
+            - bili: B站
+            - zhihu: 知乎
+            - weibo: 微博
+        cursor: 分页游标,默认为 "0"(第一页)
+        max_count: 返回的最大条数,默认为 5
+        uid: 用户ID(自动注入)
+
+    Returns:
+        API 返回的原始响应,结构如下:
+        {
+            "code": 0,                    # 状态码,0 表示成功
+            "message": "success",         # 状态消息
+            "data": [                     # 帖子列表
+                {
+                    "channel_content_id": "68dd03db000000000303beb2",  # 内容唯一ID
+                    "title": "",                                       # 标题
+                    "content_type": "note",                            # 内容类型
+                    "body_text": "",                                   # 正文内容
+                    "like_count": 127,                                 # 点赞数
+                    "publish_timestamp": 1759314907000,                # 发布时间戳(毫秒)
+                    "images": ["https://xxx.webp"],                    # 图片列表
+                    "videos": [],                                      # 视频列表
+                    "channel": "xhs",                                  # 来源渠道
+                    "link": "xxx"                                      # 原文链接
+                }
+            ]
+        }
+
+    Raises:
+        httpx.HTTPStatusError: HTTP 请求返回非 2xx 状态码
+        httpx.TimeoutException: 请求超时
+        httpx.RequestError: 其他请求错误
+    """
+    # 处理 channel 参数,支持枚举和字符串
+    channel_value = channel.value if isinstance(channel, PostSearchChannel) else channel
+
+    url = f"{BASE_URL}/data"
+    payload = {
+        "type": channel_value,
+        "keyword": keyword,
+        "cursor": cursor,
+        "max_count": max_count,
+    }
+
+    async with httpx.AsyncClient(timeout=DEFAULT_TIMEOUT) as client:
+        response = await client.post(
+            url,
+            json=payload,
+            headers={"Content-Type": "application/json"},
+        )
+        response.raise_for_status()
+        return response.json()
+
+
+@tool(
+    display={
+        "zh": {
+            "name": "建议词搜索",
+            "params": {
+                "keyword": "搜索关键词",
+                "channel": "搜索渠道"
+            }
+        },
+        "en": {
+            "name": "Search Suggestions",
+            "params": {
+                "keyword": "Search keyword",
+                "channel": "Search channel"
+            }
+        }
+    }
+)
+async def search_suggestions(
+    keyword: str,
+    channel: str = "xhs",
+    uid: str = "",
+) -> Dict[str, Any]:
+    """
+    建议词搜索
+
+    根据关键词在指定渠道平台获取搜索建议词。
+
+    Args:
+        keyword: 搜索关键词
+        channel: 搜索渠道,支持的渠道有:
+            - xhs: 小红书
+            - wx: 微信
+            - github: GitHub
+            - toutiao: 头条
+            - douyin: 抖音
+            - bili: B站
+            - zhihu: 知乎
+        uid: 用户ID(自动注入)
+
+    Returns:
+        API 返回的原始响应,结构如下:
+        {
+            "code": 0,                    # 状态码,0 表示成功
+            "message": "success",         # 状态消息
+            "data": [                     # 建议词数据
+                {
+                    "type": "xhs",        # 渠道类型
+                    "list": [             # 建议词列表
+                        {
+                            "name": "彩虹染发"  # 建议词
+                        }
+                    ]
+                }
+            ]
+        }
+
+    Raises:
+        httpx.HTTPStatusError: HTTP 请求返回非 2xx 状态码
+        httpx.TimeoutException: 请求超时
+        httpx.RequestError: 其他请求错误
+    """
+    # 处理 channel 参数,支持枚举和字符串
+    channel_value = channel.value if isinstance(channel, SuggestSearchChannel) else channel
+
+    url = f"{BASE_URL}/suggest"
+    payload = {
+        "type": channel_value,
+        "keyword": keyword,
+    }
+
+    async with httpx.AsyncClient(timeout=DEFAULT_TIMEOUT) as client:
+        response = await client.post(
+            url,
+            json=payload,
+            headers={"Content-Type": "application/json"},
+        )
+        response.raise_for_status()
+        return response.json()

+ 0 - 227
tools/search.py

@@ -1,227 +0,0 @@
-"""
-搜索工具模块
-
-提供帖子搜索和建议词搜索功能,支持多个渠道平台。
-
-主要功能:
-1. search_posts - 帖子搜索
-2. search_suggestions - 建议词搜索
-"""
-
-from enum import Enum
-from typing import Any, Dict, Optional
-
-import httpx
-
-from agent import tool
-
-
-# API 基础配置
-BASE_URL = "http://aigc-channel.aiddit.com/aigc/channel"
-DEFAULT_TIMEOUT = 60.0
-
-
-class PostSearchChannel(str, Enum):
-    """
-    帖子搜索支持的渠道类型
-    """
-    XHS = "xhs"           # 小红书
-    GZH = "gzh"           # 公众号
-    SPH = "sph"           # 视频号
-    GITHUB = "github"     # GitHub
-    TOUTIAO = "toutiao"   # 头条
-    DOUYIN = "douyin"     # 抖音
-    BILI = "bili"         # B站
-    ZHIHU = "zhihu"       # 知乎
-    WEIBO = "weibo"       # 微博
-
-
-class SuggestSearchChannel(str, Enum):
-    """
-    建议词搜索支持的渠道类型
-    """
-    XHS = "xhs"           # 小红书
-    WX = "wx"             # 微信
-    GITHUB = "github"     # GitHub
-    TOUTIAO = "toutiao"   # 头条
-    DOUYIN = "douyin"     # 抖音
-    BILI = "bili"         # B站
-    ZHIHU = "zhihu"       # 知乎
-
-
-@tool(
-    display={
-        "zh": {
-            "name": "帖子搜索",
-            "params": {
-                "keyword": "搜索关键词",
-                "channel": "搜索渠道",
-                "cursor": "分页游标",
-                "max_count": "返回条数"
-            }
-        },
-        "en": {
-            "name": "Search Posts",
-            "params": {
-                "keyword": "Search keyword",
-                "channel": "Search channel",
-                "cursor": "Pagination cursor",
-                "max_count": "Max results"
-            }
-        }
-    }
-)
-async def search_posts(
-    keyword: str,
-    channel: str = "xhs",
-    cursor: str = "0",
-    max_count: int = 5,
-    uid: str = "",
-) -> Dict[str, Any]:
-    """
-    帖子搜索
-
-    根据关键词在指定渠道平台搜索帖子内容。
-
-    Args:
-        keyword: 搜索关键词
-        channel: 搜索渠道,支持的渠道有:
-            - xhs: 小红书
-            - gzh: 公众号
-            - sph: 视频号
-            - github: GitHub
-            - toutiao: 头条
-            - douyin: 抖音
-            - bili: B站
-            - zhihu: 知乎
-            - weibo: 微博
-        cursor: 分页游标,默认为 "0"(第一页)
-        max_count: 返回的最大条数,默认为 5
-        uid: 用户ID(自动注入)
-
-    Returns:
-        API 返回的原始响应,结构如下:
-        {
-            "code": 0,                    # 状态码,0 表示成功
-            "message": "success",         # 状态消息
-            "data": [                     # 帖子列表
-                {
-                    "channel_content_id": "68dd03db000000000303beb2",  # 内容唯一ID
-                    "title": "",                                       # 标题
-                    "content_type": "note",                            # 内容类型
-                    "body_text": "",                                   # 正文内容
-                    "like_count": 127,                                 # 点赞数
-                    "publish_timestamp": 1759314907000,                # 发布时间戳(毫秒)
-                    "images": ["https://xxx.webp"],                    # 图片列表
-                    "videos": [],                                      # 视频列表
-                    "channel": "xhs",                                  # 来源渠道
-                    "link": "xxx"                                      # 原文链接
-                }
-            ]
-        }
-
-    Raises:
-        httpx.HTTPStatusError: HTTP 请求返回非 2xx 状态码
-        httpx.TimeoutException: 请求超时
-        httpx.RequestError: 其他请求错误
-    """
-    # 处理 channel 参数,支持枚举和字符串
-    channel_value = channel.value if isinstance(channel, PostSearchChannel) else channel
-
-    url = f"{BASE_URL}/data"
-    payload = {
-        "type": channel_value,
-        "keyword": keyword,
-        "cursor": cursor,
-        "max_count": max_count,
-    }
-
-    async with httpx.AsyncClient(timeout=DEFAULT_TIMEOUT) as client:
-        response = await client.post(
-            url,
-            json=payload,
-            headers={"Content-Type": "application/json"},
-        )
-        response.raise_for_status()
-        return response.json()
-
-
-@tool(
-    display={
-        "zh": {
-            "name": "建议词搜索",
-            "params": {
-                "keyword": "搜索关键词",
-                "channel": "搜索渠道"
-            }
-        },
-        "en": {
-            "name": "Search Suggestions",
-            "params": {
-                "keyword": "Search keyword",
-                "channel": "Search channel"
-            }
-        }
-    }
-)
-async def search_suggestions(
-    keyword: str,
-    channel: str = "xhs",
-    uid: str = "",
-) -> Dict[str, Any]:
-    """
-    建议词搜索
-
-    根据关键词在指定渠道平台获取搜索建议词。
-
-    Args:
-        keyword: 搜索关键词
-        channel: 搜索渠道,支持的渠道有:
-            - xhs: 小红书
-            - wx: 微信
-            - github: GitHub
-            - toutiao: 头条
-            - douyin: 抖音
-            - bili: B站
-            - zhihu: 知乎
-        uid: 用户ID(自动注入)
-
-    Returns:
-        API 返回的原始响应,结构如下:
-        {
-            "code": 0,                    # 状态码,0 表示成功
-            "message": "success",         # 状态消息
-            "data": [                     # 建议词数据
-                {
-                    "type": "xhs",        # 渠道类型
-                    "list": [             # 建议词列表
-                        {
-                            "name": "彩虹染发"  # 建议词
-                        }
-                    ]
-                }
-            ]
-        }
-
-    Raises:
-        httpx.HTTPStatusError: HTTP 请求返回非 2xx 状态码
-        httpx.TimeoutException: 请求超时
-        httpx.RequestError: 其他请求错误
-    """
-    # 处理 channel 参数,支持枚举和字符串
-    channel_value = channel.value if isinstance(channel, SuggestSearchChannel) else channel
-
-    url = f"{BASE_URL}/suggest"
-    payload = {
-        "type": channel_value,
-        "keyword": keyword,
-    }
-
-    async with httpx.AsyncClient(timeout=DEFAULT_TIMEOUT) as client:
-        response = await client.post(
-            url,
-            json=payload,
-            headers={"Content-Type": "application/json"},
-        )
-        response.raise_for_status()
-        return response.json()