Bladeren bron

Merge branch 'feature/202506-exp-system' of Server/AgentCoreService into master

fengzhoutian 4 dagen geleden
bovenliggende
commit
352f26fe22

+ 1 - 0
.gitignore

@@ -1,3 +1,4 @@
+image_descriptions_cache/
 # ---> Python
 # Byte-compiled / optimized / DLL files
 __pycache__/

+ 7 - 1
pqai_agent/abtest/models.py

@@ -3,6 +3,9 @@ import json
 from dataclasses import dataclass, field
 import hashlib
 
+from pqai_agent.logging_service import logger
+
+
 class FNV:
     INIT64 = int("cbf29ce484222325", 16)
     PRIME64 = int("100000001b3", 16)
@@ -232,7 +235,10 @@ class ExperimentResult:
         self.exp_id = ""
 
     def add_params(self, params: Dict[str, str]):
-        self.params.update(params)
+        for key, value in params.items():
+            if key in self.params:
+                logger.warning(f"Duplicate key '{key}' in params, overwriting value: {self.params[key]} with {value}")
+            self.params[key] = value
 
     def add_experiment_version(self, version):
         self.experiment_versions.append(version)

+ 7 - 0
pqai_agent/abtest/utils.py

@@ -0,0 +1,7 @@
+from pqai_agent.abtest.models import ExperimentContext
+from pqai_agent.abtest.client import get_client
+
+def get_abtest_info(uid: str):
+    client = get_client()
+    exp_ctx = ExperimentContext(uid=uid)
+    return client.match_experiment('PQAgent', exp_ctx)

+ 25 - 0
pqai_agent/agent_config_manager.py

@@ -0,0 +1,25 @@
+from typing import Dict, Optional
+
+from pqai_agent.data_models.agent_configuration import AgentConfiguration
+from pqai_agent.logging_service import logger
+
+class AgentConfigManager:
+    def __init__(self, session_maker):
+        self.session_maker = session_maker
+        self.agent_configs: Dict[int, AgentConfiguration] = {}
+        self.refresh_configs()
+
+    def refresh_configs(self):
+        try:
+            with self.session_maker() as session:
+                data = session.query(AgentConfiguration).filter_by(is_delete=False).all()
+                agent_configs = {}
+                for config in data:
+                    agent_configs[config.id] = config
+            self.agent_configs = agent_configs
+            logger.debug(f"Refreshed agent configurations: {list(self.agent_configs.keys())}")
+        except Exception as e:
+            logger.error(f"Failed to refresh agent configurations: {e}")
+
+    def get_config(self, agent_id: int) -> Optional[AgentConfiguration]:
+        return self.agent_configs.get(agent_id, None)

+ 29 - 4
pqai_agent/agent_service.py

@@ -18,6 +18,8 @@ from apscheduler.schedulers.background import BackgroundScheduler
 from sqlalchemy.orm import sessionmaker
 
 from pqai_agent import configs
+from pqai_agent.abtest.utils import get_abtest_info
+from pqai_agent.agent_config_manager import AgentConfigManager
 from pqai_agent.agents.message_reply_agent import MessageReplyAgent
 from pqai_agent.configs import apollo_config
 from pqai_agent.exceptions import NoRetryException
@@ -29,10 +31,12 @@ from pqai_agent.history_dialogue_service import HistoryDialogueDatabase
 from pqai_agent.push_service import PushScanThread, PushTaskWorkerPool
 from pqai_agent.rate_limiter import MessageSenderRateLimiter
 from pqai_agent.response_type_detector import ResponseTypeDetector
+from pqai_agent.service_module_manager import ServiceModuleManager
 from pqai_agent.user_manager import UserManager, UserRelationManager
 from pqai_agent.message_queue_backend import MessageQueueBackend, AliyunRocketMQQueueBackend
 from pqai_agent.user_profile_extractor import UserProfileExtractor
 from pqai_agent.mq_message import MessageType, MqMessage, MessageChannel
+from pqai_agent.utils.agent_abtest_utils import get_agent_abtest_config
 from pqai_agent.utils.db_utils import create_ai_agent_db_engine
 
 
@@ -61,7 +65,7 @@ class AgentService:
         self.agent_registry: Dict[str, DialogueManager] = {}
         self.history_dialogue_db = HistoryDialogueDatabase(self.config['database']['ai_agent'])
         self.agent_db_engine = create_ai_agent_db_engine()
-        self.AgentDBSession = sessionmaker(bind=self.agent_db_engine)
+        self.agent_db_session_maker = sessionmaker(bind=self.agent_db_engine)
 
         chat_config = self.config['chat_api']['openai_compatible']
         self.text_model_name = chat_config['text_model']
@@ -98,6 +102,10 @@ class AgentService:
 
         self.send_rate_limiter = MessageSenderRateLimiter()
 
+        # Agent配置和实验相关
+        self.service_module_manager = ServiceModuleManager(self.agent_db_session_maker)
+        self.agent_config_manager = AgentConfigManager(self.agent_db_session_maker)
+
     def setup_initiative_conversations(self, schedule_params: Optional[Dict] = None):
         if not schedule_params:
             schedule_params = {'hour': '8,16,20'}
@@ -123,6 +131,11 @@ class AgentService:
             )
             self.msg_scheduler_thread = threading.Thread(target=self.process_scheduler_events)
             self.msg_scheduler_thread.start()
+        # 定时更新模块配置任务
+        self.scheduler.add_job(self.service_module_manager.refresh_configs, 'interval',
+                               seconds=60, id='refresh_module_configs')
+        self.scheduler.add_job(self.agent_config_manager.refresh_configs, 'interval',
+                               seconds=60, id='refresh_agent_configs')
         self.scheduler.start()
 
     def process_scheduler_events(self):
@@ -149,7 +162,7 @@ class AgentService:
         agent_key = 'agent_{}_{}'.format(staff_id, user_id)
         if agent_key not in self.agent_registry:
             self.agent_registry[agent_key] = DialogueManager(
-                staff_id, user_id, self.user_manager, self.agent_state_cache, self.AgentDBSession)
+                staff_id, user_id, self.user_manager, self.agent_state_cache, self.agent_db_session_maker)
         agent = self.agent_registry[agent_key]
         agent.refresh_profile()
         return agent
@@ -240,7 +253,12 @@ class AgentService:
             sys.exit(0)
 
     def _update_user_profile(self, user_id, user_profile, recent_dialogue: List[Dict]):
-        profile_to_update = self.user_profile_extractor.extract_profile_info_v2(user_profile, recent_dialogue)
+        agent_info = get_agent_abtest_config('profile_extractor', user_id, self.service_module_manager, self.agent_config_manager)
+        if agent_info:
+            prompt_template = agent_info.task_prompt
+        else:
+            prompt_template = None
+        profile_to_update = self.user_profile_extractor.extract_profile_info_v2(user_profile, recent_dialogue, prompt_template)
         if not profile_to_update:
             logger.debug("user_id: {}, no profile info extracted".format(user_id))
             return
@@ -435,7 +453,14 @@ class AgentService:
             return None
 
     def _get_chat_response_v2(self, main_agent: DialogueManager) -> List[Dict]:
-        chat_agent = MessageReplyAgent()
+        agent_config = get_agent_abtest_config('chat', main_agent.user_id,
+                                               self.service_module_manager, self.agent_config_manager)
+        if agent_config:
+            chat_agent = MessageReplyAgent(model=agent_config.execution_model,
+                                           system_prompt=agent_config.system_prompt,
+                                           tools=None)
+        else:
+            chat_agent = MessageReplyAgent()
         chat_responses = chat_agent.generate_message(
             context=main_agent.get_prompt_context(None),
             dialogue_history=main_agent.dialogue_history[-100:]

+ 8 - 32
pqai_agent/agents/message_push_agent.py

@@ -1,10 +1,8 @@
-import datetime
 from typing import Optional, List, Dict
 
-from pqai_agent.agents.simple_chat_agent import SimpleOpenAICompatibleChatAgent
+from pqai_agent.agents.multimodal_chat_agent import MultiModalChatAgent
 from pqai_agent.chat_service import VOLCENGINE_MODEL_DEEPSEEK_V3
 from pqai_agent.logging_service import logger
-from pqai_agent.mq_message import MessageType
 from pqai_agent.toolkit.function_tool import FunctionTool
 from pqai_agent.toolkit.image_describer import ImageDescriber
 from pqai_agent.toolkit.message_notifier import MessageNotifier
@@ -120,46 +118,24 @@ QUERY_PROMPT_TEMPLATE = """现在,请通过多步思考,以客服的角色
 Now, start to process your task. Please think step by step.
  """
 
-class MessagePushAgent(SimpleOpenAICompatibleChatAgent):
+class MessagePushAgent(MultiModalChatAgent):
     """A specialized agent for message push tasks."""
 
     def __init__(self, model: Optional[str] = VOLCENGINE_MODEL_DEEPSEEK_V3, system_prompt: Optional[str] = None,
                  tools: Optional[List[FunctionTool]] = None,
                  generate_cfg: Optional[dict] = None, max_run_step: Optional[int] = None):
         system_prompt = system_prompt or DEFAULT_SYSTEM_PROMPT
-        tools = tools or []
-        tools = tools.copy()
-        tools.extend([
-            *ImageDescriber().get_tools(),
-            *MessageNotifier().get_tools(),
-        ])
+        if tools is None:
+            tools = [
+                *ImageDescriber().get_tools(),
+                *MessageNotifier().get_tools()
+            ]
         super().__init__(model, system_prompt, tools, generate_cfg, max_run_step)
 
     def generate_message(self, context: Dict, dialogue_history: List[Dict],
                          query_prompt_template: Optional[str] = None) -> List[Dict]:
-        formatted_dialogue = MessagePushAgent.compose_dialogue(dialogue_history)
         query_prompt_template = query_prompt_template or QUERY_PROMPT_TEMPLATE
-        query = query_prompt_template.format(**context, dialogue_history=formatted_dialogue)
-        self.run(query)
-        result = []
-        for tool_call in self.tool_call_records:
-            if tool_call['name'] == MessageNotifier.output_multimodal_message.__name__:
-                result.append(tool_call['arguments']['message'])
-        return result
-
-    @staticmethod
-    def compose_dialogue(dialogue: List[Dict]) -> str:
-        role_map = {'user': '用户', 'assistant': '客服'}
-        messages = []
-        for msg in dialogue:
-            if not msg['content']:
-                continue
-            if msg['role'] not in role_map:
-                continue
-            format_dt = datetime.datetime.fromtimestamp(msg['timestamp'] / 1000).strftime('%Y-%m-%d %H:%M:%S')
-            msg_type = msg.get('type', MessageType.TEXT).description
-            messages.append('[{}][{}][{}]{}'.format(role_map[msg['role']], format_dt, msg_type, msg['content']))
-        return '\n'.join(messages)
+        return self._generate_message(context, dialogue_history, query_prompt_template)
 
 class DummyMessagePushAgent(MessagePushAgent):
     """A dummy agent for testing purposes."""

+ 12 - 34
pqai_agent/agents/message_reply_agent.py

@@ -1,10 +1,8 @@
-import datetime
 from typing import Optional, List, Dict
 
-from pqai_agent.agents.simple_chat_agent import SimpleOpenAICompatibleChatAgent
+from pqai_agent.agents.multimodal_chat_agent import MultiModalChatAgent
 from pqai_agent.chat_service import VOLCENGINE_MODEL_DEEPSEEK_V3
 from pqai_agent.logging_service import logger
-from pqai_agent.mq_message import MessageType
 from pqai_agent.toolkit.function_tool import FunctionTool
 from pqai_agent.toolkit.image_describer import ImageDescriber
 from pqai_agent.toolkit.message_notifier import MessageNotifier
@@ -86,44 +84,24 @@ QUERY_PROMPT_TEMPLATE = """现在,请以客服的角色分析以下会话并
 Now, start to process your task. Please think step by step.
  """
 
-class MessageReplyAgent(SimpleOpenAICompatibleChatAgent):
+class MessageReplyAgent(MultiModalChatAgent):
     """A specialized agent for message reply tasks."""
 
     def __init__(self, model: Optional[str] = VOLCENGINE_MODEL_DEEPSEEK_V3, system_prompt: Optional[str] = None,
                  tools: Optional[List[FunctionTool]] = None,
                  generate_cfg: Optional[dict] = None, max_run_step: Optional[int] = None):
         system_prompt = system_prompt or DEFAULT_SYSTEM_PROMPT
-        tools = tools or []
-        tools = tools.copy()
-        tools.extend([
-            *ImageDescriber().get_tools(),
-            *MessageNotifier().get_tools()
-        ])
+        if tools is None:
+            tools = [
+                *ImageDescriber().get_tools(),
+                *MessageNotifier().get_tools()
+            ]
         super().__init__(model, system_prompt, tools, generate_cfg, max_run_step)
 
-    def generate_message(self, context: Dict, dialogue_history: List[Dict]) -> List[Dict]:
-        formatted_dialogue = MessageReplyAgent.compose_dialogue(dialogue_history)
-        query = QUERY_PROMPT_TEMPLATE.format(**context, dialogue_history=formatted_dialogue)
-        self.run(query)
-        result = []
-        for tool_call in self.tool_call_records:
-            if tool_call['name'] == MessageNotifier.output_multimodal_message.__name__:
-                result.append(tool_call['arguments']['message'])
-        return result
-
-    @staticmethod
-    def compose_dialogue(dialogue: List[Dict]) -> str:
-        role_map = {'user': '用户', 'assistant': '客服'}
-        messages = []
-        for msg in dialogue:
-            if not msg['content']:
-                continue
-            if msg['role'] not in role_map:
-                continue
-            format_dt = datetime.datetime.fromtimestamp(msg['timestamp'] / 1000).strftime('%Y-%m-%d %H:%M:%S')
-            msg_type = msg.get('type', MessageType.TEXT).description
-            messages.append('[{}][{}][{}]{}'.format(role_map[msg['role']], format_dt, msg_type, msg['content']))
-        return '\n'.join(messages)
+    def generate_message(self, context: Dict, dialogue_history: List[Dict],
+                         query_prompt_template: Optional[str] = None) -> List[Dict]:
+        query_prompt_template = query_prompt_template or QUERY_PROMPT_TEMPLATE
+        return self._generate_message(context, dialogue_history, query_prompt_template)
 
 class DummyMessageReplyAgent(MessageReplyAgent):
     """A dummy agent for testing purposes."""
@@ -131,7 +109,7 @@ class DummyMessageReplyAgent(MessageReplyAgent):
     def __init__(self, *args, **kwargs):
         super().__init__(*args, **kwargs)
 
-    def generate_message(self, context: Dict, dialogue_history: List[Dict]) -> List[Dict]:
+    def generate_message(self, context: Dict, dialogue_history: List[Dict], query_prompt_template = None) -> List[Dict]:
         logger.debug(f"DummyMessageReplyAgent.generate_message called, context: {context}")
         result = [{"type": "text", "content": "测试消息: {agent_name} -> {nickname}".format(**context)},
                   {"type": "image", "content": "https://example.com/test_image.jpg"}]

+ 52 - 0
pqai_agent/agents/multimodal_chat_agent.py

@@ -0,0 +1,52 @@
+import datetime
+from abc import abstractmethod
+from typing import Optional, List, Dict
+
+from pqai_agent.agents.simple_chat_agent import SimpleOpenAICompatibleChatAgent
+from pqai_agent.logging_service import logger
+from pqai_agent.mq_message import MessageType
+from pqai_agent.toolkit.function_tool import FunctionTool
+from pqai_agent.toolkit.message_notifier import MessageNotifier
+
+
+class MultiModalChatAgent(SimpleOpenAICompatibleChatAgent):
+    """A specialized agent for message reply tasks."""
+
+    def __init__(self, model: str, system_prompt: str,
+                 tools: Optional[List[FunctionTool]] = None,
+                 generate_cfg: Optional[dict] = None, max_run_step: Optional[int] = None):
+        super().__init__(model, system_prompt, tools, generate_cfg, max_run_step)
+        if 'output_multimodal_message' not in self.tool_map:
+            self.add_tool(FunctionTool(MessageNotifier.output_multimodal_message))
+        if 'message_notify_user' not in self.tool_map:
+            self.add_tool(FunctionTool(MessageNotifier.message_notify_user))
+
+    @abstractmethod
+    def generate_message(self, context: Dict, dialogue_history: List[Dict],
+                         query_prompt_template: str) -> List[Dict]:
+        pass
+
+    def _generate_message(self, context: Dict, dialogue_history: List[Dict],
+                         query_prompt_template: str) -> List[Dict]:
+        formatted_dialogue = MultiModalChatAgent.compose_dialogue(dialogue_history)
+        query = query_prompt_template.format(**context, dialogue_history=formatted_dialogue)
+        self.run(query)
+        result = []
+        for tool_call in self.tool_call_records:
+            if tool_call['name'] == MessageNotifier.output_multimodal_message.__name__:
+                result.append(tool_call['arguments']['message'])
+        return result
+
+    @staticmethod
+    def compose_dialogue(dialogue: List[Dict]) -> str:
+        role_map = {'user': '用户', 'assistant': '客服'}
+        messages = []
+        for msg in dialogue:
+            if not msg['content']:
+                continue
+            if msg['role'] not in role_map:
+                continue
+            format_dt = datetime.datetime.fromtimestamp(msg['timestamp'] / 1000).strftime('%Y-%m-%d %H:%M:%S')
+            msg_type = msg.get('type', MessageType.TEXT).description
+            messages.append('[{}][{}][{}]{}'.format(role_map[msg['role']], format_dt, msg_type, msg['content']))
+        return '\n'.join(messages)

+ 11 - 1
pqai_agent/agents/simple_chat_agent.py

@@ -15,12 +15,22 @@ class SimpleOpenAICompatibleChatAgent:
         self.model = model
         self.llm_client = OpenAICompatible.create_client(model)
         self.system_prompt = system_prompt
-        self.tools = tools or []
+        if tools:
+            self.tools = [*tools]
+        else:
+            self.tools = []
         self.tool_map = {tool.name: tool for tool in self.tools}
         self.generate_cfg = generate_cfg or {}
         self.max_run_step = max_run_step or DEFAULT_MAX_RUN_STEPS
         self.tool_call_records = []
 
+    def add_tool(self, tool: FunctionTool):
+        """添加一个工具到Agent中"""
+        if tool.name in self.tool_map:
+            logger.warning(f"Tool {tool.name} already exists, replacing it.")
+        self.tools.append(tool)
+        self.tool_map[tool.name] = tool
+
     def run(self, user_input: str) -> str:
         messages = [{"role": "system", "content": self.system_prompt}]
         tools = [tool.get_openai_tool_schema() for tool in self.tools]

+ 3 - 3
pqai_agent/dialogue_manager.py

@@ -102,7 +102,7 @@ class DialogueStateCache:
 
 class DialogueManager:
     def __init__(self, staff_id: str, user_id: str, user_manager: UserManager, state_cache: DialogueStateCache,
-                 AgentDBSession: sessionmaker[Session]):
+                 agent_db_session_maker: sessionmaker[Session]):
         config = configs.get()
 
         self.staff_id = staff_id
@@ -125,7 +125,7 @@ class DialogueManager:
         self.history_dialogue_service = HistoryDialogueService(
             config['storage']['history_dialogue']['api_base_url']
         )
-        self.AgentDBSession = AgentDBSession
+        self.agent_db_session_maker = agent_db_session_maker
         self._recover_state()
         # 由于本地状态管理过于复杂,引入事务机制做状态回滚
         self._uncommited_state_change = []
@@ -174,7 +174,7 @@ class DialogueManager:
         else:
             # 默认设置
             self.last_interaction_time_ms = int(time.time() * 1000) - minutes_to_get * 60 * 1000
-        with self.AgentDBSession() as session:
+        with self.agent_db_session_maker() as session:
             # 读取数据库中的最后一次交互时间
             query = session.query(AgentPushRecord).filter(
                 AgentPushRecord.staff_id == self.staff_id,

+ 15 - 3
pqai_agent/push_service.py

@@ -17,6 +17,7 @@ from pqai_agent.configs import apollo_config
 from pqai_agent.data_models.agent_push_record import AgentPushRecord
 from pqai_agent.logging_service import logger
 from pqai_agent.mq_message import MessageType
+from pqai_agent.utils.agent_abtest_utils import get_agent_abtest_config
 
 
 class TaskType(Enum):
@@ -168,7 +169,7 @@ class PushTaskWorkerPool:
                     if response:
                         item["type"] = message_type
                         messages_to_send.append(item)
-            with self.agent_service.AgentDBSession() as session:
+            with self.agent_service.agent_db_session_maker() as session:
                 msg_list = [{"type": msg["type"].value, "content": msg["content"]} for msg in messages_to_send]
                 record = AgentPushRecord(staff_id=staff_id, user_id=user_id,
                                          content=json.dumps(msg_list, ensure_ascii=False),
@@ -192,12 +193,23 @@ class PushTaskWorkerPool:
             staff_id = task['staff_id']
             user_id = task['user_id']
             main_agent = self.agent_service.get_agent_instance(staff_id, user_id)
-            push_agent = MessagePushAgent()
+            agent_config = get_agent_abtest_config('push', user_id,
+                                                   self.agent_service.service_module_manager,
+                                                   self.agent_service.agent_config_manager)
+            if agent_config:
+                push_agent = MessagePushAgent(model=agent_config.execution_model,
+                                              system_prompt=agent_config.system_prompt,
+                                              tools=None)
+                query_prompt_template = agent_config.task_prompt
+            else:
+                push_agent = MessagePushAgent()
+                query_prompt_template = None
             message_to_user = push_agent.generate_message(
                 context=main_agent.get_prompt_context(None),
                 dialogue_history=self.agent_service.history_dialogue_db.get_dialogue_history_backward(
                     staff_id, user_id, main_agent.last_interaction_time_ms, limit=100
-                )
+                ),
+                query_prompt_template=query_prompt_template
             )
             if message_to_user:
                 rmq_message = generate_task_rmq_message(

+ 27 - 0
pqai_agent/service_module_manager.py

@@ -0,0 +1,27 @@
+from pqai_agent.data_models.service_module import ServiceModule, ModuleAgentType
+from pqai_agent.logging_service import logger
+
+class ServiceModuleManager:
+    def __init__(self, session_maker):
+        self.session_maker = session_maker
+        self.module_configs = {}
+        self.refresh_configs()
+
+    def refresh_configs(self):
+        try:
+            with self.session_maker() as session:
+                data = session.query(ServiceModule).filter_by(is_delete=False).all()
+                module_configs = {}
+                for module in data:
+                    module_configs[module.name] = {
+                        'display_name': module.display_name,
+                        'default_agent_type': ModuleAgentType(module.default_agent_type),
+                        'default_agent_id': module.default_agent_id
+                    }
+                self.module_configs = module_configs
+                logger.debug(f"Refreshed module configurations: {module_configs}")
+        except Exception as e:
+            logger.error(f"Error refreshing module configs: {e}")
+
+    def get_module_config(self, module_name: str):
+        return self.module_configs.get(module_name)

+ 0 - 23
pqai_agent/user_manager.py

@@ -292,7 +292,6 @@ class MySQLUserManager(UserManager):
             "data": staff_list
         }
 
-
 class LocalUserRelationManager(UserRelationManager):
     def __init__(self):
         pass
@@ -434,25 +433,3 @@ class MySQLUserRelationManager(UserRelationManager):
         except Exception as e:
             logger.error(f"stop_user_daily_push failed: {e}")
             return False
-
-
-if __name__ == '__main__':
-    config = configs.get()
-    user_db_config = config['storage']['user']
-    staff_db_config = config['storage']['staff']
-    user_manager = MySQLUserManager(user_db_config['mysql'], user_db_config['table'], staff_db_config['table'])
-    user_profile = user_manager.get_user_profile('7881301263964433')
-    print(user_profile)
-
-    wecom_db_config = config['storage']['user_relation']
-    user_relation_manager = MySQLUserRelationManager(
-        user_db_config['mysql'], wecom_db_config['mysql'],
-        config['storage']['staff']['table'],
-        user_db_config['table'],
-        wecom_db_config['table']['staff'],
-        wecom_db_config['table']['relation'],
-        wecom_db_config['table']['user']
-    )
-    # all_staff_users = user_relation_manager.list_staff_users()
-    user_tags = user_relation_manager.get_user_tags('7881302078008656')
-    print(user_tags)

+ 1 - 0
pqai_agent/user_profile_extractor.py

@@ -156,6 +156,7 @@ class UserProfileExtractor:
         :return:
         """
         if configs.get().get('debug_flags', {}).get('disable_llm_api_call', False):
+            logger.debug("skip LLM API call.")
             return None
 
         try: