|
@@ -15,6 +15,7 @@ from apscheduler.schedulers.background import BackgroundScheduler
|
|
|
import chat_service
|
|
|
import configs
|
|
|
import logging_service
|
|
|
+from logging_service import logger
|
|
|
from chat_service import CozeChat, ChatServiceType
|
|
|
from dialogue_manager import DialogueManager, DialogueState, DialogueStateCache
|
|
|
from user_manager import UserManager, LocalUserManager, MySQLUserManager, MySQLUserRelationManager, UserRelationManager
|
|
@@ -23,7 +24,6 @@ from message_queue_backend import MessageQueueBackend, MemoryQueueBackend, Aliyu
|
|
|
from user_profile_extractor import UserProfileExtractor
|
|
|
import threading
|
|
|
from message import MessageType, Message, MessageChannel
|
|
|
-from logging_service import ColoredFormatter
|
|
|
|
|
|
|
|
|
class AgentService:
|
|
@@ -93,22 +93,22 @@ class AgentService:
|
|
|
self.process_single_message(message)
|
|
|
self.receive_queue.ack(message)
|
|
|
except Exception as e:
|
|
|
- logging.error("Error processing message: {}".format(e))
|
|
|
+ logger.error("Error processing message: {}".format(e))
|
|
|
traceback.print_exc()
|
|
|
time.sleep(1)
|
|
|
|
|
|
def _update_user_profile(self, user_id, user_profile, message: str):
|
|
|
profile_to_update = self.user_profile_extractor.extract_profile_info(user_profile, message)
|
|
|
if not profile_to_update:
|
|
|
- logging.debug("user_id: {}, no profile info extracted".format(user_id))
|
|
|
+ logger.debug("user_id: {}, no profile info extracted".format(user_id))
|
|
|
return
|
|
|
- logging.warning("update user profile: {}".format(profile_to_update))
|
|
|
+ logger.warning("update user profile: {}".format(profile_to_update))
|
|
|
merged_profile = self.user_profile_extractor.merge_profile_info(user_profile, profile_to_update)
|
|
|
self.user_manager.save_user_profile(user_id, merged_profile)
|
|
|
return merged_profile
|
|
|
|
|
|
def _schedule_aggregation_trigger(self, staff_id: str, user_id: str, delay_sec: int):
|
|
|
- logging.debug("user: {}, schedule trigger message after {} seconds".format(user_id, delay_sec))
|
|
|
+ logger.debug("user: {}, schedule trigger message after {} seconds".format(user_id, delay_sec))
|
|
|
message_ts = int((time.time() + delay_sec) * 1000)
|
|
|
message = Message.build(MessageType.AGGREGATION_TRIGGER, MessageChannel.SYSTEM, user_id, staff_id, None, message_ts)
|
|
|
# 系统消息使用特定的msgId,无实际意义
|
|
@@ -126,9 +126,9 @@ class AgentService:
|
|
|
agent = self._get_agent_instance(staff_id, user_id)
|
|
|
|
|
|
# 更新对话状态
|
|
|
- logging.debug("process message: {}".format(message))
|
|
|
+ logger.debug("process message: {}".format(message))
|
|
|
need_response, message_text = agent.update_state(message)
|
|
|
- logging.debug("user: {}, next state: {}".format(user_id, agent.current_state))
|
|
|
+ logger.debug("user: {}, next state: {}".format(user_id, agent.current_state))
|
|
|
|
|
|
# 根据状态路由消息
|
|
|
if agent.is_in_human_intervention():
|
|
@@ -136,7 +136,7 @@ class AgentService:
|
|
|
elif agent.current_state == DialogueState.MESSAGE_AGGREGATING:
|
|
|
if message.type != MessageType.AGGREGATION_TRIGGER:
|
|
|
# 产生一个触发器,但是不能由触发器递归产生
|
|
|
- logging.debug("user: {}, waiting next message for aggregation".format(user_id))
|
|
|
+ logger.debug("user: {}, waiting next message for aggregation".format(user_id))
|
|
|
self._schedule_aggregation_trigger(staff_id, user_id, agent.message_aggregation_sec)
|
|
|
return
|
|
|
elif need_response:
|
|
@@ -144,7 +144,7 @@ class AgentService:
|
|
|
self._update_user_profile(user_id, user_profile, message_text)
|
|
|
self._get_chat_response(user_id, agent, message_text)
|
|
|
else:
|
|
|
- logging.debug(f"staff[{staff_id}], user[{user_id}]: do not need response")
|
|
|
+ logger.debug(f"staff[{staff_id}], user[{user_id}]: do not need response")
|
|
|
|
|
|
def _route_to_human_intervention(self, user_id: str, origin_message: Message):
|
|
|
"""路由到人工干预"""
|
|
@@ -166,33 +166,33 @@ class AgentService:
|
|
|
should_initiate = agent.should_initiate_conversation()
|
|
|
|
|
|
if should_initiate:
|
|
|
- logging.warning("user: {}, initiate conversation".format(user_id))
|
|
|
+ logger.warning("user: {}, initiate conversation".format(user_id))
|
|
|
self._get_chat_response(user_id, agent, None)
|
|
|
else:
|
|
|
- logging.debug("user: {}, do not initiate conversation".format(user_id))
|
|
|
+ logger.debug("user: {}, do not initiate conversation".format(user_id))
|
|
|
|
|
|
def _get_chat_response(self, user_id: str, agent: DialogueManager,
|
|
|
user_message: Optional[str]):
|
|
|
"""处理LLM响应"""
|
|
|
chat_config = agent.build_chat_configuration(user_message, self.chat_service_type)
|
|
|
- logging.debug(chat_config)
|
|
|
+ logger.debug(chat_config)
|
|
|
# FIXME(zhoutian): 临时处理去除头尾的空格
|
|
|
chat_response = self._call_chat_api(chat_config).strip()
|
|
|
|
|
|
if response := agent.generate_response(chat_response):
|
|
|
- logging.warning(f"staff[{agent.staff_id}] user[{user_id}]: response: {response}")
|
|
|
+ logger.warning(f"staff[{agent.staff_id}] user[{user_id}]: response: {response}")
|
|
|
current_ts = int(time.time() * 1000)
|
|
|
# FIXME(zhoutian)
|
|
|
# 测试期间临时逻辑,只发送特定的用户
|
|
|
if agent.staff_id not in set(['1688854492669990']):
|
|
|
- logging.warning(f"skip message from sender [{agent.staff_id}]")
|
|
|
+ logger.warning(f"skip message from sender [{agent.staff_id}]")
|
|
|
return
|
|
|
self.send_queue.produce(
|
|
|
Message.build(MessageType.TEXT, MessageChannel.CORP_WECHAT,
|
|
|
agent.staff_id, user_id, response, current_ts)
|
|
|
)
|
|
|
else:
|
|
|
- logging.warning(f"staff[{agent.staff_id}] user[{user_id}]: no response generated")
|
|
|
+ logger.warning(f"staff[{agent.staff_id}] user[{user_id}]: no response generated")
|
|
|
|
|
|
def _call_chat_api(self, chat_config: Dict) -> str:
|
|
|
if configs.get().get('debug_flags', {}).get('disable_llm_api_call', False):
|
|
@@ -216,7 +216,7 @@ class AgentService:
|
|
|
if __name__ == "__main__":
|
|
|
config = configs.get()
|
|
|
logging_service.setup_root_logger()
|
|
|
- logging.warning("current env: {}".format(configs.get_env()))
|
|
|
+ logger.warning("current env: {}".format(configs.get_env()))
|
|
|
scheduler_logger = logging.getLogger('apscheduler')
|
|
|
scheduler_logger.setLevel(logging.WARNING)
|
|
|
|