|
@@ -20,7 +20,7 @@ from openai import OpenAI
|
|
|
from message_queue_backend import MessageQueueBackend, MemoryQueueBackend
|
|
|
from user_profile_extractor import UserProfileExtractor
|
|
|
import threading
|
|
|
-from message import MessageType
|
|
|
+from message import MessageType, Message, MessageChannel
|
|
|
from logging_service import ColoredFormatter
|
|
|
|
|
|
|
|
@@ -92,19 +92,15 @@ class AgentService:
|
|
|
|
|
|
def _schedule_aggregation_trigger(self, user_id: str, delay_sec: int):
|
|
|
logging.debug("user: {}, schedule trigger message after {} seconds".format(user_id, delay_sec))
|
|
|
- message = {
|
|
|
- 'user_id': user_id,
|
|
|
- 'type': MessageType.AGGREGATION_TRIGGER,
|
|
|
- 'text': None,
|
|
|
- 'timestamp': int(time.time() * 1000) + delay_sec * 1000
|
|
|
- }
|
|
|
+ message_ts = int((time.time() + delay_sec) * 1000)
|
|
|
+ message = Message.build(MessageType.AGGREGATION_TRIGGER, MessageChannel.SYSTEM, None, user_id, None, message_ts)
|
|
|
+ message.id = -MessageType.AGGREGATION_TRIGGER.code
|
|
|
self.scheduler.add_job(lambda: self.receive_queue.produce(message),
|
|
|
'date',
|
|
|
run_date=datetime.now() + timedelta(seconds=delay_sec))
|
|
|
|
|
|
- def process_single_message(self, message: Dict):
|
|
|
- user_id = message['user_id']
|
|
|
- message_text = message.get('text', None)
|
|
|
+ def process_single_message(self, message: Message):
|
|
|
+ user_id = message.user_id
|
|
|
|
|
|
# 获取用户信息和Agent实例
|
|
|
user_profile = self.user_manager.get_user_profile(user_id)
|
|
@@ -117,9 +113,9 @@ class AgentService:
|
|
|
|
|
|
# 根据状态路由消息
|
|
|
if agent.is_in_human_intervention():
|
|
|
- self._route_to_human_intervention(user_id, message_text, dialogue_state)
|
|
|
+ self._route_to_human_intervention(user_id, message)
|
|
|
elif dialogue_state == DialogueState.MESSAGE_AGGREGATING:
|
|
|
- if message['type'] != MessageType.AGGREGATION_TRIGGER:
|
|
|
+ if message.type != MessageType.AGGREGATION_TRIGGER:
|
|
|
# 产生一个触发器,但是不能由触发器递归产生
|
|
|
logging.debug("user: {}, waiting next message for aggregation".format(user_id))
|
|
|
self._schedule_aggregation_trigger(user_id, agent.message_aggregation_sec)
|
|
@@ -129,13 +125,16 @@ class AgentService:
|
|
|
self._update_user_profile(user_id, user_profile, message_text)
|
|
|
self._get_chat_response(user_id, agent, message_text)
|
|
|
|
|
|
- def _route_to_human_intervention(self, user_id: str, user_message: str, state: DialogueState):
|
|
|
+ def _route_to_human_intervention(self, user_id: str, origin_message: Message):
|
|
|
"""路由到人工干预"""
|
|
|
- self.human_queue.produce({
|
|
|
- 'user_id': user_id,
|
|
|
- 'state': state,
|
|
|
- 'timestamp': datetime.now().isoformat()
|
|
|
- })
|
|
|
+ self.human_queue.produce(Message.build(
|
|
|
+ MessageType.TEXT,
|
|
|
+ origin_message.channel,
|
|
|
+ origin_message.staff_id,
|
|
|
+ origin_message.user_id,
|
|
|
+ "用户对话需人工介入,用户名:{}".format(user_id),
|
|
|
+ int(time.time() * 1000)
|
|
|
+ ))
|
|
|
|
|
|
def _check_initiative_conversations(self):
|
|
|
"""定时检查主动发起对话"""
|
|
@@ -154,7 +153,6 @@ class AgentService:
|
|
|
"""处理LLM响应"""
|
|
|
chat_config = agent.build_chat_configuration(user_message, self.chat_service_type)
|
|
|
logging.debug(chat_config)
|
|
|
- # FIXME(zhoutian): 这里的抽象不够好,DialogueManager和AgentService有耦合
|
|
|
chat_response = self._call_chat_api(chat_config)
|
|
|
|
|
|
if response := agent.generate_response(chat_response):
|
|
@@ -187,7 +185,7 @@ class AgentService:
|
|
|
if __name__ == "__main__":
|
|
|
logging.getLogger().setLevel(logging.DEBUG)
|
|
|
console_handler = logging.StreamHandler()
|
|
|
- console_handler.setLevel(logging.INFO)
|
|
|
+ console_handler.setLevel(logging.DEBUG)
|
|
|
formatter = ColoredFormatter(
|
|
|
'%(asctime)s - %(funcName)s[%(lineno)d] - %(levelname)s - %(message)s'
|
|
|
)
|
|
@@ -220,15 +218,16 @@ if __name__ == "__main__":
|
|
|
process_thread = threading.Thread(target=service.process_messages)
|
|
|
process_thread.start()
|
|
|
|
|
|
+ message_id = 0
|
|
|
while True:
|
|
|
print("Input next message: ")
|
|
|
- message = sys.stdin.readline().strip()
|
|
|
- message_dict = {
|
|
|
- "user_id": "user_id_1",
|
|
|
- "type": MessageType.TEXT,
|
|
|
- "text": message,
|
|
|
- "timestamp": int(time.time() * 1000)
|
|
|
- }
|
|
|
- if message:
|
|
|
- receive_queue.produce(message_dict)
|
|
|
+ text = sys.stdin.readline().strip()
|
|
|
+ if not text:
|
|
|
+ continue
|
|
|
+ message_id += 1
|
|
|
+ message = Message.build(MessageType.TEXT, MessageChannel.CORP_WECHAT,
|
|
|
+ 'staff_id_1','user_id_1', text, int(time.time() * 1000)
|
|
|
+ )
|
|
|
+ message.id = message_id
|
|
|
+ receive_queue.produce(message)
|
|
|
time.sleep(0.1)
|