3 Commits d7265c8674 ... 9361cf8807

Autore SHA1 Messaggio Data
  StrayWarrior 9361cf8807 Update agent_task_server: use create_agent_from_config 1 settimana fa
  StrayWarrior 7e7efd868b Update base agent class 1 settimana fa
  StrayWarrior 114df71a23 Add agent_utils 1 settimana fa

+ 1 - 3
pqai_agent/agent.py

@@ -7,13 +7,11 @@ class BaseAgent(ABC):
     r"""An abstract base class for all agents."""
 
     @abstractmethod
-    def run(self, user_input: str, **kwargs) -> Any:
+    def run(self, user_input: str) -> Any:
         """Run the agent with the given user input.
 
         Args:
             user_input (str): The input from the user.
-            **kwargs: Additional keyword arguments.
-
         Returns:
             Any: The output from the agent.
         """

+ 24 - 0
pqai_agent/utils/agent_utils.py

@@ -0,0 +1,24 @@
+import json
+
+from pqai_agent.agents.simple_chat_agent import SimpleOpenAICompatibleChatAgent
+from pqai_agent.data_models.agent_configuration import AgentConfiguration
+from pqai_agent.toolkit import get_tools
+from pqai_agent.toolkit.sub_agent_toolkit import SubAgentToolkit
+
+
+def create_agent_from_config(agent_config: AgentConfiguration, session_maker) -> SimpleOpenAICompatibleChatAgent:
+    tools = get_tools(json.loads(agent_config.tools))
+    sub_agent_ids = json.loads(agent_config.sub_agents)
+    if sub_agent_ids:
+        # 查询子Agent配置
+        with session_maker() as session:
+            sub_agent_configs = session.query(AgentConfiguration).filter(
+                AgentConfiguration.id.in_(sub_agent_ids)).all()
+        # 将子Agent配置转换为工具
+        for sub_agent_config in sub_agent_configs:
+            sub_agent_tool = SubAgentToolkit.create_tool_from_agent(sub_agent_config)
+            tools.append(sub_agent_tool)
+    chat_agent = SimpleOpenAICompatibleChatAgent(model=agent_config.execution_model,
+                                                 system_prompt=agent_config.system_prompt,
+                                                 tools=tools)
+    return chat_agent

+ 2 - 6
pqai_agent_server/agent_task_server.py

@@ -6,12 +6,11 @@ from typing import Dict
 
 from sqlalchemy import func, select
 
-from pqai_agent.agents.simple_chat_agent import SimpleOpenAICompatibleChatAgent
 from pqai_agent.data_models.agent_configuration import AgentConfiguration
 from pqai_agent.data_models.agent_task import AgentTask
 from pqai_agent.data_models.agent_task_detail import AgentTaskDetail
 from pqai_agent.logging import logger
-from pqai_agent.toolkit import get_tools
+from pqai_agent.utils.agent_utils import create_agent_from_config
 from pqai_agent_server.const.status_enum import AgentTaskStatus, get_agent_task_detail_status_desc, \
     AgentTaskDetailStatus, get_agent_task_status_desc
 
@@ -160,10 +159,7 @@ class AgentTaskManager:
             self.update_task_status(task_id, AgentTaskStatus.IN_PROGRESS.value)
             agent_task = self.get_agent_task(task_id)
             agent_config = self.get_agent_config(agent_task.agent_id)
-            tools = get_tools(json.loads(agent_config.tools))
-            chat_agent = SimpleOpenAICompatibleChatAgent(model=agent_config.execution_model,
-                                                         system_prompt=agent_config.system_prompt,
-                                                         tools=tools)
+            chat_agent = create_agent_from_config(agent_config, self.session_maker)
             message = chat_agent.run(agent_task.input)
             agent_task_details = chat_agent.get_agent_task_details()
             for agent_task_detail in agent_task_details: