Browse Source

增加agent提交执行和查询任务功能

xueyiming 2 days ago
parent
commit
1c4481aed5

+ 26 - 2
pqai_agent/agents/simple_chat_agent.py

@@ -4,9 +4,10 @@ from typing import List, Optional
 import pqai_agent.utils
 from pqai_agent.agent import DEFAULT_MAX_RUN_STEPS
 from pqai_agent.chat_service import OpenAICompatible
+from pqai_agent.data_models.agent_task_detail import AgentTaskDetail
 from pqai_agent.logging import logger
 from pqai_agent.toolkit.function_tool import FunctionTool
-
+from pqai_agent_server.const.status_enum import AgentTaskDetailStatus
 
 
 class SimpleOpenAICompatibleChatAgent:
@@ -24,6 +25,7 @@ class SimpleOpenAICompatibleChatAgent:
         self.generate_cfg = generate_cfg or {}
         self.max_run_step = max_run_step or DEFAULT_MAX_RUN_STEPS
         self.tool_call_records = []
+        self.agent_task_details: list[AgentTaskDetail] = []
         self.total_input_tokens = 0
         self.total_output_tokens = 0
         logger.debug(self.tool_map)
@@ -57,8 +59,15 @@ class SimpleOpenAICompatibleChatAgent:
                     arguments = json.loads(tool_call.function.arguments)
                     logger.debug(f"run_id[{run_id}] call function[{function_name}], parameter: {arguments}")
 
+                    agent_task_detail = AgentTaskDetail()
+                    agent_task_detail.executor_type = 'tool'
+                    agent_task_detail.executor_name = function_name
+                    agent_task_detail.input_data = tool_call.function.arguments
+                    self.agent_task_details.append(agent_task_detail)
+
                     if function_name in self.tool_map:
-                        result = self.tool_map[function_name](**arguments)
+                        # result = self.tool_map[function_name](**arguments)
+                        result = "success"
                         messages.append({
                             "role": "tool",
                             "tool_call_id": tool_call.id,
@@ -69,15 +78,30 @@ class SimpleOpenAICompatibleChatAgent:
                             "arguments": arguments,
                             "result": result
                         })
+                        agent_task_detail.output_data = json.dumps(result, ensure_ascii=False)
+                        agent_task_detail.status = AgentTaskDetailStatus.SUCCESS.value
                     else:
+                        agent_task_detail.error_message = f"Function {function_name} not found in tool map."
+                        agent_task_detail.status = AgentTaskDetailStatus.FAILED.value
                         logger.error(f"run_id[{run_id}] Function {function_name} not found in tool map.")
                         raise Exception(f"Function {function_name} not found in tool map.")
             else:
+                agent_task_detail = AgentTaskDetail()
+                agent_task_detail.executor_type = 'llm'
+                agent_task_detail.executor_name = self.model
+                agent_task_detail.output_data = message.content
+                agent_task_detail.status = AgentTaskDetailStatus.SUCCESS.value
+                self.agent_task_details.append(agent_task_detail)
                 return message.content
             n_steps += 1
 
         raise Exception("Max run steps exceeded")
 
+    # 新增方法:获取步骤记录
+    def get_agent_task_details(self) -> list:
+        """返回代理运行过程中的详细步骤记录"""
+        return self.agent_task_details
+
     def get_total_input_tokens(self) -> int:
         """获取总输入token数"""
         return self.total_input_tokens

+ 22 - 0
pqai_agent/data_models/agent_task.py

@@ -0,0 +1,22 @@
+from sqlalchemy import Column, Integer, Text, BigInteger, String, SmallInteger, Boolean, TIMESTAMP
+from sqlalchemy.ext.declarative import declarative_base
+
+Base = declarative_base()
+
+
+class AgentTask(Base):
+    __tablename__ = "agent_task"
+
+    id = Column(BigInteger, primary_key=True, autoincrement=True, comment="主键id")
+    agent_id = Column(BigInteger, nullable=False, comment="agent主键")
+    status = Column(Integer, nullable=False, default=0, comment="状态(0:未开始, 1:进行中, 2:已完成, 3:失败)")
+    start_time = Column(TIMESTAMP, nullable=True, comment="任务开始执行时间")
+    end_time = Column(TIMESTAMP, nullable=True, comment="任务结束执行时间")
+    create_user = Column(String(32), nullable=True, comment="创建用户")
+    input = Column(Text, nullable=True, comment="任务执行输入")
+    tools = Column(Text, nullable=True, comment="任务使用的工具")
+    output = Column(Text, nullable=True, comment="任务执行输出")
+    error_message = Column(Text, nullable=True, comment="错误详情(失败时记录)")
+    create_time = Column(TIMESTAMP, nullable=False, server_default="CURRENT_TIMESTAMP", comment="创建时间")
+    update_time = Column(TIMESTAMP, nullable=False, server_default="CURRENT_TIMESTAMP", onupdate="CURRENT_TIMESTAMP",
+                         comment="更新时间")

+ 23 - 0
pqai_agent/data_models/agent_task_detail.py

@@ -0,0 +1,23 @@
+from sqlalchemy import Column, Integer, Text, BigInteger, String, SmallInteger, Boolean, TIMESTAMP
+from sqlalchemy.dialects.mysql import VARCHAR
+from sqlalchemy.ext.declarative import declarative_base
+
+Base = declarative_base()
+
+
+class AgentTaskDetail(Base):
+    __tablename__ = "agent_task_detail"
+
+    id = Column(BigInteger, primary_key=True, autoincrement=True, comment="主键id")
+    agent_task_id = Column(BigInteger, nullable=False, comment="agent执行任务id")
+    parent_execution_id = Column(BigInteger, nullable=False, comment="父级执行ID(用于构建调用树)")
+    executor_type = Column(VARCHAR(32), nullable=True, comment="执行体类型(LLM/agent/tool)")
+    status = Column(Integer, nullable=False, default=0, comment="执行状态(0-执行中 1-成功 2-失败)")
+    input_data = Column(Text, nullable=True, comment="执行输入参数(结构化存储)")
+    executor_name = Column(Text, nullable=True, comment="执行模型名/工具名/子Agent名")
+    reasoning = Column(Text, nullable=True, comment="思考过程(仅适用于LLM步骤)")
+    output_data = Column(Text, nullable=True, comment="当前执行体的原始输出")
+    error_message = Column(Text, nullable=True, comment="错误详情(失败时记录)")
+    create_time = Column(TIMESTAMP, nullable=False, server_default="CURRENT_TIMESTAMP", comment="创建时间")
+    update_time = Column(TIMESTAMP, nullable=False, server_default="CURRENT_TIMESTAMP", onupdate="CURRENT_TIMESTAMP",
+                         comment="更新时间")

+ 206 - 0
pqai_agent_server/agent_task_server.py

@@ -0,0 +1,206 @@
+import concurrent.futures
+import json
+import threading
+from concurrent.futures import ThreadPoolExecutor
+from datetime import datetime
+from typing import Dict
+
+from sqlalchemy import func
+
+from pqai_agent.agents.multimodal_chat_agent import MultiModalChatAgent
+from pqai_agent.agents.simple_chat_agent import SimpleOpenAICompatibleChatAgent
+from pqai_agent.data_models import agent_task_detail
+from pqai_agent.data_models.agent_configuration import AgentConfiguration
+from pqai_agent.data_models.agent_task import AgentTask
+from pqai_agent.data_models.agent_task_detail import AgentTaskDetail
+from pqai_agent.data_models.agent_test_task import AgentTestTask
+from pqai_agent.data_models.agent_test_task_conversations import AgentTestTaskConversations
+from pqai_agent.data_models.service_module import ServiceModule
+from pqai_agent.logging import logger
+from pqai_agent.toolkit import get_tools
+from pqai_agent_server.const.status_enum import TestTaskConversationsStatus, TestTaskStatus, get_test_task_status_desc, \
+    AgentTaskStatus, get_agent_task_detail_status_desc, AgentTaskDetailStatus
+from scripts.evaluate_agent import evaluate_agent
+
+
+class AgentTaskManager:
+    """任务管理器"""
+
+    def __init__(self, session_maker):
+        self.session_maker = session_maker
+        self.task_events = {}  # 任务ID -> Event (用于取消任务)
+        self.task_locks = {}  # 任务ID -> Lock (用于任务状态同步)
+        self.running_tasks = set()
+        self.executor = ThreadPoolExecutor(max_workers=10, thread_name_prefix='AgentTaskWorker')
+        self.task_futures = {}  # 任务ID -> Future
+
+    def get_agent_task(self, agent_task_id: int):
+        """获取任务信息"""
+        with self.session_maker() as session:
+            return session.query(AgentTask).filter(AgentTask.id == agent_task_id).one()
+
+    def get_agent_config(self, agent_id: int):
+        """获取任务信息"""
+        with self.session_maker() as session:
+            return session.query(AgentConfiguration).filter(AgentConfiguration.id == agent_id).one()
+
+    def get_in_progress_task(self):
+        """获取执行中任务"""
+        with self.session_maker() as session:
+            return session.query(AgentTask).filter(AgentTask.status.in_([
+                AgentTaskStatus.NOT_STARTED.value,
+                AgentTaskStatus.IN_PROGRESS.value
+            ])).all()
+
+    def get_agent_task_details(self, task_id):
+        """更新任务状态"""
+        with self.session_maker() as session:
+            return session.query(AgentTaskDetail).filter(AgentTaskDetail.agent_task_id == task_id).all()
+
+    def save_agent_task_details_batch(self, agent_task_details: list, agent_task_id: int, message: str):
+        """批量保存子任务到数据库"""
+        try:
+            with self.session_maker() as session:
+                with session.begin():
+                    session.add_all(agent_task_details)
+                    session.query(AgentTask).filter(
+                        AgentTask.id == agent_task_id).update(
+                        {"status": AgentTaskStatus.COMPLETED.value, "output": message, "update_time": datetime.now()})
+                    session.commit()
+        except Exception as e:
+            logger.error(e)
+            raise Exception(e)
+
+    def update_task_failed(self, task_id, error_message: str):
+        """更新任务状态"""
+        with self.session_maker() as session:
+            session.query(AgentTask).filter(AgentTask.id == task_id).update(
+                {"status": AgentTaskStatus.FAILED, "error_message": error_message, "update_time": datetime.now()})
+            session.commit()
+
+    def update_task_status(self, task_id, status):
+        """更新任务状态"""
+        with self.session_maker() as session:
+            session.query(AgentTask).filter(AgentTask.id == task_id).update(
+                {"status": status, "update_time": datetime.now()})
+            session.commit()
+
+    def get_agent_task_list(self, page_num: int, page_size: int) -> Dict:
+        with self.session_maker() as session:
+            # 计算偏移量
+            offset = (page_num - 1) * page_size
+            # 查询分页数据
+            result = (session.query(AgentTestTask, AgentConfiguration)
+                      .outerjoin(AgentConfiguration, AgentTestTask.agent_id == AgentConfiguration.id)
+                      .limit(page_size).offset(offset).all())
+            # 查询总记录数
+            total = session.query(func.count(AgentTestTask.id)).scalar()
+
+            total_page = total // page_size + 1 if total % page_size > 0 else total // page_size
+            total_page = 1 if total_page <= 0 else total_page
+            response_data = [
+                {
+                    "id": agent_test_task.id,
+                    "agentName": agent_configuration.name,
+                    "createUser": agent_test_task.create_user,
+                    "updateUser": agent_test_task.update_user,
+                    "statusName": get_test_task_status_desc(agent_test_task.status),
+                    "createTime": agent_test_task.create_time.strftime("%Y-%m-%d %H:%M:%S"),
+                    "updateTime": agent_test_task.update_time.strftime("%Y-%m-%d %H:%M:%S")
+                }
+                for agent_test_task, agent_configuration in result
+            ]
+            return {
+                "currentPage": page_num,
+                "pageSize": page_size,
+                "totalSize": total_page,
+                "total": total,
+                "list": response_data,
+            }
+
+    def create_task(self, agent_id: int, task_prompt: str):
+        """创建新任务"""
+        with self.session_maker() as session:
+            agent_config = session.get(AgentConfiguration, agent_id)
+            agent_task = AgentTask(agent_id=agent_id,
+                                   status=AgentTaskStatus.NOT_STARTED.value,
+                                   start_time=datetime.now(),
+                                   input=task_prompt,
+                                   tools=agent_config.tools)
+            session.add(agent_task)
+            session.commit()  # 显式提交
+            task_id = agent_task.id
+        # 异步执行创建任务
+        self.executor.submit(self._execute_task, task_id)
+
+    def _process_task(self, task_id: int):
+        try:
+            self.update_task_status(task_id, AgentTaskStatus.IN_PROGRESS.value)
+            agent_task = self.get_agent_task(task_id)
+            agent_config = self.get_agent_config(agent_task.agent_id)
+            tools = get_tools(json.loads(agent_config.tools))
+            chat_agent = SimpleOpenAICompatibleChatAgent(model=agent_config.execution_model,
+                                                         system_prompt=agent_config.system_prompt,
+                                                         tools=tools)
+            message = chat_agent.run(agent_task.input)
+            agent_task_details = chat_agent.get_agent_task_details()
+            for agent_task_detail in agent_task_details:
+                agent_task_detail.agent_task_id = task_id
+                agent_task_detail.status = AgentTaskDetailStatus.SUCCESS.value
+            self.save_agent_task_details_batch(agent_task_details, task_id, message)
+        except Exception as e:
+            logger.error(e)
+            self.update_task_failed(task_id, str(e))
+
+    def recover_tasks(self):
+        """服务启动时恢复未完成的任务"""
+
+        in_progress_tasks = self.get_in_progress_task()
+
+        for task in in_progress_tasks:
+            task_id = task.id
+            # 重新提交任务
+            self._execute_task(task_id)
+
+    def _execute_task(self, task_id: int):
+        """提交任务到线程池执行"""
+        # 确保任务状态一致性
+        if task_id in self.running_tasks:
+            return
+
+        # 创建任务事件和锁
+        if task_id not in self.task_events:
+            self.task_events[task_id] = threading.Event()
+        if task_id not in self.task_locks:
+            self.task_locks[task_id] = threading.Lock()
+
+        # 提交到线程池
+        future = self.executor.submit(self._process_task, task_id)
+        self.task_futures[task_id] = future
+
+        # 标记任务为运行中
+        with self.task_locks[task_id]:
+            self.running_tasks.add(task_id)
+
+    def get_agent_task_detail(self, agent_task_id):
+        agent_task = self.get_agent_task(agent_task_id)
+        agent_task_details = self.get_agent_task_details(agent_task_id)
+        agent_task_detail_datas = []
+        for agent_task_detail in agent_task_details:
+            data = {}
+            data["id"] = agent_task_detail.id
+            data["executorType"] = agent_task_detail.executor_type
+            data["status"] = get_agent_task_detail_status_desc(agent_task_detail.status)
+            data["inputData"] = agent_task_detail.input_data
+            data["executorName"] = agent_task_detail.executor_name
+            data["reasoning"] = agent_task_detail.reasoning
+            data["outputData"] = agent_task_detail.output_data
+            data["errorMessage"] = agent_task_detail.error_message
+            data["createTime"]: agent_task_detail.create_time.strftime("%Y-%m-%d %H:%M:%S")
+            data["updateTime"]: agent_task_detail.update_time.strftime("%Y-%m-%d %H:%M:%S")
+            agent_task_detail_datas.append(data)
+        return {
+            "input": agent_task.input,
+            "tools": agent_task.tools,
+            "agentTaskDetails": agent_task_detail_datas
+        }

+ 36 - 1
pqai_agent_server/api_server.py

@@ -20,6 +20,7 @@ from pqai_agent.toolkit import global_tool_map
 from pqai_agent.user_manager import MySQLUserManager, MySQLUserRelationManager
 from pqai_agent.utils.db_utils import create_ai_agent_db_engine
 from pqai_agent.utils.prompt_utils import format_agent_profile, format_user_profile
+from pqai_agent_server.agent_task_server import AgentTaskManager
 from pqai_agent_server.const import AgentApiConst
 from pqai_agent_server.const.status_enum import TestTaskStatus
 from pqai_agent_server.const.type_enum import EvaluateType
@@ -764,6 +765,36 @@ def get_agent_types():
     ]
     return wrap_response(200, data=agent_types)
 
+@app.route("/api/createAgentTask", methods=["POST"])
+def create_agent_task():
+    """
+       创建agent执行任务
+       :return:
+    """
+    req_data = request.json
+    agent_id = req_data.get('agentId', None)
+    task_prompt = req_data.get('taskPrompt', None)
+    if not agent_id:
+        return wrap_response(404, msg='agent id is required')
+    if not task_prompt:
+        return wrap_response(404, msg='task_prompt is required')
+    app.agent_task_manager.create_task(agent_id, task_prompt)
+    return wrap_response(200)
+
+
+@app.route("/api/getAgentTaskDetail", methods=["GET"])
+def get_agent_task_detail():
+    """
+       查询agent执行任务详情
+       :return:
+    """
+    agent_task_id = request.args.get("agentTaskId", None)
+    if not agent_task_id:
+        return wrap_response(404, msg='agent_task_id is required')
+    response = app.agent_task_manager.get_agent_task_detail(int(agent_task_id))
+    return wrap_response(200, data=response)
+
+
 @app.errorhandler(werkzeug.exceptions.BadRequest)
 def handle_bad_request(e):
     logger.error(e)
@@ -811,7 +842,11 @@ if __name__ == '__main__':
 
     task_manager = TaskManager(session_maker=sessionmaker(bind=agent_db_engine), dataset_service=dataset_service)
     app.task_manager = task_manager
-    task_manager.recover_tasks()
+    app.task_manager.recover_tasks()
+
+    agent_task_manager = AgentTaskManager(session_maker=sessionmaker(bind=agent_db_engine))
+    app.agent_task_manager = agent_task_manager
+    app.agent_task_manager.recover_tasks()
 
     wecom_db_config = config['storage']['user_relation']
     user_relation_manager = MySQLUserRelationManager(

+ 49 - 0
pqai_agent_server/const/status_enum.py

@@ -64,3 +64,52 @@ def get_test_task_conversations_status_desc(status_code):
         return status.description
     except ValueError:
         return f"未知状态: {status_code}"
+
+
+class AgentTaskStatus(Enum):
+    NOT_STARTED = 0
+    IN_PROGRESS = 1
+    COMPLETED = 2
+    FAILED = 3
+
+    @property
+    def description(self):
+        descriptions = {
+            self.NOT_STARTED: "未开始",
+            self.IN_PROGRESS: "进行中",
+            self.COMPLETED: "已完成",
+            self.FAILED: "已失败"
+        }
+        return descriptions.get(self)
+
+
+# 使用示例
+def get_agent_task_status_desc(status_code):
+    try:
+        status = AgentTaskStatus(status_code)
+        return status.description
+    except ValueError:
+        return f"未知状态: {status_code}"
+
+class AgentTaskDetailStatus(Enum):
+    IN_PROGRESS = 0
+    SUCCESS = 1
+    FAILED = 2
+
+    @property
+    def description(self):
+        descriptions = {
+            self.IN_PROGRESS: "执行中",
+            self.SUCCESS: "成功",
+            self.FAILED: "失败"
+        }
+        return descriptions.get(self)
+
+
+# 使用示例
+def get_agent_task_detail_status_desc(status_code):
+    try:
+        status = AgentTaskDetailStatus(status_code)
+        return status.description
+    except ValueError:
+        return f"未知状态: {status_code}"