Procházet zdrojové kódy

Update agent_service: add get_chat_response_v2

StrayWarrior před 2 týdny
rodič
revize
a95c020375
1 změnil soubory, kde provedl 27 přidání a 6 odebrání
  1. 27 6
      pqai_agent/agent_service.py

+ 27 - 6
pqai_agent/agent_service.py

@@ -6,7 +6,7 @@ import re
 import signal
 import sys
 import time
-from typing import Dict, List, Optional
+from typing import Dict, List, Optional, Union
 import logging
 from datetime import datetime, timedelta
 import threading
@@ -18,6 +18,7 @@ from apscheduler.schedulers.background import BackgroundScheduler
 from sqlalchemy.orm import sessionmaker
 
 from pqai_agent import configs
+from pqai_agent.agents.message_reply_agent import MessageReplyAgent
 from pqai_agent.configs import apollo_config
 from pqai_agent.exceptions import NoRetryException
 from pqai_agent.logging_service import logger
@@ -293,7 +294,7 @@ class AgentService:
             elif need_response:
                 # 先更新用户画像再处理回复
                 self._update_user_profile(user_id, user_profile, agent.dialogue_history[-10:])
-                resp = self._get_chat_response(user_id, agent, message_text)
+                resp = self.get_chat_response(agent, message_text)
                 if resp:
                     recent_dialogue = agent.dialogue_history[-10:]
                     agent_voice_whitelist = set(apollo_config.get_json_value("agent_voice_whitelist", []))
@@ -409,9 +410,14 @@ class AgentService:
         # 问题在于,如果每次创建出新的PushTaskWorkerPool,在上次任务有未处理完的消息即退出时,会有未处理的消息堆积
         push_task_worker_pool.wait_to_finish()
 
-    def _get_chat_response(self, user_id: str, agent: DialogueManager,
-                           user_message: Optional[str]):
-        """处理LLM响应"""
+    def get_chat_response(self, agent: DialogueManager, user_message: Optional[str]) -> Union[str, List[Dict]]:
+        chat_agent_ver = self.config.get('system', {}).get('chat_agent_version', 1)
+        if chat_agent_ver == 2:
+            return self._get_chat_response_v2(agent)
+        else:
+            return self._get_chat_response_v1(agent, user_message)
+
+    def _get_chat_response_v1(self, agent: DialogueManager, user_message: Optional[str]) -> str:
         chat_config = agent.build_chat_configuration(user_message, self.chat_service_type)
         config_for_logging = chat_config.copy()
         config_for_logging['messages'] = config_for_logging['messages'][-20:]
@@ -422,9 +428,24 @@ class AgentService:
         if response := agent.generate_response(chat_response):
             return response
         else:
-            logger.warning(f"staff[{agent.staff_id}] user[{user_id}]: no response generated")
+            logger.warning(f"staff[{agent.staff_id}] user[{agent.user_id}]: no response generated")
             return None
 
+    def _get_chat_response_v2(self, main_agent: DialogueManager) -> List[Dict]:
+        chat_agent = MessageReplyAgent()
+        chat_responses = chat_agent.generate_message(
+            context=main_agent.get_prompt_context(None),
+            dialogue_history=main_agent.dialogue_history[-100:]
+        )
+        if not chat_responses:
+            logger.warning(f"staff[{main_agent.staff_id}] user[{main_agent.user_id}]: no response generated")
+            return []
+        final_responses = []
+        for chat_response in chat_responses:
+            if response := main_agent.generate_multimodal_response(chat_response):
+                final_responses.append(response)
+        return final_responses
+
     def _call_chat_api(self, chat_config: Dict, chat_service_type: ChatServiceType) -> str:
         if configs.get().get('debug_flags', {}).get('disable_llm_api_call', False):
             return 'LLM模拟回复 {}'.format(datetime.now().strftime("%Y-%m-%d %H:%M:%S"))