Browse Source

Merge branch 'feature/202506-exp-tools' of Server/AgentCoreService into master

fengzhoutian 4 days ago
parent
commit
3fbea2d5b4

+ 2 - 1
pqai_agent/agent_service.py

@@ -32,6 +32,7 @@ from pqai_agent.push_service import PushScanThread, PushTaskWorkerPool
 from pqai_agent.rate_limiter import MessageSenderRateLimiter
 from pqai_agent.response_type_detector import ResponseTypeDetector
 from pqai_agent.service_module_manager import ServiceModuleManager
+from pqai_agent.toolkit import get_tools
 from pqai_agent.user_manager import UserManager, UserRelationManager
 from pqai_agent.message_queue_backend import MessageQueueBackend, AliyunRocketMQQueueBackend
 from pqai_agent.user_profile_extractor import UserProfileExtractor
@@ -458,7 +459,7 @@ class AgentService:
         if agent_config:
             chat_agent = MessageReplyAgent(model=agent_config.execution_model,
                                            system_prompt=agent_config.system_prompt,
-                                           tools=None)
+                                           tools=get_tools(agent_config.tools))
         else:
             chat_agent = MessageReplyAgent()
         chat_responses = chat_agent.generate_message(

+ 3 - 2
pqai_agent/agents/multimodal_chat_agent.py

@@ -5,6 +5,7 @@ from typing import Optional, List, Dict
 from pqai_agent.agents.simple_chat_agent import SimpleOpenAICompatibleChatAgent
 from pqai_agent.logging_service import logger
 from pqai_agent.mq_message import MessageType
+from pqai_agent.toolkit import get_tool
 from pqai_agent.toolkit.function_tool import FunctionTool
 from pqai_agent.toolkit.message_notifier import MessageNotifier
 
@@ -17,9 +18,9 @@ class MultiModalChatAgent(SimpleOpenAICompatibleChatAgent):
                  generate_cfg: Optional[dict] = None, max_run_step: Optional[int] = None):
         super().__init__(model, system_prompt, tools, generate_cfg, max_run_step)
         if 'output_multimodal_message' not in self.tool_map:
-            self.add_tool(FunctionTool(MessageNotifier.output_multimodal_message))
+            self.add_tool(get_tool('output_multimodal_message'))
         if 'message_notify_user' not in self.tool_map:
-            self.add_tool(FunctionTool(MessageNotifier.message_notify_user))
+            self.add_tool(get_tool('message_notify_user'))
 
     @abstractmethod
     def generate_message(self, context: Dict, dialogue_history: List[Dict],

+ 2 - 1
pqai_agent/push_service.py

@@ -17,6 +17,7 @@ from pqai_agent.configs import apollo_config
 from pqai_agent.data_models.agent_push_record import AgentPushRecord
 from pqai_agent.logging_service import logger
 from pqai_agent.mq_message import MessageType
+from pqai_agent.toolkit import get_tools
 from pqai_agent.utils.agent_abtest_utils import get_agent_abtest_config
 
 
@@ -199,7 +200,7 @@ class PushTaskWorkerPool:
             if agent_config:
                 push_agent = MessagePushAgent(model=agent_config.execution_model,
                                               system_prompt=agent_config.system_prompt,
-                                              tools=None)
+                                              tools=get_tools(agent_config.tools))
                 query_prompt_template = agent_config.task_prompt
             else:
                 push_agent = MessagePushAgent()

+ 33 - 0
pqai_agent/toolkit/__init__.py

@@ -0,0 +1,33 @@
+# 必须要在这里导入模块,以便对应的模块执行register_toolkit
+from typing import Sequence, List
+
+from pqai_agent.toolkit.tool_registry import ToolRegistry
+from pqai_agent.toolkit.image_describer import ImageDescriber
+from pqai_agent.toolkit.message_notifier import MessageNotifier
+from pqai_agent.toolkit.pq_video_searcher import PQVideoSearcher
+
+global_tool_map = ToolRegistry.tool_map
+
+def get_tool(tool_name: str) -> 'FunctionTool':
+    """
+    Retrieve a tool by its name from the global tool map.
+
+    Args:
+        tool_name (str): The name of the tool to retrieve.
+
+    Returns:
+        FunctionTool: The tool instance if found, otherwise None.
+    """
+    return global_tool_map.get(tool_name, None)
+
+def get_tools(tool_names: Sequence[str]) -> List['FunctionTool']:
+    """
+    Retrieve multiple tools by their names from the global tool map.
+
+    Args:
+        tool_names (Sequence[str]): A sequence of tool names to retrieve.
+
+    Returns:
+        Sequence[FunctionTool]: A sequence of tool instances corresponding to the provided names.
+    """
+    return [get_tool(name) for name in tool_names if get_tool(name) is not None]

+ 2 - 0
pqai_agent/toolkit/image_describer.py

@@ -6,11 +6,13 @@ from pqai_agent.chat_service import VOLCENGINE_MODEL_DOUBAO_1_5_VISION_PRO
 from pqai_agent.logging_service import logger
 from pqai_agent.toolkit.base import BaseToolkit
 from pqai_agent.toolkit.function_tool import FunctionTool
+from pqai_agent.toolkit.tool_registry import register_toolkit
 
 # 不同实例间复用cache,但不是很好的实践
 _image_describer_caches = {}
 _cache_mutex = threading.Lock()
 
+@register_toolkit
 class ImageDescriber(BaseToolkit):
     def __init__(self, cache_dir: str = None):
         self.model = VOLCENGINE_MODEL_DOUBAO_1_5_VISION_PRO

+ 2 - 0
pqai_agent/toolkit/message_notifier.py

@@ -3,8 +3,10 @@ from typing import List, Dict
 from pqai_agent.logging_service import logger
 from pqai_agent.toolkit.base import BaseToolkit
 from pqai_agent.toolkit.function_tool import FunctionTool
+from pqai_agent.toolkit.tool_registry import register_toolkit
 
 
+@register_toolkit
 class MessageNotifier(BaseToolkit):
     def __init__(self):
         super().__init__()

+ 3 - 0
pqai_agent/toolkit/pq_video_searcher.py

@@ -3,7 +3,10 @@ import requests
 
 from pqai_agent.toolkit.base import BaseToolkit
 from pqai_agent.toolkit.function_tool import FunctionTool
+from pqai_agent.toolkit.tool_registry import register_toolkit
 
+
+@register_toolkit
 class PQVideoSearcher(BaseToolkit):
     API_URL = "https://vlogapi.piaoquantv.com/longvideoapi/search/userandvideo/list"
     def search_pq_video(self, keywords: List[str]) -> List[Dict]:

+ 2 - 0
pqai_agent/toolkit/search_toolkit.py

@@ -4,8 +4,10 @@ import requests
 
 from pqai_agent.toolkit.base import BaseToolkit
 from pqai_agent.toolkit.function_tool import FunctionTool
+from pqai_agent.toolkit.tool_registry import register_toolkit
 
 
+@register_toolkit
 class SearchToolkit(BaseToolkit):
     r"""A class representing a toolkit for web search.
     """

+ 27 - 0
pqai_agent/toolkit/tool_registry.py

@@ -0,0 +1,27 @@
+from typing import Type, Dict
+from pqai_agent.toolkit.function_tool import FunctionTool
+
+class ToolRegistry:
+    tool_map: Dict[str, FunctionTool] = {}
+
+    @classmethod
+    def register_tools(cls, toolkit_class: Type):
+        """
+        Register tools from a toolkit class into the global tool_map.
+
+        Args:
+            toolkit_class (Type): A class that implements a `get_tools` method.
+        """
+        instance = toolkit_class()
+        if not hasattr(instance, 'get_tools') or not callable(instance.get_tools):
+            raise ValueError(f"{toolkit_class.__name__} must implement a callable `get_tools` method.")
+
+        tools = instance.get_tools()
+        for tool in tools:
+            if not hasattr(tool, 'name'):
+                raise ValueError(f"Tool {tool} must have a `name` attribute.")
+            cls.tool_map[tool.name] = tool
+
+def register_toolkit(cls):
+    ToolRegistry.register_tools(cls)
+    return cls