jihuaqiang 2 недель назад
Родитель
Сommit
4a9738cfe4
3 измененных файлов с 61 добавлено и 24 удалено
  1. 3 1
      examples/create/run.py
  2. 1 13
      examples/create/tool/__init__.py
  3. 57 10
      examples/create/tool/topic_search.py

+ 3 - 1
examples/create/run.py

@@ -22,7 +22,7 @@ from pathlib import Path
 # Clash Verge TUN 模式兼容:禁止 httpx/urllib 自动检测系统 HTTP 代理
 # Clash Verge TUN 模式兼容:禁止 httpx/urllib 自动检测系统 HTTP 代理
 # TUN 虚拟网卡已在网络层接管所有流量,不需要应用层再走 HTTP 代理,
 # TUN 虚拟网卡已在网络层接管所有流量,不需要应用层再走 HTTP 代理,
 # 否则 httpx 检测到 macOS 系统代理 (127.0.0.1:7897) 会导致 ConnectError
 # 否则 httpx 检测到 macOS 系统代理 (127.0.0.1:7897) 会导致 ConnectError
-os.environ.setdefault("no_proxy", "*")
+# os.environ.setdefault("no_proxy", "*")
 
 
 # 添加项目根目录到 Python 路径
 # 添加项目根目录到 Python 路径
 sys.path.insert(0, str(Path(__file__).parent.parent.parent))
 sys.path.insert(0, str(Path(__file__).parent.parent.parent))
@@ -301,6 +301,8 @@ async def main():
     # 加载自定义工具
     # 加载自定义工具
     print("   - 加载自定义工具: nanobanana")
     print("   - 加载自定义工具: nanobanana")
     import examples.how.tool  # 导入自定义工具模块,触发 @tool 装饰器注册
     import examples.how.tool  # 导入自定义工具模块,触发 @tool 装饰器注册
+    print("   - 加载自定义工具: topic_search")
+    import examples.create.tool  # 选题检索工具,用于在数据库中匹配已有帖子选题
 
 
     store = FileSystemTraceStore(base_path=".trace")
     store = FileSystemTraceStore(base_path=".trace")
     runner = AgentRunner(
     runner = AgentRunner(

+ 1 - 13
examples/create/tool/__init__.py

@@ -3,17 +3,5 @@ Create 示例的自定义工具
 """
 """
 
 
 from examples.create.tool.topic_search import topic_search
 from examples.create.tool.topic_search import topic_search
-from examples.create.tool.search_library import (
-    search_point_by_element_from_full_all_levels,
-    search_point_by_path_from_full_all_levels
-)
-from examples.create.tool.search_person_tree import (
-    search_person_tree_constants
-)
 
 
-__all__ = [
-    "topic_search",
-    "search_point_by_element_from_full_all_levels",
-    "search_point_by_path_from_full_all_levels",
-    "search_person_tree_constants"
-]
+__all__ = ["topic_search"]

+ 57 - 10
examples/create/tool/topic_search.py

@@ -1,7 +1,7 @@
 """
 """
 选题检索工具 - 根据关键词在数据库中匹配已有帖子的选题
 选题检索工具 - 根据关键词在数据库中匹配已有帖子的选题
 
 
-用于 Agent 执行时自主调取参考数据。
+用于 Agent 执行时自主调取参考数据,并选择与当前人设最匹配的内容输出
 """
 """
 
 
 import json
 import json
@@ -41,33 +41,80 @@ async def _call_search_api(keywords: List[str]) -> Optional[List[Dict[str, Any]]
     return []
     return []
 
 
 
 
-def _pick_first(results: List[Dict[str, Any]]) -> Dict[str, Any]:
-    """从结果中取第一条。"""
+def _extract_text(obj: Any) -> str:
+    """从结果对象中提取可比较的文本。"""
+    if obj is None:
+        return ""
+    if isinstance(obj, str):
+        return obj
+    if isinstance(obj, dict):
+        text_parts = []
+        for k in ("title", "content", "主题", "选题", "描述", "description", "摘要"):
+            v = obj.get(k)
+            if v and isinstance(v, str):
+                text_parts.append(v)
+        if not text_parts:
+            text_parts = [str(v) for v in obj.values() if isinstance(v, str)]
+        return " ".join(text_parts)
+    return str(obj)
+
+
+def _score_match(result: Dict[str, Any], persona_summary: str) -> float:
+    """
+    计算单条结果与人设摘要的匹配度(简单关键词重叠)。
+    返回 0~1 之间的分数,越高表示越匹配。
+    """
+    if not persona_summary or not persona_summary.strip():
+        return 1.0
+
+    result_text = _extract_text(result).lower()
+    persona_words = set(
+        w for w in persona_summary.lower().replace(",", " ").replace(",", " ").split()
+        if len(w) > 1
+    )
+    if not persona_words:
+        return 1.0
+
+    hits = sum(1 for w in persona_words if w in result_text)
+    return hits / len(persona_words)
+
+
+def _pick_best_match(results: List[Dict[str, Any]], persona_summary: Optional[str]) -> Dict[str, Any]:
+    """从结果中选出与人设最匹配的一条。"""
     if not results:
     if not results:
         raise ValueError("无可用结果")
         raise ValueError("无可用结果")
-    return results[0]
+    if not persona_summary or len(results) == 1:
+        return results[0]
+
+    best = max(results, key=lambda r: _score_match(r, persona_summary))
+    return best
 
 
 
 
 @tool(
 @tool(
-    description="根据关键词在数据库中检索已有帖子的选题,用于创作参考。最多返回5条,取第一条输出。",
+    description="根据关键词在数据库中检索已有帖子的选题,用于创作参考。最多返回5条,自动选择与当前人设最匹配的一条输出。",
     display={
     display={
         "zh": {
         "zh": {
             "name": "爆款选题检索",
             "name": "爆款选题检索",
             "params": {
             "params": {
                 "keywords": "关键词列表",
                 "keywords": "关键词列表",
+                "persona_summary": "当前人设摘要(可选,用于筛选最匹配结果)",
             },
             },
         },
         },
     },
     },
 )
 )
-async def topic_search(keywords: List[str]) -> ToolResult:
+async def topic_search(
+    keywords: List[str],
+    persona_summary: Optional[str] = None,
+) -> ToolResult:
     """
     """
-    根据关键词检索数据库中已有帖子的选题,取第一条作为参考。
+    根据关键词检索数据库中已有帖子的选题,选择与人设最匹配的一条作为参考。
 
 
     Args:
     Args:
         keywords: 关键词列表,如 ["中老年健康养生", "爆款", "知识科普"]
         keywords: 关键词列表,如 ["中老年健康养生", "爆款", "知识科普"]
+        persona_summary: 当前人设摘要,用于从多条结果中筛选最匹配的(可选)
 
 
     Returns:
     Returns:
-        ToolResult: 选题参考内容
+        ToolResult: 最匹配的选题参考内容
     """
     """
     if not keywords:
     if not keywords:
         return ToolResult(
         return ToolResult(
@@ -92,7 +139,7 @@ async def topic_search(keywords: List[str]) -> ToolResult:
         )
         )
 
 
     try:
     try:
-        best = _pick_first(results)
+        best = _pick_best_match(results, persona_summary)
     except ValueError:
     except ValueError:
         return ToolResult(
         return ToolResult(
             title="选题检索",
             title="选题检索",
@@ -103,5 +150,5 @@ async def topic_search(keywords: List[str]) -> ToolResult:
     return ToolResult(
     return ToolResult(
         title="选题检索 - 参考数据",
         title="选题检索 - 参考数据",
         output=output,
         output=output,
-        long_term_memory=f"检索到选题参考,关键词: {', '.join(keywords)}",
+        long_term_memory=f"检索到与人设匹配的选题参考,关键词: {', '.join(keywords)}",
     )
     )