|
@@ -2,6 +2,7 @@
|
|
|
# -*- coding: utf-8 -*-
|
|
|
# vim:fenc=utf-8
|
|
|
|
|
|
+import sys
|
|
|
import time
|
|
|
from typing import Dict, List, Tuple, Any
|
|
|
import logging
|
|
@@ -10,7 +11,9 @@ from datetime import datetime, timedelta
|
|
|
import apscheduler.triggers.cron
|
|
|
from apscheduler.schedulers.background import BackgroundScheduler
|
|
|
|
|
|
+import chat_service
|
|
|
import global_flags
|
|
|
+from chat_service import CozeChat, ChatServiceType
|
|
|
from dialogue_manager import DialogueManager, DialogueState
|
|
|
from user_manager import UserManager, LocalUserManager
|
|
|
from openai import OpenAI
|
|
@@ -27,7 +30,8 @@ class AgentService:
|
|
|
receive_backend: MessageQueueBackend,
|
|
|
send_backend: MessageQueueBackend,
|
|
|
human_backend: MessageQueueBackend,
|
|
|
- user_manager: UserManager
|
|
|
+ user_manager: UserManager,
|
|
|
+ chat_service_type: ChatServiceType = ChatServiceType.OPENAI_COMPATIBLE
|
|
|
):
|
|
|
self.receive_queue = receive_backend
|
|
|
self.send_queue = send_backend
|
|
@@ -42,7 +46,13 @@ class AgentService:
|
|
|
api_key='5e275c38-44fd-415f-abcf-4b59f6377f72',
|
|
|
base_url="https://ark.cn-beijing.volces.com/api/v3"
|
|
|
)
|
|
|
- self.model_name = "ep-20250213194558-rrmr2" # DeepSeek on Volces
|
|
|
+ # DeepSeek on Volces
|
|
|
+ self.model_name = "ep-20250213194558-rrmr2"
|
|
|
+ self.coze_client = CozeChat(
|
|
|
+ token=chat_service.COZE_API_TOKEN,
|
|
|
+ base_url=chat_service.COZE_CN_BASE_URL
|
|
|
+ )
|
|
|
+ self.chat_service_type = chat_service_type
|
|
|
|
|
|
# 定时任务调度器
|
|
|
self.scheduler = BackgroundScheduler()
|
|
@@ -67,7 +77,7 @@ class AgentService:
|
|
|
while True:
|
|
|
message = self.receive_queue.consume()
|
|
|
if message:
|
|
|
- self._process_single_message(message)
|
|
|
+ self.process_single_message(message)
|
|
|
time.sleep(1) # 避免CPU空转
|
|
|
|
|
|
def _update_user_profile(self, user_id, user_profile, message: str):
|
|
@@ -92,7 +102,7 @@ class AgentService:
|
|
|
'date',
|
|
|
run_date=datetime.now() + timedelta(seconds=delay_sec))
|
|
|
|
|
|
- def _process_single_message(self, message: Dict):
|
|
|
+ def process_single_message(self, message: Dict):
|
|
|
user_id = message['user_id']
|
|
|
message_text = message.get('text', None)
|
|
|
|
|
@@ -117,7 +127,7 @@ class AgentService:
|
|
|
else:
|
|
|
# 先更新用户画像再处理回复
|
|
|
self._update_user_profile(user_id, user_profile, message_text)
|
|
|
- self._process_llm_response(user_id, agent, message_text)
|
|
|
+ self._get_chat_response(user_id, agent, message_text)
|
|
|
|
|
|
def _route_to_human_intervention(self, user_id: str, user_message: str, state: DialogueState):
|
|
|
"""路由到人工干预"""
|
|
@@ -127,20 +137,6 @@ class AgentService:
|
|
|
'timestamp': datetime.now().isoformat()
|
|
|
})
|
|
|
|
|
|
- def _process_llm_response(self, user_id: str, agent: DialogueManager,
|
|
|
- user_message: str):
|
|
|
- """处理LLM响应"""
|
|
|
- messages = agent.make_llm_messages(user_message)
|
|
|
- logging.debug(messages)
|
|
|
- llm_response = self._call_llm_api(messages)
|
|
|
-
|
|
|
- if response := agent.generate_response(llm_response):
|
|
|
- logging.warning("user: {}, response: {}".format(user_id, response))
|
|
|
- self.send_queue.produce({
|
|
|
- 'user_id': user_id,
|
|
|
- 'text': response,
|
|
|
- })
|
|
|
-
|
|
|
def _check_initiative_conversations(self):
|
|
|
"""定时检查主动发起对话"""
|
|
|
for user_id in self.user_manager.list_all_users():
|
|
@@ -149,24 +145,49 @@ class AgentService:
|
|
|
|
|
|
if should_initiate:
|
|
|
logging.warning("user: {}, initiate conversation".format(user_id))
|
|
|
- self._process_llm_response(user_id, agent, None)
|
|
|
+ self._get_chat_response(user_id, agent, None)
|
|
|
else:
|
|
|
logging.debug("user: {}, do not initiate conversation".format(user_id))
|
|
|
|
|
|
- def _call_llm_api(self, messages: List[Dict]) -> str:
|
|
|
+ def _get_chat_response(self, user_id: str, agent: DialogueManager,
|
|
|
+ user_message: str):
|
|
|
+ """处理LLM响应"""
|
|
|
+ chat_config = agent.build_chat_configuration(user_message, self.chat_service_type)
|
|
|
+ logging.debug(chat_config)
|
|
|
+ # FIXME(zhoutian): 这里的抽象不够好,DialogueManager和AgentService有耦合
|
|
|
+ chat_response = self._call_chat_api(chat_config)
|
|
|
+
|
|
|
+ if response := agent.generate_response(chat_response):
|
|
|
+ logging.warning("user: {}, response: {}".format(user_id, response))
|
|
|
+ self.send_queue.produce({
|
|
|
+ 'user_id': user_id,
|
|
|
+ 'type': MessageType.TEXT,
|
|
|
+ 'text': response,
|
|
|
+ })
|
|
|
+
|
|
|
+ def _call_chat_api(self, chat_config: Dict) -> str:
|
|
|
if global_flags.DISABLE_LLM_API_CALL:
|
|
|
return 'LLM模拟回复'
|
|
|
- chat_completion = self.llm_client.chat.completions.create(
|
|
|
- messages=messages,
|
|
|
- model=self.model_name,
|
|
|
- )
|
|
|
- response = chat_completion.choices[0].message.content
|
|
|
+ if self.chat_service_type == ChatServiceType.OPENAI_COMPATIBLE:
|
|
|
+ chat_completion = self.llm_client.chat.completions.create(
|
|
|
+ messages=chat_config['messages'],
|
|
|
+ model=self.model_name,
|
|
|
+ )
|
|
|
+ response = chat_completion.choices[0].message.content
|
|
|
+ elif self.chat_service_type == ChatServiceType.COZE_CHAT:
|
|
|
+ bot_user_id = 'dev_user'
|
|
|
+ response = self.coze_client.create(
|
|
|
+ chat_config['bot_id'], bot_user_id, chat_config['messages'],
|
|
|
+ chat_config['custom_variables']
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ raise Exception('Unsupported chat service type: {}'.format(self.chat_service_type))
|
|
|
return response
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
logging.getLogger().setLevel(logging.DEBUG)
|
|
|
console_handler = logging.StreamHandler()
|
|
|
- console_handler.setLevel(logging.WARNING)
|
|
|
+ console_handler.setLevel(logging.INFO)
|
|
|
formatter = ColoredFormatter(
|
|
|
'%(asctime)s - %(funcName)s[%(lineno)d] - %(levelname)s - %(message)s'
|
|
|
)
|
|
@@ -192,7 +213,8 @@ if __name__ == "__main__":
|
|
|
receive_backend=receive_queue,
|
|
|
send_backend=send_queue,
|
|
|
human_backend=human_queue,
|
|
|
- user_manager=user_manager
|
|
|
+ user_manager=user_manager,
|
|
|
+ chat_service_type=ChatServiceType.COZE_CHAT
|
|
|
)
|
|
|
|
|
|
process_thread = threading.Thread(target=service.process_messages)
|