فهرست منبع

Merge branch 'master' into dev-xym-add-test-task

xueyiming 3 روز پیش
والد
کامیت
3df6d176e6

+ 1 - 0
.gitignore

@@ -1,3 +1,4 @@
+image_descriptions_cache/
 # ---> Python
 # Byte-compiled / optimized / DLL files
 __pycache__/

+ 8 - 7
pqai_agent/abtest/client.py

@@ -50,7 +50,7 @@ class ExperimentClient:
 
         for project_data in projects:
             project = Project(name=project_data.name, project_id=project_data.project_id)
-            logger.debug(f"[Project] {project_data}")
+            # logger.debug(f"[Project] {project_data}")
 
             # 获取项目的域
             list_domain_req = ListDomainsRequest()
@@ -66,7 +66,7 @@ class ExperimentClient:
                                 is_default_domain=domain_data.is_default_domain,
                                 exp_layer_id=domain_data.layer_id,
                                 debug_users=domain_data.debug_users)
-                logger.debug(f"[Domain] {domain_data}")
+                # logger.debug(f"[Domain] {domain_data}")
                 if domain.is_default_domain:
                     project.set_default_domain(domain)
                 domain.init()
@@ -84,7 +84,7 @@ class ExperimentClient:
                 list_layer_req.domain_id = str(domain.id)
                 layers_response = self.client.list_layers(list_layer_req)
                 for layer_data in layers_response.body.layers:
-                    logger.debug(f'[Layer] {layer_data}')
+                    # logger.debug(f'[Layer] {layer_data}')
                     layer = Layer(id=int(layer_data.layer_id), name=layer_data.name)
                     project.add_layer(layer)
 
@@ -96,7 +96,7 @@ class ExperimentClient:
                     experiments_response = self.client.list_experiments(list_experiment_req)
 
                     for experiment_data in experiments_response.body.experiments:
-                        logger.debug(f'[Experiment] {experiment_data}')
+                        # logger.debug(f'[Experiment] {experiment_data}')
                         # FIXME: Java SDK中有特殊处理
                         crowd_ids = experiment_data.crowd_ids if experiment_data.crowd_ids else ""
                         experiment = Experiment(id=int(experiment_data.experiment_id), bucket_type=experiment_data.bucket_type,
@@ -111,9 +111,8 @@ class ExperimentClient:
                         list_exp_ver_req = ListExperimentVersionsRequest()
                         list_exp_ver_req.experiment_id = int(experiment.id)
                         versions_response = self.client.list_experiment_versions(list_exp_ver_req)
-                        logger.debug(versions_response)
                         for version_data in versions_response.body.experiment_versions:
-                            logger.debug(f'[ExperimentVersion] {version_data}')
+                            # logger.debug(f'[ExperimentVersion] {version_data}')
                             version = ExperimentVersion(exp_version_id=version_data.experiment_version_id,
                                                         exp_id=experiment.id,
                                                         flow=int(version_data.flow),
@@ -150,6 +149,8 @@ class ExperimentClient:
         experiment_result = ExperimentResult(project=project, experiment_context=experiment_context)
 
         self._match_domain(project.default_domain, experiment_result)
+        matched_versions = [str(ver.id) for ver in experiment_result.experiment_versions]
+        logger.debug(f"Matched experiment, uid[{experiment_context.uid}], versions[{','.join(matched_versions)}], params: {experiment_result.params}")
         experiment_result.init()
         return experiment_result
 
@@ -176,7 +177,7 @@ class ExperimentClient:
 
         for domain in layer.domains:
             if domain.match_debug_users(experiment_result.experiment_context):
-                logger.debug(f"Matched debug user for domain: {domain.id}")
+                # logger.debug(f"Matched debug user for domain: {domain.id}")
                 self._match_domain(domain, experiment_result)
 
         hash_key = f"{experiment_result.experiment_context.uid}_LAYER{layer.id}"

+ 7 - 1
pqai_agent/abtest/models.py

@@ -3,6 +3,9 @@ import json
 from dataclasses import dataclass, field
 import hashlib
 
+from pqai_agent.logging_service import logger
+
+
 class FNV:
     INIT64 = int("cbf29ce484222325", 16)
     PRIME64 = int("100000001b3", 16)
@@ -232,7 +235,10 @@ class ExperimentResult:
         self.exp_id = ""
 
     def add_params(self, params: Dict[str, str]):
-        self.params.update(params)
+        for key, value in params.items():
+            if key in self.params:
+                logger.warning(f"Duplicate key '{key}' in params, overwriting value: {self.params[key]} with {value}")
+            self.params[key] = value
 
     def add_experiment_version(self, version):
         self.experiment_versions.append(version)

+ 7 - 0
pqai_agent/abtest/utils.py

@@ -0,0 +1,7 @@
+from pqai_agent.abtest.models import ExperimentContext
+from pqai_agent.abtest.client import get_client
+
+def get_abtest_info(uid: str):
+    client = get_client()
+    exp_ctx = ExperimentContext(uid=uid)
+    return client.match_experiment('PQAgent', exp_ctx)

+ 25 - 0
pqai_agent/agent_config_manager.py

@@ -0,0 +1,25 @@
+from typing import Dict, Optional
+
+from pqai_agent.data_models.agent_configuration import AgentConfiguration
+from pqai_agent.logging_service import logger
+
+class AgentConfigManager:
+    def __init__(self, session_maker):
+        self.session_maker = session_maker
+        self.agent_configs: Dict[int, AgentConfiguration] = {}
+        self.refresh_configs()
+
+    def refresh_configs(self):
+        try:
+            with self.session_maker() as session:
+                data = session.query(AgentConfiguration).filter_by(is_delete=False).all()
+                agent_configs = {}
+                for config in data:
+                    agent_configs[config.id] = config
+            self.agent_configs = agent_configs
+            logger.debug(f"Refreshed agent configurations: {list(self.agent_configs.keys())}")
+        except Exception as e:
+            logger.error(f"Failed to refresh agent configurations: {e}")
+
+    def get_config(self, agent_id: int) -> Optional[AgentConfiguration]:
+        return self.agent_configs.get(agent_id, None)

+ 34 - 4
pqai_agent/agent_service.py

@@ -18,6 +18,8 @@ from apscheduler.schedulers.background import BackgroundScheduler
 from sqlalchemy.orm import sessionmaker
 
 from pqai_agent import configs
+from pqai_agent.abtest.utils import get_abtest_info
+from pqai_agent.agent_config_manager import AgentConfigManager
 from pqai_agent.agents.message_reply_agent import MessageReplyAgent
 from pqai_agent.configs import apollo_config
 from pqai_agent.exceptions import NoRetryException
@@ -29,10 +31,13 @@ from pqai_agent.history_dialogue_service import HistoryDialogueDatabase
 from pqai_agent.push_service import PushScanThread, PushTaskWorkerPool
 from pqai_agent.rate_limiter import MessageSenderRateLimiter
 from pqai_agent.response_type_detector import ResponseTypeDetector
+from pqai_agent.service_module_manager import ServiceModuleManager
+from pqai_agent.toolkit import get_tools
 from pqai_agent.user_manager import UserManager, UserRelationManager
 from pqai_agent.message_queue_backend import MessageQueueBackend, AliyunRocketMQQueueBackend
 from pqai_agent.user_profile_extractor import UserProfileExtractor
 from pqai_agent.mq_message import MessageType, MqMessage, MessageChannel
+from pqai_agent.utils.agent_abtest_utils import get_agent_abtest_config
 from pqai_agent.utils.db_utils import create_ai_agent_db_engine
 
 
@@ -61,7 +66,7 @@ class AgentService:
         self.agent_registry: Dict[str, DialogueManager] = {}
         self.history_dialogue_db = HistoryDialogueDatabase(self.config['database']['ai_agent'])
         self.agent_db_engine = create_ai_agent_db_engine()
-        self.AgentDBSession = sessionmaker(bind=self.agent_db_engine)
+        self.agent_db_session_maker = sessionmaker(bind=self.agent_db_engine)
 
         chat_config = self.config['chat_api']['openai_compatible']
         self.text_model_name = chat_config['text_model']
@@ -98,6 +103,10 @@ class AgentService:
 
         self.send_rate_limiter = MessageSenderRateLimiter()
 
+        # Agent配置和实验相关
+        self.service_module_manager = ServiceModuleManager(self.agent_db_session_maker)
+        self.agent_config_manager = AgentConfigManager(self.agent_db_session_maker)
+
     def setup_initiative_conversations(self, schedule_params: Optional[Dict] = None):
         if not schedule_params:
             schedule_params = {'hour': '8,16,20'}
@@ -123,6 +132,11 @@ class AgentService:
             )
             self.msg_scheduler_thread = threading.Thread(target=self.process_scheduler_events)
             self.msg_scheduler_thread.start()
+        # 定时更新模块配置任务
+        self.scheduler.add_job(self.service_module_manager.refresh_configs, 'interval',
+                               seconds=60, id='refresh_module_configs')
+        self.scheduler.add_job(self.agent_config_manager.refresh_configs, 'interval',
+                               seconds=60, id='refresh_agent_configs')
         self.scheduler.start()
 
     def process_scheduler_events(self):
@@ -149,7 +163,7 @@ class AgentService:
         agent_key = 'agent_{}_{}'.format(staff_id, user_id)
         if agent_key not in self.agent_registry:
             self.agent_registry[agent_key] = DialogueManager(
-                staff_id, user_id, self.user_manager, self.agent_state_cache, self.AgentDBSession)
+                staff_id, user_id, self.user_manager, self.agent_state_cache, self.agent_db_session_maker)
         agent = self.agent_registry[agent_key]
         agent.refresh_profile()
         return agent
@@ -240,7 +254,12 @@ class AgentService:
             sys.exit(0)
 
     def _update_user_profile(self, user_id, user_profile, recent_dialogue: List[Dict]):
-        profile_to_update = self.user_profile_extractor.extract_profile_info_v2(user_profile, recent_dialogue)
+        agent_info = get_agent_abtest_config('profile_extractor', user_id, self.service_module_manager, self.agent_config_manager)
+        if agent_info:
+            prompt_template = agent_info.task_prompt
+        else:
+            prompt_template = None
+        profile_to_update = self.user_profile_extractor.extract_profile_info_v2(user_profile, recent_dialogue, prompt_template)
         if not profile_to_update:
             logger.debug("user_id: {}, no profile info extracted".format(user_id))
             return
@@ -396,8 +415,12 @@ class AgentService:
             return
 
         push_scan_threads = []
+        whitelist_staffs = apollo_config.get_json_value("agent_initiate_whitelist_staffs", [])
         for staff in self.user_relation_manager.list_staffs():
             staff_id = staff['third_party_user_id']
+            if staff_id not in whitelist_staffs:
+                logger.info(f"staff[{staff_id}] is not in whitelist, skip")
+                continue
             scan_thread = threading.Thread(target=PushScanThread(
                 staff_id, self, self.push_task_rmq_topic, self.push_task_producer).run)
             scan_thread.start()
@@ -435,7 +458,14 @@ class AgentService:
             return None
 
     def _get_chat_response_v2(self, main_agent: DialogueManager) -> List[Dict]:
-        chat_agent = MessageReplyAgent()
+        agent_config = get_agent_abtest_config('chat', main_agent.user_id,
+                                               self.service_module_manager, self.agent_config_manager)
+        if agent_config:
+            chat_agent = MessageReplyAgent(model=agent_config.execution_model,
+                                           system_prompt=agent_config.system_prompt,
+                                           tools=get_tools(agent_config.tools))
+        else:
+            chat_agent = MessageReplyAgent()
         chat_responses = chat_agent.generate_message(
             context=main_agent.get_prompt_context(None),
             dialogue_history=main_agent.dialogue_history[-100:]

+ 8 - 32
pqai_agent/agents/message_push_agent.py

@@ -1,10 +1,8 @@
-import datetime
 from typing import Optional, List, Dict
 
-from pqai_agent.agents.simple_chat_agent import SimpleOpenAICompatibleChatAgent
+from pqai_agent.agents.multimodal_chat_agent import MultiModalChatAgent
 from pqai_agent.chat_service import VOLCENGINE_MODEL_DEEPSEEK_V3
 from pqai_agent.logging_service import logger
-from pqai_agent.mq_message import MessageType
 from pqai_agent.toolkit.function_tool import FunctionTool
 from pqai_agent.toolkit.image_describer import ImageDescriber
 from pqai_agent.toolkit.message_notifier import MessageNotifier
@@ -120,46 +118,24 @@ QUERY_PROMPT_TEMPLATE = """现在,请通过多步思考,以客服的角色
 Now, start to process your task. Please think step by step.
  """
 
-class MessagePushAgent(SimpleOpenAICompatibleChatAgent):
+class MessagePushAgent(MultiModalChatAgent):
     """A specialized agent for message push tasks."""
 
     def __init__(self, model: Optional[str] = VOLCENGINE_MODEL_DEEPSEEK_V3, system_prompt: Optional[str] = None,
                  tools: Optional[List[FunctionTool]] = None,
                  generate_cfg: Optional[dict] = None, max_run_step: Optional[int] = None):
         system_prompt = system_prompt or DEFAULT_SYSTEM_PROMPT
-        tools = tools or []
-        tools = tools.copy()
-        tools.extend([
-            *ImageDescriber().get_tools(),
-            *MessageNotifier().get_tools(),
-        ])
+        if tools is None:
+            tools = [
+                *ImageDescriber().get_tools(),
+                *MessageNotifier().get_tools()
+            ]
         super().__init__(model, system_prompt, tools, generate_cfg, max_run_step)
 
     def generate_message(self, context: Dict, dialogue_history: List[Dict],
                          query_prompt_template: Optional[str] = None) -> List[Dict]:
-        formatted_dialogue = MessagePushAgent.compose_dialogue(dialogue_history)
         query_prompt_template = query_prompt_template or QUERY_PROMPT_TEMPLATE
-        query = query_prompt_template.format(**context, dialogue_history=formatted_dialogue)
-        self.run(query)
-        result = []
-        for tool_call in self.tool_call_records:
-            if tool_call['name'] == MessageNotifier.output_multimodal_message.__name__:
-                result.append(tool_call['arguments']['message'])
-        return result
-
-    @staticmethod
-    def compose_dialogue(dialogue: List[Dict]) -> str:
-        role_map = {'user': '用户', 'assistant': '客服'}
-        messages = []
-        for msg in dialogue:
-            if not msg['content']:
-                continue
-            if msg['role'] not in role_map:
-                continue
-            format_dt = datetime.datetime.fromtimestamp(msg['timestamp'] / 1000).strftime('%Y-%m-%d %H:%M:%S')
-            msg_type = msg.get('type', MessageType.TEXT).description
-            messages.append('[{}][{}][{}]{}'.format(role_map[msg['role']], format_dt, msg_type, msg['content']))
-        return '\n'.join(messages)
+        return self._generate_message(context, dialogue_history, query_prompt_template)
 
 class DummyMessagePushAgent(MessagePushAgent):
     """A dummy agent for testing purposes."""

+ 12 - 34
pqai_agent/agents/message_reply_agent.py

@@ -1,10 +1,8 @@
-import datetime
 from typing import Optional, List, Dict
 
-from pqai_agent.agents.simple_chat_agent import SimpleOpenAICompatibleChatAgent
+from pqai_agent.agents.multimodal_chat_agent import MultiModalChatAgent
 from pqai_agent.chat_service import VOLCENGINE_MODEL_DEEPSEEK_V3
 from pqai_agent.logging_service import logger
-from pqai_agent.mq_message import MessageType
 from pqai_agent.toolkit.function_tool import FunctionTool
 from pqai_agent.toolkit.image_describer import ImageDescriber
 from pqai_agent.toolkit.message_notifier import MessageNotifier
@@ -86,44 +84,24 @@ QUERY_PROMPT_TEMPLATE = """现在,请以客服的角色分析以下会话并
 Now, start to process your task. Please think step by step.
  """
 
-class MessageReplyAgent(SimpleOpenAICompatibleChatAgent):
+class MessageReplyAgent(MultiModalChatAgent):
     """A specialized agent for message reply tasks."""
 
     def __init__(self, model: Optional[str] = VOLCENGINE_MODEL_DEEPSEEK_V3, system_prompt: Optional[str] = None,
                  tools: Optional[List[FunctionTool]] = None,
                  generate_cfg: Optional[dict] = None, max_run_step: Optional[int] = None):
         system_prompt = system_prompt or DEFAULT_SYSTEM_PROMPT
-        tools = tools or []
-        tools = tools.copy()
-        tools.extend([
-            *ImageDescriber().get_tools(),
-            *MessageNotifier().get_tools()
-        ])
+        if tools is None:
+            tools = [
+                *ImageDescriber().get_tools(),
+                *MessageNotifier().get_tools()
+            ]
         super().__init__(model, system_prompt, tools, generate_cfg, max_run_step)
 
-    def generate_message(self, context: Dict, dialogue_history: List[Dict]) -> List[Dict]:
-        formatted_dialogue = MessageReplyAgent.compose_dialogue(dialogue_history)
-        query = QUERY_PROMPT_TEMPLATE.format(**context, dialogue_history=formatted_dialogue)
-        self.run(query)
-        result = []
-        for tool_call in self.tool_call_records:
-            if tool_call['name'] == MessageNotifier.output_multimodal_message.__name__:
-                result.append(tool_call['arguments']['message'])
-        return result
-
-    @staticmethod
-    def compose_dialogue(dialogue: List[Dict]) -> str:
-        role_map = {'user': '用户', 'assistant': '客服'}
-        messages = []
-        for msg in dialogue:
-            if not msg['content']:
-                continue
-            if msg['role'] not in role_map:
-                continue
-            format_dt = datetime.datetime.fromtimestamp(msg['timestamp'] / 1000).strftime('%Y-%m-%d %H:%M:%S')
-            msg_type = msg.get('type', MessageType.TEXT).description
-            messages.append('[{}][{}][{}]{}'.format(role_map[msg['role']], format_dt, msg_type, msg['content']))
-        return '\n'.join(messages)
+    def generate_message(self, context: Dict, dialogue_history: List[Dict],
+                         query_prompt_template: Optional[str] = None) -> List[Dict]:
+        query_prompt_template = query_prompt_template or QUERY_PROMPT_TEMPLATE
+        return self._generate_message(context, dialogue_history, query_prompt_template)
 
 class DummyMessageReplyAgent(MessageReplyAgent):
     """A dummy agent for testing purposes."""
@@ -131,7 +109,7 @@ class DummyMessageReplyAgent(MessageReplyAgent):
     def __init__(self, *args, **kwargs):
         super().__init__(*args, **kwargs)
 
-    def generate_message(self, context: Dict, dialogue_history: List[Dict]) -> List[Dict]:
+    def generate_message(self, context: Dict, dialogue_history: List[Dict], query_prompt_template = None) -> List[Dict]:
         logger.debug(f"DummyMessageReplyAgent.generate_message called, context: {context}")
         result = [{"type": "text", "content": "测试消息: {agent_name} -> {nickname}".format(**context)},
                   {"type": "image", "content": "https://example.com/test_image.jpg"}]

+ 53 - 0
pqai_agent/agents/multimodal_chat_agent.py

@@ -0,0 +1,53 @@
+import datetime
+from abc import abstractmethod
+from typing import Optional, List, Dict
+
+from pqai_agent.agents.simple_chat_agent import SimpleOpenAICompatibleChatAgent
+from pqai_agent.logging_service import logger
+from pqai_agent.mq_message import MessageType
+from pqai_agent.toolkit import get_tool
+from pqai_agent.toolkit.function_tool import FunctionTool
+from pqai_agent.toolkit.message_notifier import MessageNotifier
+
+
+class MultiModalChatAgent(SimpleOpenAICompatibleChatAgent):
+    """A specialized agent for message reply tasks."""
+
+    def __init__(self, model: str, system_prompt: str,
+                 tools: Optional[List[FunctionTool]] = None,
+                 generate_cfg: Optional[dict] = None, max_run_step: Optional[int] = None):
+        super().__init__(model, system_prompt, tools, generate_cfg, max_run_step)
+        if 'output_multimodal_message' not in self.tool_map:
+            self.add_tool(get_tool('output_multimodal_message'))
+        if 'message_notify_user' not in self.tool_map:
+            self.add_tool(get_tool('message_notify_user'))
+
+    @abstractmethod
+    def generate_message(self, context: Dict, dialogue_history: List[Dict],
+                         query_prompt_template: str) -> List[Dict]:
+        pass
+
+    def _generate_message(self, context: Dict, dialogue_history: List[Dict],
+                         query_prompt_template: str) -> List[Dict]:
+        formatted_dialogue = MultiModalChatAgent.compose_dialogue(dialogue_history)
+        query = query_prompt_template.format(**context, dialogue_history=formatted_dialogue)
+        self.run(query)
+        result = []
+        for tool_call in self.tool_call_records:
+            if tool_call['name'] == MessageNotifier.output_multimodal_message.__name__:
+                result.append(tool_call['arguments']['message'])
+        return result
+
+    @staticmethod
+    def compose_dialogue(dialogue: List[Dict]) -> str:
+        role_map = {'user': '用户', 'assistant': '客服'}
+        messages = []
+        for msg in dialogue:
+            if not msg['content']:
+                continue
+            if msg['role'] not in role_map:
+                continue
+            format_dt = datetime.datetime.fromtimestamp(msg['timestamp'] / 1000).strftime('%Y-%m-%d %H:%M:%S')
+            msg_type = msg.get('type', MessageType.TEXT).description
+            messages.append('[{}][{}][{}]{}'.format(role_map[msg['role']], format_dt, msg_type, msg['content']))
+        return '\n'.join(messages)

+ 11 - 1
pqai_agent/agents/simple_chat_agent.py

@@ -15,12 +15,22 @@ class SimpleOpenAICompatibleChatAgent:
         self.model = model
         self.llm_client = OpenAICompatible.create_client(model)
         self.system_prompt = system_prompt
-        self.tools = tools or []
+        if tools:
+            self.tools = [*tools]
+        else:
+            self.tools = []
         self.tool_map = {tool.name: tool for tool in self.tools}
         self.generate_cfg = generate_cfg or {}
         self.max_run_step = max_run_step or DEFAULT_MAX_RUN_STEPS
         self.tool_call_records = []
 
+    def add_tool(self, tool: FunctionTool):
+        """添加一个工具到Agent中"""
+        if tool.name in self.tool_map:
+            logger.warning(f"Tool {tool.name} already exists, replacing it.")
+        self.tools.append(tool)
+        self.tool_map[tool.name] = tool
+
     def run(self, user_input: str) -> str:
         messages = [{"role": "system", "content": self.system_prompt}]
         tools = [tool.get_openai_tool_schema() for tool in self.tools]

+ 36 - 0
pqai_agent/clients/relation_stage_client.py

@@ -0,0 +1,36 @@
+from typing import Optional
+
+import requests
+
+from pqai_agent.logging_service import logger
+
+class RelationStageClient:
+    UNKNOWN_RELATION_STAGE = '未知'
+
+    def __init__(self, base_url: Optional[str] = None):
+        base_url = base_url or "http://ai-wechat-hook-internal.piaoquantv.com/analyse/getUserEmployeeRelStage"
+        self.base_url = base_url
+
+    def get_relation_stage(self, staff_id: str, user_id: str) -> str:
+        url = f"{self.base_url}?employeeId={staff_id}&userId={user_id}"
+        response = requests.get(url)
+        if response.status_code != 200:
+            logger.error(f"Request error [{response.status_code}]: {response.text}")
+            return self.UNKNOWN_RELATION_STAGE
+        data = response.json()
+        if not data.get('success', False):
+            logger.error(f"Error in response: {data.get('message', 'no message returned')}")
+            return self.UNKNOWN_RELATION_STAGE
+        if 'data' not in data:
+            logger.error("No 'data' field in response")
+            return self.UNKNOWN_RELATION_STAGE
+        return data.get('data')
+
+if __name__ == "__main__":
+    # Example usage
+    client = RelationStageClient()
+    stage = client.get_relation_stage("1688856125791790", "7881301780233975")
+    if stage:
+        print(f"Relation stage: {stage}")
+    else:
+        print("Failed to retrieve relation stage.")

+ 12 - 4
pqai_agent/dialogue_manager.py

@@ -14,6 +14,7 @@ import cozepy
 from sqlalchemy.orm import sessionmaker, Session
 
 from pqai_agent import configs
+from pqai_agent.clients.relation_stage_client import RelationStageClient
 from pqai_agent.data_models.agent_push_record import AgentPushRecord
 from pqai_agent.logging_service import logger
 from pqai_agent.database import MySQLManager
@@ -102,7 +103,7 @@ class DialogueStateCache:
 
 class DialogueManager:
     def __init__(self, staff_id: str, user_id: str, user_manager: UserManager, state_cache: DialogueStateCache,
-                 AgentDBSession: sessionmaker[Session]):
+                 agent_db_session_maker: sessionmaker[Session]):
         config = configs.get()
 
         self.staff_id = staff_id
@@ -125,7 +126,10 @@ class DialogueManager:
         self.history_dialogue_service = HistoryDialogueService(
             config['storage']['history_dialogue']['api_base_url']
         )
-        self.AgentDBSession = AgentDBSession
+        # FIXME: 实际为无状态接口,不需要每个DialogueManager持有一个单独实例
+        self.relation_stage_client = RelationStageClient()
+        self.relation_stage = self.relation_stage_client.get_relation_stage(staff_id, user_id)
+        self.agent_db_session_maker = agent_db_session_maker
         self._recover_state()
         # 由于本地状态管理过于复杂,引入事务机制做状态回滚
         self._uncommited_state_change = []
@@ -155,6 +159,10 @@ class DialogueManager:
 
     def refresh_profile(self):
         self.staff_profile = self.user_manager.get_staff_profile(self.staff_id)
+        relation_stage = self.relation_stage_client.get_relation_stage(self.staff_id, self.user_id)
+        if relation_stage and relation_stage != self.relation_stage:
+            logger.info(f"staff[{self.staff_id}], user[{self.user_id}]: relation stage changed from {self.relation_stage} to {relation_stage}")
+            self.relation_stage = relation_stage
 
     def _recover_state(self):
         self.current_state, self.previous_state = self.state_cache.get_state(self.staff_id, self.user_id)
@@ -174,7 +182,7 @@ class DialogueManager:
         else:
             # 默认设置
             self.last_interaction_time_ms = int(time.time() * 1000) - minutes_to_get * 60 * 1000
-        with self.AgentDBSession() as session:
+        with self.agent_db_session_maker() as session:
             # 读取数据库中的最后一次交互时间
             query = session.query(AgentPushRecord).filter(
                 AgentPushRecord.staff_id == self.staff_id,
@@ -530,7 +538,6 @@ class DialogueManager:
             return True
         return False
 
-
     def is_in_human_intervention(self) -> bool:
         """检查是否处于人工介入状态"""
         return self.current_state == DialogueState.HUMAN_INTERVENTION
@@ -559,6 +566,7 @@ class DialogueManager:
             "last_interaction_interval": self._get_hours_since_last_interaction(2),
             "if_first_interaction": True if self.previous_state == DialogueState.INITIALIZED else False,
             "if_active_greeting": False if user_message else True,
+            "relation_stage": self.relation_stage,
             "formatted_staff_profile": prompt_utils.format_agent_profile(self.staff_profile),
             "formatted_user_profile": prompt_utils.format_user_profile(self.user_profile),
             **self.user_profile,

+ 31 - 7
pqai_agent/push_service.py

@@ -12,11 +12,14 @@ import rocketmq
 from rocketmq import ClientConfiguration, Credentials, SimpleConsumer, FilterExpression
 
 from pqai_agent import configs
+from pqai_agent.abtest.utils import get_abtest_info
 from pqai_agent.agents.message_push_agent import MessagePushAgent, DummyMessagePushAgent
 from pqai_agent.configs import apollo_config
 from pqai_agent.data_models.agent_push_record import AgentPushRecord
 from pqai_agent.logging_service import logger
 from pqai_agent.mq_message import MessageType
+from pqai_agent.toolkit import get_tools
+from pqai_agent.utils.agent_abtest_utils import get_agent_abtest_config
 
 
 class TaskType(Enum):
@@ -54,12 +57,17 @@ class PushScanThread:
         for staff_user in self.service.user_relation_manager.list_staff_users(staff_id=self.staff_id):
             staff_id = staff_user['staff_id']
             user_id = staff_user['user_id']
-            agent = self.service.get_agent_instance(staff_id, user_id)
-            should_initiate = agent.should_initiate_conversation()
+            # 通过AB实验配置控制用户组是否启用push
+            # abtest_params = get_abtest_info(user_id).params
+            # if abtest_params.get('agent_push_enabled', 'false').lower() != 'true':
+            #     logger.debug(f"User {user_id} not enabled agent push, skipping.")
+            #     continue
             user_tags = self.service.user_relation_manager.get_user_tags(user_id)
-
             if configs.get_env() != 'dev' and not white_list_tags.intersection(user_tags):
                 should_initiate = False
+            else:
+                agent = self.service.get_agent_instance(staff_id, user_id)
+                should_initiate = agent.should_initiate_conversation()
             if should_initiate:
                 logger.info(f"user[{user_id}], tags{user_tags}: generate a generation task for conversation initiation")
                 rmq_msg = generate_task_rmq_message(self.rmq_topic, staff_id, user_id, TaskType.GENERATE)
@@ -168,7 +176,7 @@ class PushTaskWorkerPool:
                     if response:
                         item["type"] = message_type
                         messages_to_send.append(item)
-            with self.agent_service.AgentDBSession() as session:
+            with self.agent_service.agent_db_session_maker() as session:
                 msg_list = [{"type": msg["type"].value, "content": msg["content"]} for msg in messages_to_send]
                 record = AgentPushRecord(staff_id=staff_id, user_id=user_id,
                                          content=json.dumps(msg_list, ensure_ascii=False),
@@ -192,12 +200,28 @@ class PushTaskWorkerPool:
             staff_id = task['staff_id']
             user_id = task['user_id']
             main_agent = self.agent_service.get_agent_instance(staff_id, user_id)
-            push_agent = MessagePushAgent()
+            agent_config = get_agent_abtest_config('push', user_id,
+                                                   self.agent_service.service_module_manager,
+                                                   self.agent_service.agent_config_manager)
+            if agent_config:
+                try:
+                    tool_names = json.loads(agent_config.tools)
+                except json.JSONDecodeError:
+                    logger.error(f"Invalid JSON in agent tools: {agent_config.tools}")
+                    tool_names = []
+                push_agent = MessagePushAgent(model=agent_config.execution_model,
+                                              system_prompt=agent_config.system_prompt,
+                                              tools=get_tools(tool_names))
+                query_prompt_template = agent_config.task_prompt
+            else:
+                push_agent = MessagePushAgent()
+                query_prompt_template = None
             message_to_user = push_agent.generate_message(
                 context=main_agent.get_prompt_context(None),
                 dialogue_history=self.agent_service.history_dialogue_db.get_dialogue_history_backward(
-                    staff_id, user_id, main_agent.last_interaction_time_ms, limit=100
-                )
+                    staff_id, user_id, main_agent.last_interaction_time_ms, limit=30
+                ),
+                query_prompt_template=query_prompt_template
             )
             if message_to_user:
                 rmq_message = generate_task_rmq_message(

+ 3 - 2
pqai_agent/response_type_detector.py

@@ -38,7 +38,8 @@ class ResponseTypeDetector:
         )
         self.model_name = chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_1_5
 
-    def detect_type(self, dialogue_history: List[Dict], next_message: Dict, enable_random=False):
+    def detect_type(self, dialogue_history: List[Dict], next_message: Dict, enable_random=False,
+                    random_rate=0.25):
         if configs.get().get('debug_flags', {}).get('disable_llm_api_call', False):
             return MessageType.TEXT
         composed_dialogue = self.compose_dialogue(dialogue_history)
@@ -62,7 +63,7 @@ class ResponseTypeDetector:
             suitable_for_voice = self.if_message_suitable_for_voice(next_message_content)
             logger.debug(f"voice suitable[{suitable_for_voice}], message: {next_message_content}")
             if suitable_for_voice:
-                if random.random() < 0.6:
+                if random.random() < random_rate:
                     logger.info(f"enable voice response randomly for message: {next_message_content}")
                     return MessageType.VOICE
         return MessageType.TEXT

+ 27 - 0
pqai_agent/service_module_manager.py

@@ -0,0 +1,27 @@
+from pqai_agent.data_models.service_module import ServiceModule, ModuleAgentType
+from pqai_agent.logging_service import logger
+
+class ServiceModuleManager:
+    def __init__(self, session_maker):
+        self.session_maker = session_maker
+        self.module_configs = {}
+        self.refresh_configs()
+
+    def refresh_configs(self):
+        try:
+            with self.session_maker() as session:
+                data = session.query(ServiceModule).filter_by(is_delete=False).all()
+                module_configs = {}
+                for module in data:
+                    module_configs[module.name] = {
+                        'display_name': module.display_name,
+                        'default_agent_type': ModuleAgentType(module.default_agent_type),
+                        'default_agent_id': module.default_agent_id
+                    }
+                self.module_configs = module_configs
+                logger.debug(f"Refreshed module configurations: {module_configs}")
+        except Exception as e:
+            logger.error(f"Error refreshing module configs: {e}")
+
+    def get_module_config(self, module_name: str):
+        return self.module_configs.get(module_name)

+ 42 - 0
pqai_agent/toolkit/__init__.py

@@ -0,0 +1,42 @@
+# 必须要在这里导入模块,以便对应的模块执行register_toolkit
+from typing import Sequence, List
+
+from pqai_agent.logging_service import logger
+from pqai_agent.toolkit.tool_registry import ToolRegistry
+from pqai_agent.toolkit.image_describer import ImageDescriber
+from pqai_agent.toolkit.message_notifier import MessageNotifier
+from pqai_agent.toolkit.pq_video_searcher import PQVideoSearcher
+from pqai_agent.toolkit.search_toolkit import SearchToolkit
+
+global_tool_map = ToolRegistry.tool_map
+
+def get_tool(tool_name: str) -> 'FunctionTool':
+    """
+    Retrieve a tool by its name from the global tool map.
+
+    Args:
+        tool_name (str): The name of the tool to retrieve.
+
+    Returns:
+        FunctionTool: The tool instance if found, otherwise None.
+    """
+    return global_tool_map.get(tool_name, None)
+
+def get_tools(tool_names: Sequence[str]) -> List['FunctionTool']:
+    """
+    Retrieve multiple tools by their names from the global tool map.
+
+    Args:
+        tool_names (Sequence[str]): A sequence of tool names to retrieve.
+
+    Returns:
+        Sequence[FunctionTool]: A sequence of tool instances corresponding to the provided names.
+    """
+    ret = []
+    for name in tool_names:
+        tool = get_tool(name)
+        if tool is not None:
+            ret.append(tool)
+        else:
+            logger.warning(f"Tool '{name}' not found in the global tool map.")
+    return ret

+ 2 - 0
pqai_agent/toolkit/image_describer.py

@@ -6,11 +6,13 @@ from pqai_agent.chat_service import VOLCENGINE_MODEL_DOUBAO_1_5_VISION_PRO
 from pqai_agent.logging_service import logger
 from pqai_agent.toolkit.base import BaseToolkit
 from pqai_agent.toolkit.function_tool import FunctionTool
+from pqai_agent.toolkit.tool_registry import register_toolkit
 
 # 不同实例间复用cache,但不是很好的实践
 _image_describer_caches = {}
 _cache_mutex = threading.Lock()
 
+@register_toolkit
 class ImageDescriber(BaseToolkit):
     def __init__(self, cache_dir: str = None):
         self.model = VOLCENGINE_MODEL_DOUBAO_1_5_VISION_PRO

+ 2 - 0
pqai_agent/toolkit/message_notifier.py

@@ -3,8 +3,10 @@ from typing import List, Dict
 from pqai_agent.logging_service import logger
 from pqai_agent.toolkit.base import BaseToolkit
 from pqai_agent.toolkit.function_tool import FunctionTool
+from pqai_agent.toolkit.tool_registry import register_toolkit
 
 
+@register_toolkit
 class MessageNotifier(BaseToolkit):
     def __init__(self):
         super().__init__()

+ 3 - 0
pqai_agent/toolkit/pq_video_searcher.py

@@ -3,7 +3,10 @@ import requests
 
 from pqai_agent.toolkit.base import BaseToolkit
 from pqai_agent.toolkit.function_tool import FunctionTool
+from pqai_agent.toolkit.tool_registry import register_toolkit
 
+
+@register_toolkit
 class PQVideoSearcher(BaseToolkit):
     API_URL = "https://vlogapi.piaoquantv.com/longvideoapi/search/userandvideo/list"
     def search_pq_video(self, keywords: List[str]) -> List[Dict]:

+ 2 - 0
pqai_agent/toolkit/search_toolkit.py

@@ -4,8 +4,10 @@ import requests
 
 from pqai_agent.toolkit.base import BaseToolkit
 from pqai_agent.toolkit.function_tool import FunctionTool
+from pqai_agent.toolkit.tool_registry import register_toolkit
 
 
+@register_toolkit
 class SearchToolkit(BaseToolkit):
     r"""A class representing a toolkit for web search.
     """

+ 27 - 0
pqai_agent/toolkit/tool_registry.py

@@ -0,0 +1,27 @@
+from typing import Type, Dict
+from pqai_agent.toolkit.function_tool import FunctionTool
+
+class ToolRegistry:
+    tool_map: Dict[str, FunctionTool] = {}
+
+    @classmethod
+    def register_tools(cls, toolkit_class: Type):
+        """
+        Register tools from a toolkit class into the global tool_map.
+
+        Args:
+            toolkit_class (Type): A class that implements a `get_tools` method.
+        """
+        instance = toolkit_class()
+        if not hasattr(instance, 'get_tools') or not callable(instance.get_tools):
+            raise ValueError(f"{toolkit_class.__name__} must implement a callable `get_tools` method.")
+
+        tools = instance.get_tools()
+        for tool in tools:
+            if not hasattr(tool, 'name'):
+                raise ValueError(f"Tool {tool} must have a `name` attribute.")
+            cls.tool_map[tool.name] = tool
+
+def register_toolkit(cls):
+    ToolRegistry.register_tools(cls)
+    return cls

+ 0 - 23
pqai_agent/user_manager.py

@@ -292,7 +292,6 @@ class MySQLUserManager(UserManager):
             "data": staff_list
         }
 
-
 class LocalUserRelationManager(UserRelationManager):
     def __init__(self):
         pass
@@ -434,25 +433,3 @@ class MySQLUserRelationManager(UserRelationManager):
         except Exception as e:
             logger.error(f"stop_user_daily_push failed: {e}")
             return False
-
-
-if __name__ == '__main__':
-    config = configs.get()
-    user_db_config = config['storage']['user']
-    staff_db_config = config['storage']['staff']
-    user_manager = MySQLUserManager(user_db_config['mysql'], user_db_config['table'], staff_db_config['table'])
-    user_profile = user_manager.get_user_profile('7881301263964433')
-    print(user_profile)
-
-    wecom_db_config = config['storage']['user_relation']
-    user_relation_manager = MySQLUserRelationManager(
-        user_db_config['mysql'], wecom_db_config['mysql'],
-        config['storage']['staff']['table'],
-        user_db_config['table'],
-        wecom_db_config['table']['staff'],
-        wecom_db_config['table']['relation'],
-        wecom_db_config['table']['user']
-    )
-    # all_staff_users = user_relation_manager.list_staff_users()
-    user_tags = user_relation_manager.get_user_tags('7881302078008656')
-    print(user_tags)

+ 1 - 0
pqai_agent/user_profile_extractor.py

@@ -156,6 +156,7 @@ class UserProfileExtractor:
         :return:
         """
         if configs.get().get('debug_flags', {}).get('disable_llm_api_call', False):
+            logger.debug("skip LLM API call.")
             return None
 
         try:

+ 18 - 0
pqai_agent/utils/agent_abtest_utils.py

@@ -0,0 +1,18 @@
+from typing import Optional
+
+from pqai_agent.abtest.utils import get_abtest_info
+from pqai_agent.data_models.agent_configuration import AgentConfiguration
+from pqai_agent.service_module_manager import ServiceModuleManager
+from pqai_agent.agent_config_manager import AgentConfigManager
+
+def get_agent_abtest_config(module_name: str, uid: str,
+                            service_module_manager: ServiceModuleManager,
+                            agent_config_manager: AgentConfigManager) -> Optional[AgentConfiguration]:
+    abtest_info = get_abtest_info(uid)
+    module_config = service_module_manager.get_module_config(f'{module_name}_module')
+    agent_id = module_config['default_agent_id']
+    param_key = f'module_{module_name}_agent_id'
+    if param_key in abtest_info.params:
+        agent_id = int(abtest_info.params[param_key])
+    agent_config = agent_config_manager.get_config(agent_id)
+    return agent_config