Browse Source

Update: use module-level logger

StrayWarrior 2 weeks ago
parent
commit
df268c17f4

+ 16 - 16
agent_service.py

@@ -15,6 +15,7 @@ from apscheduler.schedulers.background import BackgroundScheduler
 import chat_service
 import chat_service
 import configs
 import configs
 import logging_service
 import logging_service
+from logging_service import logger
 from chat_service import CozeChat, ChatServiceType
 from chat_service import CozeChat, ChatServiceType
 from dialogue_manager import DialogueManager, DialogueState, DialogueStateCache
 from dialogue_manager import DialogueManager, DialogueState, DialogueStateCache
 from user_manager import UserManager, LocalUserManager, MySQLUserManager, MySQLUserRelationManager, UserRelationManager
 from user_manager import UserManager, LocalUserManager, MySQLUserManager, MySQLUserRelationManager, UserRelationManager
@@ -23,7 +24,6 @@ from message_queue_backend import MessageQueueBackend, MemoryQueueBackend, Aliyu
 from user_profile_extractor import UserProfileExtractor
 from user_profile_extractor import UserProfileExtractor
 import threading
 import threading
 from message import MessageType, Message, MessageChannel
 from message import MessageType, Message, MessageChannel
-from logging_service import ColoredFormatter
 
 
 
 
 class AgentService:
 class AgentService:
@@ -93,22 +93,22 @@ class AgentService:
                     self.process_single_message(message)
                     self.process_single_message(message)
                     self.receive_queue.ack(message)
                     self.receive_queue.ack(message)
                 except Exception as e:
                 except Exception as e:
-                    logging.error("Error processing message: {}".format(e))
+                    logger.error("Error processing message: {}".format(e))
                     traceback.print_exc()
                     traceback.print_exc()
             time.sleep(1)
             time.sleep(1)
 
 
     def _update_user_profile(self, user_id, user_profile, message: str):
     def _update_user_profile(self, user_id, user_profile, message: str):
         profile_to_update = self.user_profile_extractor.extract_profile_info(user_profile, message)
         profile_to_update = self.user_profile_extractor.extract_profile_info(user_profile, message)
         if not profile_to_update:
         if not profile_to_update:
-            logging.debug("user_id: {}, no profile info extracted".format(user_id))
+            logger.debug("user_id: {}, no profile info extracted".format(user_id))
             return
             return
-        logging.warning("update user profile: {}".format(profile_to_update))
+        logger.warning("update user profile: {}".format(profile_to_update))
         merged_profile = self.user_profile_extractor.merge_profile_info(user_profile, profile_to_update)
         merged_profile = self.user_profile_extractor.merge_profile_info(user_profile, profile_to_update)
         self.user_manager.save_user_profile(user_id, merged_profile)
         self.user_manager.save_user_profile(user_id, merged_profile)
         return merged_profile
         return merged_profile
 
 
     def _schedule_aggregation_trigger(self, staff_id: str, user_id: str, delay_sec: int):
     def _schedule_aggregation_trigger(self, staff_id: str, user_id: str, delay_sec: int):
-        logging.debug("user: {}, schedule trigger message after {} seconds".format(user_id, delay_sec))
+        logger.debug("user: {}, schedule trigger message after {} seconds".format(user_id, delay_sec))
         message_ts = int((time.time() + delay_sec) * 1000)
         message_ts = int((time.time() + delay_sec) * 1000)
         message = Message.build(MessageType.AGGREGATION_TRIGGER, MessageChannel.SYSTEM, user_id, staff_id, None, message_ts)
         message = Message.build(MessageType.AGGREGATION_TRIGGER, MessageChannel.SYSTEM, user_id, staff_id, None, message_ts)
         # 系统消息使用特定的msgId,无实际意义
         # 系统消息使用特定的msgId,无实际意义
@@ -126,9 +126,9 @@ class AgentService:
         agent = self._get_agent_instance(staff_id, user_id)
         agent = self._get_agent_instance(staff_id, user_id)
 
 
         # 更新对话状态
         # 更新对话状态
-        logging.debug("process message: {}".format(message))
+        logger.debug("process message: {}".format(message))
         need_response, message_text = agent.update_state(message)
         need_response, message_text = agent.update_state(message)
-        logging.debug("user: {}, next state: {}".format(user_id, agent.current_state))
+        logger.debug("user: {}, next state: {}".format(user_id, agent.current_state))
 
 
         # 根据状态路由消息
         # 根据状态路由消息
         if agent.is_in_human_intervention():
         if agent.is_in_human_intervention():
@@ -136,7 +136,7 @@ class AgentService:
         elif agent.current_state == DialogueState.MESSAGE_AGGREGATING:
         elif agent.current_state == DialogueState.MESSAGE_AGGREGATING:
             if message.type != MessageType.AGGREGATION_TRIGGER:
             if message.type != MessageType.AGGREGATION_TRIGGER:
                 # 产生一个触发器,但是不能由触发器递归产生
                 # 产生一个触发器,但是不能由触发器递归产生
-                logging.debug("user: {}, waiting next message for aggregation".format(user_id))
+                logger.debug("user: {}, waiting next message for aggregation".format(user_id))
                 self._schedule_aggregation_trigger(staff_id, user_id, agent.message_aggregation_sec)
                 self._schedule_aggregation_trigger(staff_id, user_id, agent.message_aggregation_sec)
             return
             return
         elif need_response:
         elif need_response:
@@ -144,7 +144,7 @@ class AgentService:
             self._update_user_profile(user_id, user_profile, message_text)
             self._update_user_profile(user_id, user_profile, message_text)
             self._get_chat_response(user_id, agent, message_text)
             self._get_chat_response(user_id, agent, message_text)
         else:
         else:
-            logging.debug(f"staff[{staff_id}], user[{user_id}]: do not need response")
+            logger.debug(f"staff[{staff_id}], user[{user_id}]: do not need response")
 
 
     def _route_to_human_intervention(self, user_id: str, origin_message: Message):
     def _route_to_human_intervention(self, user_id: str, origin_message: Message):
         """路由到人工干预"""
         """路由到人工干预"""
@@ -166,33 +166,33 @@ class AgentService:
             should_initiate = agent.should_initiate_conversation()
             should_initiate = agent.should_initiate_conversation()
 
 
             if should_initiate:
             if should_initiate:
-                logging.warning("user: {}, initiate conversation".format(user_id))
+                logger.warning("user: {}, initiate conversation".format(user_id))
                 self._get_chat_response(user_id, agent, None)
                 self._get_chat_response(user_id, agent, None)
             else:
             else:
-                logging.debug("user: {}, do not initiate conversation".format(user_id))
+                logger.debug("user: {}, do not initiate conversation".format(user_id))
 
 
     def _get_chat_response(self, user_id: str, agent: DialogueManager,
     def _get_chat_response(self, user_id: str, agent: DialogueManager,
                            user_message: Optional[str]):
                            user_message: Optional[str]):
         """处理LLM响应"""
         """处理LLM响应"""
         chat_config = agent.build_chat_configuration(user_message, self.chat_service_type)
         chat_config = agent.build_chat_configuration(user_message, self.chat_service_type)
-        logging.debug(chat_config)
+        logger.debug(chat_config)
         # FIXME(zhoutian): 临时处理去除头尾的空格
         # FIXME(zhoutian): 临时处理去除头尾的空格
         chat_response = self._call_chat_api(chat_config).strip()
         chat_response = self._call_chat_api(chat_config).strip()
 
 
         if response := agent.generate_response(chat_response):
         if response := agent.generate_response(chat_response):
-            logging.warning(f"staff[{agent.staff_id}] user[{user_id}]: response: {response}")
+            logger.warning(f"staff[{agent.staff_id}] user[{user_id}]: response: {response}")
             current_ts = int(time.time() * 1000)
             current_ts = int(time.time() * 1000)
             # FIXME(zhoutian)
             # FIXME(zhoutian)
             # 测试期间临时逻辑,只发送特定的用户
             # 测试期间临时逻辑,只发送特定的用户
             if agent.staff_id not in set(['1688854492669990']):
             if agent.staff_id not in set(['1688854492669990']):
-                logging.warning(f"skip message from sender [{agent.staff_id}]")
+                logger.warning(f"skip message from sender [{agent.staff_id}]")
                 return
                 return
             self.send_queue.produce(
             self.send_queue.produce(
                 Message.build(MessageType.TEXT, MessageChannel.CORP_WECHAT,
                 Message.build(MessageType.TEXT, MessageChannel.CORP_WECHAT,
                               agent.staff_id, user_id, response, current_ts)
                               agent.staff_id, user_id, response, current_ts)
             )
             )
         else:
         else:
-            logging.warning(f"staff[{agent.staff_id}] user[{user_id}]: no response generated")
+            logger.warning(f"staff[{agent.staff_id}] user[{user_id}]: no response generated")
 
 
     def _call_chat_api(self, chat_config: Dict) -> str:
     def _call_chat_api(self, chat_config: Dict) -> str:
         if configs.get().get('debug_flags', {}).get('disable_llm_api_call', False):
         if configs.get().get('debug_flags', {}).get('disable_llm_api_call', False):
@@ -216,7 +216,7 @@ class AgentService:
 if __name__ == "__main__":
 if __name__ == "__main__":
     config = configs.get()
     config = configs.get()
     logging_service.setup_root_logger()
     logging_service.setup_root_logger()
-    logging.warning("current env: {}".format(configs.get_env()))
+    logger.warning("current env: {}".format(configs.get_env()))
     scheduler_logger = logging.getLogger('apscheduler')
     scheduler_logger = logging.getLogger('apscheduler')
     scheduler_logger.setLevel(logging.WARNING)
     scheduler_logger.setLevel(logging.WARNING)
 
 

+ 3 - 3
chat_service.py

@@ -7,7 +7,7 @@ import os
 import threading
 import threading
 from typing import List, Dict, Optional
 from typing import List, Dict, Optional
 from enum import Enum, auto
 from enum import Enum, auto
-import logging
+from logging_service import logger
 import cozepy
 import cozepy
 from cozepy import Coze, TokenAuth, Message, ChatStatus, MessageType, JWTOAuthApp, JWTAuth
 from cozepy import Coze, TokenAuth, Message, ChatStatus, MessageType, JWTOAuthApp, JWTAuth
 import time
 import time
@@ -57,9 +57,9 @@ class CozeChat:
         response = self.coze.chat.create_and_poll(
         response = self.coze.chat.create_and_poll(
             bot_id=bot_id, user_id=user_id, additional_messages=messages,
             bot_id=bot_id, user_id=user_id, additional_messages=messages,
             custom_variables=custom_variables)
             custom_variables=custom_variables)
-        logging.debug("Coze response size: {}".format(len(response.messages)))
+        logger.debug("Coze response size: {}".format(len(response.messages)))
         if response.chat.status != ChatStatus.COMPLETED:
         if response.chat.status != ChatStatus.COMPLETED:
-            logging.error("Coze chat not completed: {}".format(response.chat.status))
+            logger.error("Coze chat not completed: {}".format(response.chat.status))
             return None
             return None
         final_response = None
         final_response = None
         for message in response.messages:
         for message in response.messages:

+ 1 - 1
configs/dev.yaml

@@ -42,7 +42,7 @@ chat_api:
     account_id: 649175100044793
     account_id: 649175100044793
 
 
 debug_flags:
 debug_flags:
-  disable_llm_api_call: False
+  disable_llm_api_call: True
   use_local_user_storage: False
   use_local_user_storage: False
   console_input: True
   console_input: True
   disable_active_conversation: False
   disable_active_conversation: False

+ 2 - 2
database.py

@@ -5,7 +5,7 @@
 # Copyright © 2024 StrayWarrior <i@straywarrior.com>
 # Copyright © 2024 StrayWarrior <i@straywarrior.com>
 
 
 import pymysql
 import pymysql
-import logging
+from logging_service import logger
 
 
 class MySQLManager:
 class MySQLManager:
     def __init__(self, config):
     def __init__(self, config):
@@ -64,7 +64,7 @@ class MySQLManager:
                 conn.commit()
                 conn.commit()
                 return rows
                 return rows
         except pymysql.MySQLError as e:
         except pymysql.MySQLError as e:
-            logging.error(f"Error in batch_insert: {e}")
+            logger.error(f"Error in batch_insert: {e}")
             conn.rollback()
             conn.rollback()
             raise e
             raise e
         conn.close()
         conn.close()

+ 8 - 12
dialogue_manager.py

@@ -6,7 +6,7 @@ from enum import Enum, auto
 from typing import Dict, List, Optional, Tuple, Any
 from typing import Dict, List, Optional, Tuple, Any
 from datetime import datetime
 from datetime import datetime
 import time
 import time
-import logging
+from logging_service import logger
 
 
 import pymysql.cursors
 import pymysql.cursors
 
 
@@ -21,10 +21,6 @@ from message import MessageType, Message
 from user_manager import UserManager
 from user_manager import UserManager
 from prompt_templates import *
 from prompt_templates import *
 
 
-# 配置日志
-logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(funcName)s[%(lineno)d] - %(levelname)s - %(message)s')
-logger = logging.getLogger(__name__)
-
 class DummyVectorMemoryManager:
 class DummyVectorMemoryManager:
     def __init__(self, user_id):
     def __init__(self, user_id):
         pass
         pass
@@ -68,7 +64,7 @@ class DialogueStateCache:
         query = f"SELECT current_state, previous_state FROM {self.table} WHERE staff_id=%s AND user_id=%s"
         query = f"SELECT current_state, previous_state FROM {self.table} WHERE staff_id=%s AND user_id=%s"
         data = self.db.select(query, pymysql.cursors.DictCursor, (staff_id, user_id))
         data = self.db.select(query, pymysql.cursors.DictCursor, (staff_id, user_id))
         if not data:
         if not data:
-            logging.warning(f"staff[{staff_id}], user[{user_id}]: agent state not found")
+            logger.warning(f"staff[{staff_id}], user[{user_id}]: agent state not found")
             state = DialogueState.CHITCHAT
             state = DialogueState.CHITCHAT
             previous_state = DialogueState.INITIALIZED
             previous_state = DialogueState.INITIALIZED
             self.set_state(staff_id, user_id, state, previous_state)
             self.set_state(staff_id, user_id, state, previous_state)
@@ -82,7 +78,7 @@ class DialogueStateCache:
                 f" VALUES (%s, %s, %s, %s) " \
                 f" VALUES (%s, %s, %s, %s) " \
                 f"ON DUPLICATE KEY UPDATE current_state=%s, previous_state=%s"
                 f"ON DUPLICATE KEY UPDATE current_state=%s, previous_state=%s"
         rows = self.db.execute(query, (staff_id, user_id, state.value, previous_state.value, state.value, previous_state.value))
         rows = self.db.execute(query, (staff_id, user_id, state.value, previous_state.value, state.value, previous_state.value))
-        logging.debug("staff[{}], user[{}]: set state: {}, previous state: {}, rows affected: {}"
+        logger.debug("staff[{}], user[{}]: set state: {}, previous state: {}, rows affected: {}"
                       .format(staff_id, user_id, state, previous_state, rows))
                       .format(staff_id, user_id, state, previous_state, rows))
 
 
 class DialogueManager:
 class DialogueManager:
@@ -122,7 +118,7 @@ class DialogueManager:
             # 默认设置为24小时前
             # 默认设置为24小时前
             self.last_interaction_time = int(time.time() * 1000) - 24 * 3600 * 1000
             self.last_interaction_time = int(time.time() * 1000) - 24 * 3600 * 1000
         time_for_read = datetime.fromtimestamp(self.last_interaction_time / 1000).strftime("%Y-%m-%d %H:%M:%S")
         time_for_read = datetime.fromtimestamp(self.last_interaction_time / 1000).strftime("%Y-%m-%d %H:%M:%S")
-        logging.debug(f"staff[{self.staff_id}], user[{self.user_id}]: state: {self.current_state.name}, last_interaction: {time_for_read}")
+        logger.debug(f"staff[{self.staff_id}], user[{self.user_id}]: state: {self.current_state.name}, last_interaction: {time_for_read}")
 
 
     def persist_state(self):
     def persist_state(self):
         """持久化对话状态"""
         """持久化对话状态"""
@@ -165,7 +161,7 @@ class DialogueManager:
             # 收到的是特殊定时触发的空消息,且在聚合中,且已经超时,恢复之前状态,继续处理
             # 收到的是特殊定时触发的空消息,且在聚合中,且已经超时,恢复之前状态,继续处理
             if message.type == MessageType.AGGREGATION_TRIGGER \
             if message.type == MessageType.AGGREGATION_TRIGGER \
                     and message_ts - self.last_interaction_time > self.message_aggregation_sec * 1000:
                     and message_ts - self.last_interaction_time > self.message_aggregation_sec * 1000:
-                logging.debug("user_id: {}, last interaction time: {}".format(
+                logger.debug("user_id: {}, last interaction time: {}".format(
                     self.user_id, datetime.fromtimestamp(self.last_interaction_time / 1000)))
                     self.user_id, datetime.fromtimestamp(self.last_interaction_time / 1000)))
                 self.current_state = self.previous_state
                 self.current_state = self.previous_state
             else:
             else:
@@ -459,7 +455,7 @@ class DialogueManager:
             消息列表
             消息列表
         """
         """
         dialogue_history = self.history_dialogue_service.get_dialogue_history(self.staff_id, self.user_id)
         dialogue_history = self.history_dialogue_service.get_dialogue_history(self.staff_id, self.user_id)
-        logging.debug("staff[{}], user[{}], dialogue_history: {}".format(
+        logger.debug("staff[{}], user[{}], dialogue_history: {}".format(
             self.staff_id, self.user_id, dialogue_history
             self.staff_id, self.user_id, dialogue_history
         ))
         ))
         messages = []
         messages = []
@@ -478,7 +474,7 @@ class DialogueManager:
         elif chat_service_type == ChatServiceType.COZE_CHAT:
         elif chat_service_type == ChatServiceType.COZE_CHAT:
             for entry in dialogue_history:
             for entry in dialogue_history:
                 if not entry['content']:
                 if not entry['content']:
-                    logging.warning("staff[{}], user[{}], role[{}]: empty content in dialogue history".format(
+                    logger.warning("staff[{}], user[{}], role[{}]: empty content in dialogue history".format(
                         self.staff_id, self.user_id, entry['role']
                         self.staff_id, self.user_id, entry['role']
                     ))
                     ))
                     continue
                     continue
@@ -498,7 +494,7 @@ class DialogueManager:
                 messages.append(cozepy.Message.build_user_question_text('请开始对话'))
                 messages.append(cozepy.Message.build_user_question_text('请开始对话'))
             #FIXME(zhoutian): 临时报警
             #FIXME(zhoutian): 临时报警
             if user_message and not messages:
             if user_message and not messages:
-                logging.error(f"staff[{self.staff_id}], user[{self.user_id}]: inconsistency in messages")
+                logger.error(f"staff[{self.staff_id}], user[{self.user_id}]: inconsistency in messages")
         config['messages'] = messages
         config['messages'] = messages
 
 
         return config
         return config

+ 2 - 3
history_dialogue_service.py

@@ -3,8 +3,7 @@
 # vim:fenc=utf-8
 # vim:fenc=utf-8
 
 
 import requests
 import requests
-import logging
-import json
+from logging_service import logger
 
 
 import configs
 import configs
 
 
@@ -32,7 +31,7 @@ class HistoryDialogueService:
             elif sender == staff_id:
             elif sender == staff_id:
                 role = 'assistant'
                 role = 'assistant'
             else:
             else:
-                logging.warning("Unknown sender in dialogue history: {}".format(sender))
+                logger.warning("Unknown sender in dialogue history: {}".format(sender))
                 continue
                 continue
             ret.append({
             ret.append({
                 'role': role,
                 'role': role,

+ 8 - 4
logging_service.py

@@ -22,14 +22,18 @@ class ColoredFormatter(logging.Formatter):
             message = f"{COLORS[record.levelname]}{message}{COLORS['RESET']}"
             message = f"{COLORS[record.levelname]}{message}{COLORS['RESET']}"
         return message
         return message
 
 
-def setup_root_logger():
-    logging.getLogger().setLevel(logging.DEBUG)
+def setup_root_logger(level=logging.DEBUG):
     console_handler = logging.StreamHandler()
     console_handler = logging.StreamHandler()
     console_handler.setLevel(logging.DEBUG)
     console_handler.setLevel(logging.DEBUG)
     formatter = ColoredFormatter(
     formatter = ColoredFormatter(
-        '%(asctime)s - %(funcName)s[%(lineno)d] - %(levelname)s - %(message)s'
+        '%(asctime)s - %(name)s %(funcName)s[%(lineno)d] - %(levelname)s - %(message)s'
     )
     )
     console_handler.setFormatter(formatter)
     console_handler.setFormatter(formatter)
     root_logger = logging.getLogger()
     root_logger = logging.getLogger()
     root_logger.handlers.clear()
     root_logger.handlers.clear()
-    root_logger.addHandler(console_handler)
+    root_logger.addHandler(console_handler)
+
+    agent_logger = logging.getLogger('agent')
+    agent_logger.setLevel(level)
+
+logger = logging.getLogger('agent')

+ 4 - 4
message_queue_backend.py

@@ -4,7 +4,7 @@
 
 
 import abc
 import abc
 import time
 import time
-import logging
+from logging_service import logger
 from typing import Dict, Any, Optional
 from typing import Dict, Any, Optional
 import configs
 import configs
 
 
@@ -84,12 +84,12 @@ class AliyunRocketMQQueueBackend(MessageQueueBackend):
             return None
             return None
         rmq_message = messages[0]
         rmq_message = messages[0]
         body = rmq_message.body.decode('utf-8')
         body = rmq_message.body.decode('utf-8')
-        logging.debug("recv message body: {}".format(body))
+        logger.debug("recv message body: {}".format(body))
         try:
         try:
             message = Message.from_json(body)
             message = Message.from_json(body)
             message._rmq_message = rmq_message
             message._rmq_message = rmq_message
         except Exception as e:
         except Exception as e:
-            logging.error("Invalid message: {}. Parsing error: {}".format(body, e))
+            logger.error("Invalid message: {}. Parsing error: {}".format(body, e))
             # 如果消息非法,直接ACK,避免死信
             # 如果消息非法,直接ACK,避免死信
             self.consumer.ack(rmq_message)
             self.consumer.ack(rmq_message)
             return None
             return None
@@ -98,7 +98,7 @@ class AliyunRocketMQQueueBackend(MessageQueueBackend):
     def ack(self, message: Message):
     def ack(self, message: Message):
         if not message._rmq_message:
         if not message._rmq_message:
             raise ValueError("Message not set with _rmq_message.")
             raise ValueError("Message not set with _rmq_message.")
-        logging.debug("ack message: {}".format(message))
+        logger.debug("ack message: {}".format(message))
         self.consumer.ack(message._rmq_message)
         self.consumer.ack(message._rmq_message)
 
 
     def produce(self, message: Message) -> None:
     def produce(self, message: Message) -> None:

+ 9 - 8
user_manager.py

@@ -1,7 +1,8 @@
 #! /usr/bin/env python
 #! /usr/bin/env python
 # -*- coding: utf-8 -*-
 # -*- coding: utf-8 -*-
 # vim:fenc=utf-8
 # vim:fenc=utf-8
-import logging
+
+from logging_service import logger
 from typing import Dict, Optional, Tuple, Any, List
 from typing import Dict, Optional, Tuple, Any, List
 import json
 import json
 import time
 import time
@@ -114,11 +115,11 @@ class MySQLUserManager(UserManager):
         sql = f"SELECT name, wxid, profile_data_v1 FROM {self.table_name} WHERE third_party_user_id = {user_id}"
         sql = f"SELECT name, wxid, profile_data_v1 FROM {self.table_name} WHERE third_party_user_id = {user_id}"
         data = self.db.select(sql, pymysql.cursors.DictCursor)
         data = self.db.select(sql, pymysql.cursors.DictCursor)
         if not data:
         if not data:
-            logging.error(f"user[{user_id}] not found")
+            logger.error(f"user[{user_id}] not found")
             return {}
             return {}
         data = data[0]
         data = data[0]
         if not data['profile_data_v1']:
         if not data['profile_data_v1']:
-            logging.warning(f"user[{user_id}] profile not found, create a default one")
+            logger.warning(f"user[{user_id}] profile not found, create a default one")
             default_profile = self.get_default_profile(nickname=data['name'])
             default_profile = self.get_default_profile(nickname=data['name'])
             self.save_user_profile(user_id, default_profile)
             self.save_user_profile(user_id, default_profile)
             return default_profile
             return default_profile
@@ -143,7 +144,7 @@ class MySQLUserManager(UserManager):
               f"FROM {self.staff_table} WHERE third_party_user_id = '{staff_id}'"
               f"FROM {self.staff_table} WHERE third_party_user_id = '{staff_id}'"
         data = self.db.select(sql, pymysql.cursors.DictCursor)
         data = self.db.select(sql, pymysql.cursors.DictCursor)
         if not data:
         if not data:
-            logging.error(f"staff[{staff_id}] not found")
+            logger.error(f"staff[{staff_id}] not found")
             return {}
             return {}
         profile = data[0]
         profile = data[0]
         return profile
         return profile
@@ -180,19 +181,19 @@ class MySQLUserRelationManager(UserRelationManager):
             sql = f"SELECT id FROM {self.staff_table} WHERE carrier_id = '{wxid}'"
             sql = f"SELECT id FROM {self.staff_table} WHERE carrier_id = '{wxid}'"
             staff_data = self.wecom_db.select(sql, pymysql.cursors.DictCursor)
             staff_data = self.wecom_db.select(sql, pymysql.cursors.DictCursor)
             if not staff_data:
             if not staff_data:
-                logging.error(f"staff[{wxid}] not found in wecom database")
+                logger.error(f"staff[{wxid}] not found in wecom database")
                 continue
                 continue
             staff_id = staff_data[0]['id']
             staff_id = staff_data[0]['id']
             sql = f"SELECT user_id FROM {self.relation_table} WHERE staff_id = '{staff_id}' AND is_delete = 0"
             sql = f"SELECT user_id FROM {self.relation_table} WHERE staff_id = '{staff_id}' AND is_delete = 0"
             user_data = self.wecom_db.select(sql, pymysql.cursors.DictCursor)
             user_data = self.wecom_db.select(sql, pymysql.cursors.DictCursor)
             if not user_data:
             if not user_data:
-                logging.warning(f"staff[{wxid}] has no user")
+                logger.warning(f"staff[{wxid}] has no user")
                 continue
                 continue
             user_ids = tuple(user['user_id'] for user in user_data)
             user_ids = tuple(user['user_id'] for user in user_data)
             sql = f"SELECT union_id FROM {self.user_table} WHERE id IN {str(user_ids)} AND union_id is not null"
             sql = f"SELECT union_id FROM {self.user_table} WHERE id IN {str(user_ids)} AND union_id is not null"
             user_data = self.wecom_db.select(sql, pymysql.cursors.DictCursor)
             user_data = self.wecom_db.select(sql, pymysql.cursors.DictCursor)
             if not user_data:
             if not user_data:
-                logging.warning(f"staff[{wxid}] users not found in wecom database")
+                logger.warning(f"staff[{wxid}] users not found in wecom database")
                 continue
                 continue
             user_union_ids = tuple(user['union_id'] for user in user_data)
             user_union_ids = tuple(user['union_id'] for user in user_data)
             batch_size = 100
             batch_size = 100
@@ -205,7 +206,7 @@ class MySQLUserRelationManager(UserRelationManager):
                 sql = f"SELECT third_party_user_id, wxid FROM {self.agent_user_table} WHERE wxid IN {str(batch_union_ids)}"
                 sql = f"SELECT third_party_user_id, wxid FROM {self.agent_user_table} WHERE wxid IN {str(batch_union_ids)}"
                 batch_agent_user_data = self.agent_db.select(sql, pymysql.cursors.DictCursor)
                 batch_agent_user_data = self.agent_db.select(sql, pymysql.cursors.DictCursor)
                 if len(agent_user_data) != len(batch_union_ids):
                 if len(agent_user_data) != len(batch_union_ids):
-                    # logging.debug(f"staff[{wxid}] some users not found in agent database")
+                    # logger.debug(f"staff[{wxid}] some users not found in agent database")
                     pass
                     pass
                 agent_user_data.extend(batch_agent_user_data)
                 agent_user_data.extend(batch_agent_user_data)
             staff_user_pairs = [
             staff_user_pairs = [

+ 5 - 5
user_profile_extractor.py

@@ -9,7 +9,7 @@ import chat_service
 import configs
 import configs
 from prompt_templates import USER_PROFILE_EXTRACT_PROMPT
 from prompt_templates import USER_PROFILE_EXTRACT_PROMPT
 from openai import OpenAI
 from openai import OpenAI
-import logging
+from logging_service import logger
 
 
 import global_flags
 import global_flags
 
 
@@ -82,7 +82,7 @@ class UserProfileExtractor:
             return None
             return None
 
 
         try:
         try:
-            logging.debug("try to extract profile from message: {}".format(dialogue_history))
+            logger.debug("try to extract profile from message: {}".format(dialogue_history))
             response = self.llm_client.chat.completions.create(
             response = self.llm_client.chat.completions.create(
                 model=self.model_name,
                 model=self.model_name,
                 messages=[
                 messages=[
@@ -95,7 +95,7 @@ class UserProfileExtractor:
 
 
             # 解析Function Call的参数
             # 解析Function Call的参数
             tool_calls = response.choices[0].message.tool_calls
             tool_calls = response.choices[0].message.tool_calls
-            logging.debug(response)
+            logger.debug(response)
             if tool_calls:
             if tool_calls:
                 function_call = tool_calls[0]
                 function_call = tool_calls[0]
                 if function_call.function.name == 'update_user_profile':
                 if function_call.function.name == 'update_user_profile':
@@ -103,11 +103,11 @@ class UserProfileExtractor:
                         profile_info = json.loads(function_call.function.arguments)
                         profile_info = json.loads(function_call.function.arguments)
                         return {k: v for k, v in profile_info.items() if v}
                         return {k: v for k, v in profile_info.items() if v}
                     except json.JSONDecodeError:
                     except json.JSONDecodeError:
-                        logging.error("无法解析提取的用户信息")
+                        logger.error("无法解析提取的用户信息")
                         return None
                         return None
 
 
         except Exception as e:
         except Exception as e:
-            logging.error(f"用户画像提取出错: {e}")
+            logger.error(f"用户画像提取出错: {e}")
             return None
             return None
 
 
     def merge_profile_info(self, existing_profile: Dict, new_info: Dict) -> Dict:
     def merge_profile_info(self, existing_profile: Dict, new_info: Dict) -> Dict: