|
@@ -18,6 +18,8 @@ from apscheduler.schedulers.background import BackgroundScheduler
|
|
|
from sqlalchemy.orm import sessionmaker
|
|
|
|
|
|
from pqai_agent import configs
|
|
|
+from pqai_agent.abtest.utils import get_abtest_info
|
|
|
+from pqai_agent.agent_config_manager import AgentConfigManager
|
|
|
from pqai_agent.agents.message_reply_agent import MessageReplyAgent
|
|
|
from pqai_agent.configs import apollo_config
|
|
|
from pqai_agent.exceptions import NoRetryException
|
|
@@ -29,10 +31,13 @@ from pqai_agent.history_dialogue_service import HistoryDialogueDatabase
|
|
|
from pqai_agent.push_service import PushScanThread, PushTaskWorkerPool
|
|
|
from pqai_agent.rate_limiter import MessageSenderRateLimiter
|
|
|
from pqai_agent.response_type_detector import ResponseTypeDetector
|
|
|
+from pqai_agent.service_module_manager import ServiceModuleManager
|
|
|
+from pqai_agent.toolkit import get_tools
|
|
|
from pqai_agent.user_manager import UserManager, UserRelationManager
|
|
|
from pqai_agent.message_queue_backend import MessageQueueBackend, AliyunRocketMQQueueBackend
|
|
|
from pqai_agent.user_profile_extractor import UserProfileExtractor
|
|
|
from pqai_agent.mq_message import MessageType, MqMessage, MessageChannel
|
|
|
+from pqai_agent.utils.agent_abtest_utils import get_agent_abtest_config
|
|
|
from pqai_agent.utils.db_utils import create_ai_agent_db_engine
|
|
|
|
|
|
|
|
@@ -61,7 +66,7 @@ class AgentService:
|
|
|
self.agent_registry: Dict[str, DialogueManager] = {}
|
|
|
self.history_dialogue_db = HistoryDialogueDatabase(self.config['database']['ai_agent'])
|
|
|
self.agent_db_engine = create_ai_agent_db_engine()
|
|
|
- self.AgentDBSession = sessionmaker(bind=self.agent_db_engine)
|
|
|
+ self.agent_db_session_maker = sessionmaker(bind=self.agent_db_engine)
|
|
|
|
|
|
chat_config = self.config['chat_api']['openai_compatible']
|
|
|
self.text_model_name = chat_config['text_model']
|
|
@@ -98,6 +103,10 @@ class AgentService:
|
|
|
|
|
|
self.send_rate_limiter = MessageSenderRateLimiter()
|
|
|
|
|
|
+ # Agent配置和实验相关
|
|
|
+ self.service_module_manager = ServiceModuleManager(self.agent_db_session_maker)
|
|
|
+ self.agent_config_manager = AgentConfigManager(self.agent_db_session_maker)
|
|
|
+
|
|
|
def setup_initiative_conversations(self, schedule_params: Optional[Dict] = None):
|
|
|
if not schedule_params:
|
|
|
schedule_params = {'hour': '8,16,20'}
|
|
@@ -123,6 +132,11 @@ class AgentService:
|
|
|
)
|
|
|
self.msg_scheduler_thread = threading.Thread(target=self.process_scheduler_events)
|
|
|
self.msg_scheduler_thread.start()
|
|
|
+ # 定时更新模块配置任务
|
|
|
+ self.scheduler.add_job(self.service_module_manager.refresh_configs, 'interval',
|
|
|
+ seconds=60, id='refresh_module_configs')
|
|
|
+ self.scheduler.add_job(self.agent_config_manager.refresh_configs, 'interval',
|
|
|
+ seconds=60, id='refresh_agent_configs')
|
|
|
self.scheduler.start()
|
|
|
|
|
|
def process_scheduler_events(self):
|
|
@@ -149,7 +163,7 @@ class AgentService:
|
|
|
agent_key = 'agent_{}_{}'.format(staff_id, user_id)
|
|
|
if agent_key not in self.agent_registry:
|
|
|
self.agent_registry[agent_key] = DialogueManager(
|
|
|
- staff_id, user_id, self.user_manager, self.agent_state_cache, self.AgentDBSession)
|
|
|
+ staff_id, user_id, self.user_manager, self.agent_state_cache, self.agent_db_session_maker)
|
|
|
agent = self.agent_registry[agent_key]
|
|
|
agent.refresh_profile()
|
|
|
return agent
|
|
@@ -240,7 +254,12 @@ class AgentService:
|
|
|
sys.exit(0)
|
|
|
|
|
|
def _update_user_profile(self, user_id, user_profile, recent_dialogue: List[Dict]):
|
|
|
- profile_to_update = self.user_profile_extractor.extract_profile_info_v2(user_profile, recent_dialogue)
|
|
|
+ agent_info = get_agent_abtest_config('profile_extractor', user_id, self.service_module_manager, self.agent_config_manager)
|
|
|
+ if agent_info:
|
|
|
+ prompt_template = agent_info.task_prompt
|
|
|
+ else:
|
|
|
+ prompt_template = None
|
|
|
+ profile_to_update = self.user_profile_extractor.extract_profile_info_v2(user_profile, recent_dialogue, prompt_template)
|
|
|
if not profile_to_update:
|
|
|
logger.debug("user_id: {}, no profile info extracted".format(user_id))
|
|
|
return
|
|
@@ -396,8 +415,12 @@ class AgentService:
|
|
|
return
|
|
|
|
|
|
push_scan_threads = []
|
|
|
+ whitelist_staffs = apollo_config.get_json_value("agent_initiate_whitelist_staffs", [])
|
|
|
for staff in self.user_relation_manager.list_staffs():
|
|
|
staff_id = staff['third_party_user_id']
|
|
|
+ if staff_id not in whitelist_staffs:
|
|
|
+ logger.info(f"staff[{staff_id}] is not in whitelist, skip")
|
|
|
+ continue
|
|
|
scan_thread = threading.Thread(target=PushScanThread(
|
|
|
staff_id, self, self.push_task_rmq_topic, self.push_task_producer).run)
|
|
|
scan_thread.start()
|
|
@@ -435,7 +458,14 @@ class AgentService:
|
|
|
return None
|
|
|
|
|
|
def _get_chat_response_v2(self, main_agent: DialogueManager) -> List[Dict]:
|
|
|
- chat_agent = MessageReplyAgent()
|
|
|
+ agent_config = get_agent_abtest_config('chat', main_agent.user_id,
|
|
|
+ self.service_module_manager, self.agent_config_manager)
|
|
|
+ if agent_config:
|
|
|
+ chat_agent = MessageReplyAgent(model=agent_config.execution_model,
|
|
|
+ system_prompt=agent_config.system_prompt,
|
|
|
+ tools=get_tools(agent_config.tools))
|
|
|
+ else:
|
|
|
+ chat_agent = MessageReplyAgent()
|
|
|
chat_responses = chat_agent.generate_message(
|
|
|
context=main_agent.get_prompt_context(None),
|
|
|
dialogue_history=main_agent.dialogue_history[-100:]
|