|
@@ -19,6 +19,7 @@ import logging_service
|
|
|
from logging_service import logger
|
|
|
from chat_service import CozeChat, ChatServiceType
|
|
|
from dialogue_manager import DialogueManager, DialogueState, DialogueStateCache
|
|
|
+from response_type_detector import ResponseTypeDetector
|
|
|
from user_manager import UserManager, LocalUserManager, MySQLUserManager, MySQLUserRelationManager, UserRelationManager
|
|
|
from openai import OpenAI
|
|
|
from message_queue_backend import MessageQueueBackend, MemoryQueueBackend, AliyunRocketMQQueueBackend
|
|
@@ -46,6 +47,7 @@ class AgentService:
|
|
|
self.user_manager = user_manager
|
|
|
self.user_relation_manager = user_relation_manager
|
|
|
self.user_profile_extractor = UserProfileExtractor()
|
|
|
+ self.response_type_detector = ResponseTypeDetector()
|
|
|
self.agent_registry: Dict[str, DialogueManager] = {}
|
|
|
|
|
|
self.llm_client = OpenAI(
|
|
@@ -145,7 +147,12 @@ class AgentService:
|
|
|
self._update_user_profile(user_id, user_profile, agent.dialogue_history[-10:])
|
|
|
resp = self._get_chat_response(user_id, agent, message_text)
|
|
|
if resp:
|
|
|
- self._send_response(staff_id, user_id, resp)
|
|
|
+ recent_dialogue = agent.dialogue_history[-10:]
|
|
|
+ if len(recent_dialogue) < 2:
|
|
|
+ message_type = MessageType.TEXT
|
|
|
+ else:
|
|
|
+ message_type = self.response_type_detector.detect_type(recent_dialogue[:-1], recent_dialogue[-1])
|
|
|
+ self._send_response(staff_id, user_id, resp, message_type)
|
|
|
else:
|
|
|
logger.debug(f"staff[{staff_id}], user[{user_id}]: do not need response")
|
|
|
# 当前消息处理成功,持久化agent状态
|
|
@@ -154,8 +161,8 @@ class AgentService:
|
|
|
agent.rollback_state()
|
|
|
raise e
|
|
|
|
|
|
- def _send_response(self, staff_id, user_id, response):
|
|
|
- logger.warning(f"staff[{staff_id}] user[{user_id}]: response: {response}")
|
|
|
+ def _send_response(self, staff_id, user_id, response, message_type: MessageType):
|
|
|
+ logger.warning(f"staff[{staff_id}] user[{user_id}]: response[{message_type}] {response}")
|
|
|
current_ts = int(time.time() * 1000)
|
|
|
user_tags = self.user_relation_manager.get_user_tags(user_id)
|
|
|
# FIXME(zhoutian)
|
|
@@ -166,7 +173,7 @@ class AgentService:
|
|
|
logger.warning(f"staff[{staff_id}] user[{user_id}]: skip reply")
|
|
|
return None
|
|
|
self.send_queue.produce(
|
|
|
- Message.build(MessageType.TEXT, MessageChannel.CORP_WECHAT,
|
|
|
+ Message.build(message_type, MessageChannel.CORP_WECHAT,
|
|
|
staff_id, user_id, response, current_ts)
|
|
|
)
|
|
|
|
|
@@ -193,7 +200,7 @@ class AgentService:
|
|
|
logger.warning("user: {}, initiate conversation".format(user_id))
|
|
|
resp = self._get_chat_response(user_id, agent, None)
|
|
|
if resp:
|
|
|
- self._send_response(staff_id, user_id, resp)
|
|
|
+ self._send_response(staff_id, user_id, resp, MessageType.TEXT)
|
|
|
time.sleep(random.randint(10,20))
|
|
|
else:
|
|
|
logger.debug("user: {}, do not initiate conversation".format(user_id))
|