浏览代码

Add tool_registry

StrayWarrior 4 天之前
父节点
当前提交
30905d705c
共有 2 个文件被更改,包括 60 次插入0 次删除
  1. 33 0
      pqai_agent/toolkit/__init__.py
  2. 27 0
      pqai_agent/toolkit/tool_registry.py

+ 33 - 0
pqai_agent/toolkit/__init__.py

@@ -0,0 +1,33 @@
+# 必须要在这里导入模块,以便对应的模块执行register_toolkit
+from typing import Sequence, List
+
+from pqai_agent.toolkit.tool_registry import ToolRegistry
+from pqai_agent.toolkit.image_describer import ImageDescriber
+from pqai_agent.toolkit.message_notifier import MessageNotifier
+from pqai_agent.toolkit.pq_video_searcher import PQVideoSearcher
+
+global_tool_map = ToolRegistry.tool_map
+
+def get_tool(tool_name: str) -> 'FunctionTool':
+    """
+    Retrieve a tool by its name from the global tool map.
+
+    Args:
+        tool_name (str): The name of the tool to retrieve.
+
+    Returns:
+        FunctionTool: The tool instance if found, otherwise None.
+    """
+    return global_tool_map.get(tool_name, None)
+
+def get_tools(tool_names: Sequence[str]) -> List['FunctionTool']:
+    """
+    Retrieve multiple tools by their names from the global tool map.
+
+    Args:
+        tool_names (Sequence[str]): A sequence of tool names to retrieve.
+
+    Returns:
+        Sequence[FunctionTool]: A sequence of tool instances corresponding to the provided names.
+    """
+    return [get_tool(name) for name in tool_names if get_tool(name) is not None]

+ 27 - 0
pqai_agent/toolkit/tool_registry.py

@@ -0,0 +1,27 @@
+from typing import Type, Dict
+from pqai_agent.toolkit.function_tool import FunctionTool
+
+class ToolRegistry:
+    tool_map: Dict[str, FunctionTool] = {}
+
+    @classmethod
+    def register_tools(cls, toolkit_class: Type):
+        """
+        Register tools from a toolkit class into the global tool_map.
+
+        Args:
+            toolkit_class (Type): A class that implements a `get_tools` method.
+        """
+        instance = toolkit_class()
+        if not hasattr(instance, 'get_tools') or not callable(instance.get_tools):
+            raise ValueError(f"{toolkit_class.__name__} must implement a callable `get_tools` method.")
+
+        tools = instance.get_tools()
+        for tool in tools:
+            if not hasattr(tool, 'name'):
+                raise ValueError(f"Tool {tool} must have a `name` attribute.")
+            cls.tool_map[tool.name] = tool
+
+def register_toolkit(cls):
+    ToolRegistry.register_tools(cls)
+    return cls