agent_service.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541
  1. #! /usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. # vim:fenc=utf-8
  4. import json
  5. import re
  6. import signal
  7. import sys
  8. import time
  9. from typing import Dict, List, Optional, Union
  10. import logging
  11. from datetime import datetime, timedelta
  12. import threading
  13. import traceback
  14. import apscheduler.triggers.cron
  15. import rocketmq
  16. from apscheduler.schedulers.background import BackgroundScheduler
  17. from rocketmq import FilterExpression
  18. from sqlalchemy.orm import sessionmaker
  19. from pqai_agent import configs, push_service
  20. from pqai_agent.abtest.utils import get_abtest_info
  21. from pqai_agent.agent_config_manager import AgentConfigManager
  22. from pqai_agent.agents.message_reply_agent import MessageReplyAgent
  23. from pqai_agent.configs import apollo_config
  24. from pqai_agent.exceptions import NoRetryException
  25. from pqai_agent.logging import logger
  26. from pqai_agent import chat_service
  27. from pqai_agent.chat_service import CozeChat, ChatServiceType
  28. from pqai_agent.dialogue_manager import DialogueManager, DialogueState, DialogueStateCache
  29. from pqai_agent.history_dialogue_service import HistoryDialogueDatabase
  30. from pqai_agent.push_service import PushScanThread, PushTaskWorkerPool
  31. from pqai_agent.rate_limiter import MessageSenderRateLimiter
  32. from pqai_agent.response_type_detector import ResponseTypeDetector
  33. from pqai_agent.service_module_manager import ServiceModuleManager
  34. from pqai_agent.toolkit import get_tools
  35. from pqai_agent.user_manager import UserManager, UserRelationManager
  36. from pqai_agent.message_queue_backend import MessageQueueBackend, AliyunRocketMQQueueBackend
  37. from pqai_agent.user_profile_extractor import UserProfileExtractor
  38. from pqai_agent.mq_message import MessageType, MqMessage, MessageChannel
  39. from pqai_agent.utils.agent_abtest_utils import get_agent_abtest_config
  40. from pqai_agent.utils.db_utils import create_ai_agent_db_engine
  41. class AgentService:
  42. def __init__(
  43. self,
  44. receive_backend: MessageQueueBackend,
  45. send_backend: MessageQueueBackend,
  46. human_backend: MessageQueueBackend,
  47. user_manager: UserManager,
  48. user_relation_manager: UserRelationManager,
  49. chat_service_type: ChatServiceType = ChatServiceType.OPENAI_COMPATIBLE
  50. ):
  51. self.config = configs.get()
  52. self.receive_queue = receive_backend
  53. self.send_queue = send_backend
  54. self.human_queue = human_backend
  55. # 核心服务模块
  56. self.agent_state_cache = DialogueStateCache()
  57. self.user_manager = user_manager
  58. self.user_relation_manager = user_relation_manager
  59. self.user_profile_extractor = UserProfileExtractor()
  60. self.response_type_detector = ResponseTypeDetector()
  61. self.agent_registry: Dict[str, DialogueManager] = {}
  62. self.history_dialogue_db = HistoryDialogueDatabase(self.config['database']['ai_agent'])
  63. self.agent_db_engine = create_ai_agent_db_engine()
  64. self.agent_db_session_maker = sessionmaker(bind=self.agent_db_engine)
  65. chat_config = self.config['chat_api']['openai_compatible']
  66. self.text_model_name = chat_config['text_model']
  67. self.multimodal_model_name = chat_config['multimodal_model']
  68. self.text_model_client = chat_service.OpenAICompatible.create_client(self.text_model_name)
  69. self.multimodal_model_client = chat_service.OpenAICompatible.create_client(self.multimodal_model_name)
  70. coze_config = configs.get()['chat_api']['coze']
  71. coze_oauth_app = CozeChat.get_oauth_app(
  72. coze_config['oauth_client_id'], coze_config['private_key_path'], str(coze_config['public_key_id']),
  73. account_id=coze_config.get('account_id', None)
  74. )
  75. self.coze_client = CozeChat(
  76. base_url=chat_service.COZE_CN_BASE_URL,
  77. auth_app=coze_oauth_app
  78. )
  79. self.chat_service_type = chat_service_type
  80. # 定时任务调度器
  81. self.scheduler = None
  82. self.scheduler_mode = self.config.get('system', {}).get('scheduler_mode', 'local')
  83. self.scheduler_queue = None
  84. self.msg_scheduler_thread = None
  85. self.running = False
  86. self.process_threads = []
  87. self._sigint_cnt = 0
  88. # Push相关
  89. self.push_task_producer = None
  90. self.push_generate_task_consumer = None
  91. self.push_send_task_consumer = None
  92. self._init_push_task_queue()
  93. self.next_push_disabled = True
  94. self._resume_unfinished_push_task()
  95. self.send_rate_limiter = MessageSenderRateLimiter()
  96. # Agent配置和实验相关
  97. self.service_module_manager = ServiceModuleManager(self.agent_db_session_maker)
  98. self.agent_config_manager = AgentConfigManager(self.agent_db_session_maker)
  99. def setup_initiative_conversations(self, schedule_params: Optional[Dict] = None):
  100. if not schedule_params:
  101. schedule_params = {'hour': '8,16,20'}
  102. self.scheduler.add_job(
  103. self._check_initiative_conversations,
  104. apscheduler.triggers.cron.CronTrigger(**schedule_params)
  105. )
  106. def setup_scheduler(self):
  107. self.scheduler = BackgroundScheduler()
  108. if self.scheduler_mode == 'mq':
  109. logging.info("setup event message scheduler with MQ")
  110. mq_conf = self.config['mq']
  111. topic = mq_conf['scheduler_topic']
  112. self.scheduler_queue = AliyunRocketMQQueueBackend(
  113. mq_conf['endpoints'],
  114. mq_conf['instance_id'],
  115. topic,
  116. has_consumer=True, has_producer=True,
  117. group_id=mq_conf['scheduler_group'],
  118. topic_type='DELAY',
  119. await_duration=5
  120. )
  121. self.msg_scheduler_thread = threading.Thread(target=self.process_scheduler_events)
  122. self.msg_scheduler_thread.start()
  123. # 定时更新模块配置任务
  124. self.scheduler.add_job(self.service_module_manager.refresh_configs, 'interval',
  125. seconds=60, id='refresh_module_configs')
  126. self.scheduler.add_job(self.agent_config_manager.refresh_configs, 'interval',
  127. seconds=60, id='refresh_agent_configs')
  128. self.scheduler.start()
  129. def process_scheduler_events(self):
  130. while self.running:
  131. msg = self.scheduler_queue.consume()
  132. if msg:
  133. try:
  134. self.process_scheduler_event(msg)
  135. self.scheduler_queue.ack(msg)
  136. except Exception as e:
  137. logger.error("Error processing scheduler event: {}".format(e))
  138. time.sleep(1)
  139. logger.info("Scheduler event processing thread exit")
  140. def process_scheduler_event(self, msg: MqMessage):
  141. if msg.type == MessageType.AGGREGATION_TRIGGER:
  142. # 延迟触发的消息,需放入接收队列以驱动Agent运转
  143. self.receive_queue.produce(msg)
  144. else:
  145. logger.warning(f"Unknown message type: {msg.type}")
  146. def get_agent_instance(self, staff_id: str, user_id: str) -> DialogueManager:
  147. """获取Agent实例"""
  148. agent_key = 'agent_{}_{}'.format(staff_id, user_id)
  149. if agent_key not in self.agent_registry:
  150. self.agent_registry[agent_key] = DialogueManager(
  151. staff_id, user_id, self.user_manager, self.agent_state_cache, self.agent_db_session_maker)
  152. agent = self.agent_registry[agent_key]
  153. agent.refresh_profile()
  154. return agent
  155. def create_queue_consumer(self) -> MessageQueueBackend:
  156. # 只有在MQ模式下才需要创建多消费者
  157. if not self.config.get('debug_flags', {}).get('use_aliyun_mq', False):
  158. logger.warning("Do not create queue consumer in local mode")
  159. return self.receive_queue
  160. mq_config = self.config['mq']
  161. consumer = AliyunRocketMQQueueBackend(
  162. endpoints=mq_config['endpoints'],
  163. instance_id=mq_config['instance_id'],
  164. topic=mq_config['receive_topic'],
  165. has_consumer=True,
  166. has_producer=False,
  167. group_id=mq_config['receive_group'],
  168. topic_type='FIFO',
  169. await_duration=10
  170. )
  171. return consumer
  172. def process_messages(self):
  173. """持续处理接收队列消息,通过顺序消息的消息组保证同一<用户, 客服>的消费保序,可并发处理"""
  174. receive_queue = self.create_queue_consumer()
  175. # 消费者创建后等一会儿再开始消费,否则可能远端没准备好会报错
  176. time.sleep(1)
  177. while self.running:
  178. message = receive_queue.consume()
  179. if message:
  180. try:
  181. self.process_single_message(message)
  182. receive_queue.ack(message)
  183. except NoRetryException as e:
  184. logger.error("Error processing message and skip retry: {}".format(e))
  185. receive_queue.ack(message)
  186. except Exception as e:
  187. error_stack = traceback.format_exc()
  188. logger.error("Error processing message: {}, {}".format(e, error_stack))
  189. time.sleep(0.1)
  190. receive_queue.shutdown()
  191. logger.info("MqMessage processing thread exit")
  192. def start(self, blocking=False):
  193. self.running = True
  194. max_reply_workers = self.config.get('system', {}).get('max_reply_workers', 1)
  195. self.process_threads = []
  196. for i in range(max_reply_workers):
  197. thread = threading.Thread(target=self.process_messages)
  198. thread.start()
  199. self.process_threads.append(thread)
  200. self.setup_scheduler()
  201. # 只有企微场景需要主动发起
  202. if not self.config['debug_flags'].get('disable_active_conversation', False):
  203. schedule_param = self.config['agent_behavior'].get('active_conversation_schedule_param', None)
  204. self.setup_initiative_conversations(schedule_param)
  205. signal.signal(signal.SIGINT, self._handle_sigint)
  206. if blocking:
  207. for thread in self.process_threads:
  208. thread.join()
  209. logger.debug("process threads finished")
  210. def shutdown(self, sync=True):
  211. if not self.running:
  212. raise Exception("Service is not running")
  213. self.running = False
  214. self.scheduler.shutdown()
  215. logger.debug("scheduler shutdown")
  216. if sync:
  217. for thread in self.process_threads:
  218. thread.join()
  219. logger.debug("message processing threads finished")
  220. if self.msg_scheduler_thread:
  221. self.msg_scheduler_thread.join()
  222. self.scheduler_queue.shutdown()
  223. logger.debug("scheduler message processing thread finished")
  224. self.receive_queue.shutdown()
  225. self.send_queue.shutdown()
  226. logger.debug("receive and send queues shutdown")
  227. def _handle_sigint(self, signum, frame):
  228. self._sigint_cnt += 1
  229. if self._sigint_cnt == 1:
  230. logger.warning("Try to shutdown gracefully...")
  231. self.shutdown(sync=True)
  232. else:
  233. logger.warning("Forcing exit")
  234. sys.exit(0)
  235. def _update_user_profile(self, user_id, user_profile, recent_dialogue: List[Dict]):
  236. agent_info = get_agent_abtest_config('profile_extractor', user_id, self.service_module_manager, self.agent_config_manager)
  237. if agent_info:
  238. prompt_template = agent_info.task_prompt
  239. else:
  240. prompt_template = None
  241. profile_to_update = self.user_profile_extractor.extract_profile_info_v2(user_profile, recent_dialogue, prompt_template)
  242. if not profile_to_update:
  243. logger.debug("user_id: {}, no profile info extracted".format(user_id))
  244. return
  245. logger.warning("update user profile: {}".format(profile_to_update))
  246. if profile_to_update.get('interaction_frequency', None) == 'stopped':
  247. # 和企微日常push联动,减少对用户的干扰
  248. if self.user_relation_manager.stop_user_daily_push(user_id):
  249. logger.warning(f"user[{user_id}]: daily push set to be stopped")
  250. merged_profile = self.user_profile_extractor.merge_profile_info(user_profile, profile_to_update)
  251. self.user_manager.save_user_profile(user_id, merged_profile)
  252. return merged_profile
  253. def _schedule_aggregation_trigger(self, staff_id: str, user_id: str, delay_sec: int):
  254. logger.debug("user: {}, schedule trigger message after {} seconds".format(user_id, delay_sec))
  255. message_ts = int((time.time() + delay_sec) * 1000)
  256. msg = MqMessage.build(MessageType.AGGREGATION_TRIGGER, MessageChannel.SYSTEM, user_id, staff_id, None, message_ts)
  257. # 系统消息使用特定的msgId,无实际意义
  258. msg.msgId = -MessageType.AGGREGATION_TRIGGER.value
  259. if self.scheduler_mode == 'mq':
  260. self.scheduler_queue.produce(msg, msg_group='agent_system')
  261. else:
  262. self.scheduler.add_job(lambda: self.receive_queue.produce(msg),
  263. 'date',
  264. run_date=datetime.now() + timedelta(seconds=delay_sec))
  265. def process_single_message(self, message: MqMessage):
  266. user_id = message.sender
  267. staff_id = message.receiver
  268. # 获取用户信息和Agent实例
  269. user_profile = self.user_manager.get_user_profile(user_id)
  270. agent = self.get_agent_instance(staff_id, user_id)
  271. if not agent.is_valid():
  272. logger.error(f"staff[{staff_id}] user[{user_id}]: agent is invalid")
  273. raise Exception('agent is invalid')
  274. # 更新对话状态
  275. logger.debug("process message: {}".format(message))
  276. need_response, message_text = agent.update_state(message)
  277. logger.debug("user: {}, next state: {}".format(user_id, agent.current_state))
  278. # 根据状态路由消息
  279. try:
  280. if agent.is_in_human_intervention():
  281. self._route_to_human_intervention(user_id, message)
  282. elif agent.current_state == DialogueState.MESSAGE_AGGREGATING:
  283. if message.type != MessageType.AGGREGATION_TRIGGER:
  284. # 产生一个触发器,但是不能由触发器递归产生
  285. logger.debug("user: {}, waiting next message for aggregation".format(user_id))
  286. self._schedule_aggregation_trigger(staff_id, user_id, agent.message_aggregation_sec)
  287. elif need_response:
  288. # 先更新用户画像再处理回复
  289. self._update_user_profile(user_id, user_profile, agent.dialogue_history[-10:])
  290. resp = self.get_chat_response(agent, message_text)
  291. self.send_responses(agent, resp)
  292. else:
  293. logger.debug(f"staff[{staff_id}], user[{user_id}]: do not need response")
  294. # 当前消息处理成功,commit并持久化agent状态
  295. agent.persist_state()
  296. except Exception as e:
  297. agent.rollback_state()
  298. raise e
  299. def send_responses(self, agent: DialogueManager, contents: List[Dict]):
  300. staff_id = agent.staff_id
  301. user_id = agent.user_id
  302. recent_dialogue = agent.dialogue_history[-10:]
  303. agent_voice_whitelist = set(apollo_config.get_json_value("agent_voice_whitelist", []))
  304. current_ts = int(time.time())
  305. for item in contents:
  306. item["timestamp"] = current_ts * 1000
  307. if item["type"] == MessageType.TEXT:
  308. if staff_id in agent_voice_whitelist or True:
  309. message_type = self.response_type_detector.detect_type(
  310. recent_dialogue, item, enable_random=True)
  311. item["type"] = message_type
  312. if contents:
  313. for response in contents:
  314. self.send_multimodal_response(staff_id, user_id, response)
  315. agent.update_last_active_interaction_time(current_ts)
  316. else:
  317. logger.debug(f"staff[{staff_id}], user[{user_id}]: no messages to send")
  318. def can_send_to_user(self, staff_id, user_id) -> bool:
  319. user_tags = self.user_manager.get_user_tags([user_id]).get(user_id, [])
  320. white_list_tags = set(apollo_config.get_json_value("agent_response_whitelist_tags", []))
  321. hit_white_list_tags = len(set(user_tags).intersection(white_list_tags)) > 0
  322. staff_white_lists = set(apollo_config.get_json_value("agent_response_whitelist_staffs", []))
  323. if not (staff_id in staff_white_lists or hit_white_list_tags):
  324. logger.warning(f"staff[{staff_id}] user[{user_id}]: skip reply")
  325. return False
  326. return True
  327. def send_multimodal_response(self, staff_id, user_id, response: Dict, skip_check=False):
  328. message_type = response["type"]
  329. logger.warning(f"staff[{staff_id}] user[{user_id}]: response[{message_type}] {response}")
  330. if message_type not in (MessageType.TEXT, MessageType.IMAGE_QW, MessageType.VOICE):
  331. logger.error(f"staff[{staff_id}] user[{user_id}]: unsupported message type {message_type}")
  332. return
  333. if not skip_check and not self.can_send_to_user(staff_id, user_id):
  334. return
  335. current_ts = int(time.time() * 1000)
  336. self.send_rate_limiter.wait_for_sending(staff_id, response)
  337. self.send_queue.produce(
  338. MqMessage.build(message_type, MessageChannel.CORP_WECHAT,
  339. staff_id, user_id, response["content"], current_ts)
  340. )
  341. def _route_to_human_intervention(self, user_id: str, origin_message: MqMessage):
  342. """路由到人工干预"""
  343. self.human_queue.produce(MqMessage.build(
  344. MessageType.TEXT,
  345. origin_message.channel,
  346. origin_message.sender,
  347. origin_message.receiver,
  348. "用户对话需人工介入,用户名:{}".format(user_id),
  349. int(time.time() * 1000)
  350. ))
  351. def _init_push_task_queue(self):
  352. credentials = rocketmq.Credentials()
  353. mq_conf = configs.get()['mq']
  354. rmq_client_conf = rocketmq.ClientConfiguration(mq_conf['endpoints'], credentials, mq_conf['instance_id'])
  355. rmq_topic = mq_conf['push_tasks_topic']
  356. rmq_group_generate = mq_conf['push_generate_task_group']
  357. rmq_group_send = mq_conf['push_send_task_group']
  358. self.push_task_rmq_topic = rmq_topic
  359. self.push_task_producer = rocketmq.Producer(rmq_client_conf, (rmq_topic,))
  360. self.push_task_producer.startup()
  361. # FIXME: 不应该暴露到agent service中
  362. self.push_generate_task_consumer = rocketmq.SimpleConsumer(rmq_client_conf, rmq_group_generate, await_duration=5)
  363. self.push_generate_task_consumer.startup()
  364. self.push_generate_task_consumer.subscribe(
  365. rmq_topic, filter_expression=FilterExpression(push_service.TaskType.GENERATE.value)
  366. )
  367. self.push_send_task_consumer = rocketmq.SimpleConsumer(rmq_client_conf, rmq_group_send, await_duration=5)
  368. self.push_send_task_consumer.startup()
  369. self.push_send_task_consumer.subscribe(
  370. rmq_topic, filter_expression=FilterExpression(push_service.TaskType.SEND.value)
  371. )
  372. def _resume_unfinished_push_task(self):
  373. def run_unfinished_push_task():
  374. logger.info("start to resume unfinished push task")
  375. push_task_worker_pool = PushTaskWorkerPool(
  376. self, self.push_task_rmq_topic, self.push_generate_task_consumer,
  377. self.push_send_task_consumer, self.push_task_producer
  378. )
  379. push_task_worker_pool.start()
  380. push_task_worker_pool.wait_to_finish()
  381. self.next_push_disabled = False
  382. logger.info("unfinished push tasks should be finished")
  383. thread = threading.Thread(target=run_unfinished_push_task)
  384. thread.start()
  385. def _check_initiative_conversations(self):
  386. logger.info("start to check initiative conversations")
  387. if self.next_push_disabled:
  388. logger.info("previous push tasks in processing, next push is disabled")
  389. return
  390. if not DialogueManager.is_time_suitable_for_active_conversation():
  391. logger.info("time is not suitable for active conversation")
  392. return
  393. push_scan_threads = []
  394. whitelist_staffs = apollo_config.get_json_value("agent_initiate_whitelist_staffs", [])
  395. for staff in self.user_relation_manager.list_staffs():
  396. staff_id = staff['third_party_user_id']
  397. if staff_id not in whitelist_staffs:
  398. logger.info(f"staff[{staff_id}] is not in whitelist, skip")
  399. continue
  400. scan_thread = threading.Thread(target=PushScanThread(
  401. staff_id, self, self.push_task_rmq_topic, self.push_task_producer).run)
  402. scan_thread.start()
  403. push_scan_threads.append(scan_thread)
  404. push_task_worker_pool = PushTaskWorkerPool(
  405. self, self.push_task_rmq_topic,
  406. self.push_generate_task_consumer, self.push_send_task_consumer, self.push_task_producer)
  407. push_task_worker_pool.start()
  408. for thread in push_scan_threads:
  409. thread.join()
  410. # 由于扫描和生成异步,两次扫描之间可能消息并未处理完,会有重复生成任务的情况,因此需等待上一次任务结束
  411. # 问题在于,如果每次创建出新的PushTaskWorkerPool,在上次任务有未处理完的消息即退出时,会有未处理的消息堆积
  412. push_task_worker_pool.wait_to_finish()
  413. def get_chat_response(self, agent: DialogueManager, user_message: Optional[str]) -> List[Dict]:
  414. chat_agent_ver = self.config.get('system', {}).get('chat_agent_version', 1)
  415. if chat_agent_ver == 2:
  416. return self._get_chat_response_v2(agent)
  417. else:
  418. text_resp = self._get_chat_response_v1(agent, user_message)
  419. return [{"type": MessageType.TEXT, "content": text_resp}] if text_resp else []
  420. def _get_chat_response_v1(self, agent: DialogueManager, user_message: Optional[str]) -> Optional[str]:
  421. chat_config = agent.build_chat_configuration(user_message, self.chat_service_type)
  422. config_for_logging = chat_config.copy()
  423. config_for_logging['messages'] = config_for_logging['messages'][-20:]
  424. logger.debug(config_for_logging)
  425. chat_response = self._call_chat_api(chat_config, self.chat_service_type)
  426. chat_response = self.sanitize_response(chat_response)
  427. if response := agent.generate_response(chat_response):
  428. return response
  429. else:
  430. logger.warning(f"staff[{agent.staff_id}] user[{agent.user_id}]: no response generated")
  431. return None
  432. def _get_chat_response_v2(self, main_agent: DialogueManager) -> List[Dict]:
  433. agent_config = get_agent_abtest_config('chat', main_agent.user_id,
  434. self.service_module_manager, self.agent_config_manager)
  435. if agent_config:
  436. try:
  437. tool_names = json.loads(agent_config.tools)
  438. except json.JSONDecodeError:
  439. logger.error(f"Invalid JSON in agent tools: {agent_config.tools}")
  440. tool_names = []
  441. chat_agent = MessageReplyAgent(model=agent_config.execution_model,
  442. system_prompt=agent_config.system_prompt,
  443. tools=get_tools(tool_names))
  444. else:
  445. chat_agent = MessageReplyAgent()
  446. chat_responses = chat_agent.generate_message(
  447. context=main_agent.get_prompt_context(None),
  448. dialogue_history=main_agent.dialogue_history[-100:]
  449. )
  450. if not chat_responses:
  451. logger.warning(f"staff[{main_agent.staff_id}] user[{main_agent.user_id}]: no response generated")
  452. return []
  453. final_responses = []
  454. for chat_response in chat_responses:
  455. if response := main_agent.generate_multimodal_response(chat_response):
  456. final_responses.append(response)
  457. else:
  458. # 存在非法/结束消息,清空待发消息
  459. final_responses.clear()
  460. return final_responses
  461. def _call_chat_api(self, chat_config: Dict, chat_service_type: ChatServiceType) -> str:
  462. if configs.get().get('debug_flags', {}).get('disable_llm_api_call', False):
  463. return 'LLM模拟回复 {}'.format(datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
  464. if chat_service_type == ChatServiceType.OPENAI_COMPATIBLE:
  465. # 指定了LLM模型则优先使用指定模型
  466. if chat_config.get('model_name', None):
  467. llm_client = chat_service.OpenAICompatible.create_client(chat_config['model_name'])
  468. chat_completion = llm_client.chat.completions.create(
  469. messages=chat_config['messages'],
  470. model=chat_config['model_name'],
  471. )
  472. elif chat_config.get('use_multimodal_model', False):
  473. chat_completion = self.multimodal_model_client.chat.completions.create(
  474. messages=chat_config['messages'],
  475. model=self.multimodal_model_name,
  476. )
  477. else:
  478. chat_completion = self.text_model_client.chat.completions.create(
  479. messages=chat_config['messages'],
  480. model=self.text_model_name,
  481. )
  482. response = chat_completion.choices[0].message.content
  483. elif chat_service_type == ChatServiceType.COZE_CHAT:
  484. bot_user_id = 'qywx_{}'.format(chat_config['user_id'])
  485. response = self.coze_client.create(
  486. chat_config['bot_id'], bot_user_id, chat_config['messages'],
  487. chat_config['custom_variables']
  488. )
  489. else:
  490. raise Exception('Unsupported chat service type: {}'.format(self.chat_service_type))
  491. return response
  492. @staticmethod
  493. def sanitize_response(response: str):
  494. pattern = r'\[?\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}\]?'
  495. response = re.sub(pattern, '', response)
  496. response = response.strip()
  497. return response