Browse Source

Update: use module-level logger

StrayWarrior 2 weeks ago
parent
commit
df268c17f4

+ 16 - 16
agent_service.py

@@ -15,6 +15,7 @@ from apscheduler.schedulers.background import BackgroundScheduler
 import chat_service
 import configs
 import logging_service
+from logging_service import logger
 from chat_service import CozeChat, ChatServiceType
 from dialogue_manager import DialogueManager, DialogueState, DialogueStateCache
 from user_manager import UserManager, LocalUserManager, MySQLUserManager, MySQLUserRelationManager, UserRelationManager
@@ -23,7 +24,6 @@ from message_queue_backend import MessageQueueBackend, MemoryQueueBackend, Aliyu
 from user_profile_extractor import UserProfileExtractor
 import threading
 from message import MessageType, Message, MessageChannel
-from logging_service import ColoredFormatter
 
 
 class AgentService:
@@ -93,22 +93,22 @@ class AgentService:
                     self.process_single_message(message)
                     self.receive_queue.ack(message)
                 except Exception as e:
-                    logging.error("Error processing message: {}".format(e))
+                    logger.error("Error processing message: {}".format(e))
                     traceback.print_exc()
             time.sleep(1)
 
     def _update_user_profile(self, user_id, user_profile, message: str):
         profile_to_update = self.user_profile_extractor.extract_profile_info(user_profile, message)
         if not profile_to_update:
-            logging.debug("user_id: {}, no profile info extracted".format(user_id))
+            logger.debug("user_id: {}, no profile info extracted".format(user_id))
             return
-        logging.warning("update user profile: {}".format(profile_to_update))
+        logger.warning("update user profile: {}".format(profile_to_update))
         merged_profile = self.user_profile_extractor.merge_profile_info(user_profile, profile_to_update)
         self.user_manager.save_user_profile(user_id, merged_profile)
         return merged_profile
 
     def _schedule_aggregation_trigger(self, staff_id: str, user_id: str, delay_sec: int):
-        logging.debug("user: {}, schedule trigger message after {} seconds".format(user_id, delay_sec))
+        logger.debug("user: {}, schedule trigger message after {} seconds".format(user_id, delay_sec))
         message_ts = int((time.time() + delay_sec) * 1000)
         message = Message.build(MessageType.AGGREGATION_TRIGGER, MessageChannel.SYSTEM, user_id, staff_id, None, message_ts)
         # 系统消息使用特定的msgId,无实际意义
@@ -126,9 +126,9 @@ class AgentService:
         agent = self._get_agent_instance(staff_id, user_id)
 
         # 更新对话状态
-        logging.debug("process message: {}".format(message))
+        logger.debug("process message: {}".format(message))
         need_response, message_text = agent.update_state(message)
-        logging.debug("user: {}, next state: {}".format(user_id, agent.current_state))
+        logger.debug("user: {}, next state: {}".format(user_id, agent.current_state))
 
         # 根据状态路由消息
         if agent.is_in_human_intervention():
@@ -136,7 +136,7 @@ class AgentService:
         elif agent.current_state == DialogueState.MESSAGE_AGGREGATING:
             if message.type != MessageType.AGGREGATION_TRIGGER:
                 # 产生一个触发器,但是不能由触发器递归产生
-                logging.debug("user: {}, waiting next message for aggregation".format(user_id))
+                logger.debug("user: {}, waiting next message for aggregation".format(user_id))
                 self._schedule_aggregation_trigger(staff_id, user_id, agent.message_aggregation_sec)
             return
         elif need_response:
@@ -144,7 +144,7 @@ class AgentService:
             self._update_user_profile(user_id, user_profile, message_text)
             self._get_chat_response(user_id, agent, message_text)
         else:
-            logging.debug(f"staff[{staff_id}], user[{user_id}]: do not need response")
+            logger.debug(f"staff[{staff_id}], user[{user_id}]: do not need response")
 
     def _route_to_human_intervention(self, user_id: str, origin_message: Message):
         """路由到人工干预"""
@@ -166,33 +166,33 @@ class AgentService:
             should_initiate = agent.should_initiate_conversation()
 
             if should_initiate:
-                logging.warning("user: {}, initiate conversation".format(user_id))
+                logger.warning("user: {}, initiate conversation".format(user_id))
                 self._get_chat_response(user_id, agent, None)
             else:
-                logging.debug("user: {}, do not initiate conversation".format(user_id))
+                logger.debug("user: {}, do not initiate conversation".format(user_id))
 
     def _get_chat_response(self, user_id: str, agent: DialogueManager,
                            user_message: Optional[str]):
         """处理LLM响应"""
         chat_config = agent.build_chat_configuration(user_message, self.chat_service_type)
-        logging.debug(chat_config)
+        logger.debug(chat_config)
         # FIXME(zhoutian): 临时处理去除头尾的空格
         chat_response = self._call_chat_api(chat_config).strip()
 
         if response := agent.generate_response(chat_response):
-            logging.warning(f"staff[{agent.staff_id}] user[{user_id}]: response: {response}")
+            logger.warning(f"staff[{agent.staff_id}] user[{user_id}]: response: {response}")
             current_ts = int(time.time() * 1000)
             # FIXME(zhoutian)
             # 测试期间临时逻辑,只发送特定的用户
             if agent.staff_id not in set(['1688854492669990']):
-                logging.warning(f"skip message from sender [{agent.staff_id}]")
+                logger.warning(f"skip message from sender [{agent.staff_id}]")
                 return
             self.send_queue.produce(
                 Message.build(MessageType.TEXT, MessageChannel.CORP_WECHAT,
                               agent.staff_id, user_id, response, current_ts)
             )
         else:
-            logging.warning(f"staff[{agent.staff_id}] user[{user_id}]: no response generated")
+            logger.warning(f"staff[{agent.staff_id}] user[{user_id}]: no response generated")
 
     def _call_chat_api(self, chat_config: Dict) -> str:
         if configs.get().get('debug_flags', {}).get('disable_llm_api_call', False):
@@ -216,7 +216,7 @@ class AgentService:
 if __name__ == "__main__":
     config = configs.get()
     logging_service.setup_root_logger()
-    logging.warning("current env: {}".format(configs.get_env()))
+    logger.warning("current env: {}".format(configs.get_env()))
     scheduler_logger = logging.getLogger('apscheduler')
     scheduler_logger.setLevel(logging.WARNING)
 

+ 3 - 3
chat_service.py

@@ -7,7 +7,7 @@ import os
 import threading
 from typing import List, Dict, Optional
 from enum import Enum, auto
-import logging
+from logging_service import logger
 import cozepy
 from cozepy import Coze, TokenAuth, Message, ChatStatus, MessageType, JWTOAuthApp, JWTAuth
 import time
@@ -57,9 +57,9 @@ class CozeChat:
         response = self.coze.chat.create_and_poll(
             bot_id=bot_id, user_id=user_id, additional_messages=messages,
             custom_variables=custom_variables)
-        logging.debug("Coze response size: {}".format(len(response.messages)))
+        logger.debug("Coze response size: {}".format(len(response.messages)))
         if response.chat.status != ChatStatus.COMPLETED:
-            logging.error("Coze chat not completed: {}".format(response.chat.status))
+            logger.error("Coze chat not completed: {}".format(response.chat.status))
             return None
         final_response = None
         for message in response.messages:

+ 1 - 1
configs/dev.yaml

@@ -42,7 +42,7 @@ chat_api:
     account_id: 649175100044793
 
 debug_flags:
-  disable_llm_api_call: False
+  disable_llm_api_call: True
   use_local_user_storage: False
   console_input: True
   disable_active_conversation: False

+ 2 - 2
database.py

@@ -5,7 +5,7 @@
 # Copyright © 2024 StrayWarrior <i@straywarrior.com>
 
 import pymysql
-import logging
+from logging_service import logger
 
 class MySQLManager:
     def __init__(self, config):
@@ -64,7 +64,7 @@ class MySQLManager:
                 conn.commit()
                 return rows
         except pymysql.MySQLError as e:
-            logging.error(f"Error in batch_insert: {e}")
+            logger.error(f"Error in batch_insert: {e}")
             conn.rollback()
             raise e
         conn.close()

+ 8 - 12
dialogue_manager.py

@@ -6,7 +6,7 @@ from enum import Enum, auto
 from typing import Dict, List, Optional, Tuple, Any
 from datetime import datetime
 import time
-import logging
+from logging_service import logger
 
 import pymysql.cursors
 
@@ -21,10 +21,6 @@ from message import MessageType, Message
 from user_manager import UserManager
 from prompt_templates import *
 
-# 配置日志
-logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(funcName)s[%(lineno)d] - %(levelname)s - %(message)s')
-logger = logging.getLogger(__name__)
-
 class DummyVectorMemoryManager:
     def __init__(self, user_id):
         pass
@@ -68,7 +64,7 @@ class DialogueStateCache:
         query = f"SELECT current_state, previous_state FROM {self.table} WHERE staff_id=%s AND user_id=%s"
         data = self.db.select(query, pymysql.cursors.DictCursor, (staff_id, user_id))
         if not data:
-            logging.warning(f"staff[{staff_id}], user[{user_id}]: agent state not found")
+            logger.warning(f"staff[{staff_id}], user[{user_id}]: agent state not found")
             state = DialogueState.CHITCHAT
             previous_state = DialogueState.INITIALIZED
             self.set_state(staff_id, user_id, state, previous_state)
@@ -82,7 +78,7 @@ class DialogueStateCache:
                 f" VALUES (%s, %s, %s, %s) " \
                 f"ON DUPLICATE KEY UPDATE current_state=%s, previous_state=%s"
         rows = self.db.execute(query, (staff_id, user_id, state.value, previous_state.value, state.value, previous_state.value))
-        logging.debug("staff[{}], user[{}]: set state: {}, previous state: {}, rows affected: {}"
+        logger.debug("staff[{}], user[{}]: set state: {}, previous state: {}, rows affected: {}"
                       .format(staff_id, user_id, state, previous_state, rows))
 
 class DialogueManager:
@@ -122,7 +118,7 @@ class DialogueManager:
             # 默认设置为24小时前
             self.last_interaction_time = int(time.time() * 1000) - 24 * 3600 * 1000
         time_for_read = datetime.fromtimestamp(self.last_interaction_time / 1000).strftime("%Y-%m-%d %H:%M:%S")
-        logging.debug(f"staff[{self.staff_id}], user[{self.user_id}]: state: {self.current_state.name}, last_interaction: {time_for_read}")
+        logger.debug(f"staff[{self.staff_id}], user[{self.user_id}]: state: {self.current_state.name}, last_interaction: {time_for_read}")
 
     def persist_state(self):
         """持久化对话状态"""
@@ -165,7 +161,7 @@ class DialogueManager:
             # 收到的是特殊定时触发的空消息,且在聚合中,且已经超时,恢复之前状态,继续处理
             if message.type == MessageType.AGGREGATION_TRIGGER \
                     and message_ts - self.last_interaction_time > self.message_aggregation_sec * 1000:
-                logging.debug("user_id: {}, last interaction time: {}".format(
+                logger.debug("user_id: {}, last interaction time: {}".format(
                     self.user_id, datetime.fromtimestamp(self.last_interaction_time / 1000)))
                 self.current_state = self.previous_state
             else:
@@ -459,7 +455,7 @@ class DialogueManager:
             消息列表
         """
         dialogue_history = self.history_dialogue_service.get_dialogue_history(self.staff_id, self.user_id)
-        logging.debug("staff[{}], user[{}], dialogue_history: {}".format(
+        logger.debug("staff[{}], user[{}], dialogue_history: {}".format(
             self.staff_id, self.user_id, dialogue_history
         ))
         messages = []
@@ -478,7 +474,7 @@ class DialogueManager:
         elif chat_service_type == ChatServiceType.COZE_CHAT:
             for entry in dialogue_history:
                 if not entry['content']:
-                    logging.warning("staff[{}], user[{}], role[{}]: empty content in dialogue history".format(
+                    logger.warning("staff[{}], user[{}], role[{}]: empty content in dialogue history".format(
                         self.staff_id, self.user_id, entry['role']
                     ))
                     continue
@@ -498,7 +494,7 @@ class DialogueManager:
                 messages.append(cozepy.Message.build_user_question_text('请开始对话'))
             #FIXME(zhoutian): 临时报警
             if user_message and not messages:
-                logging.error(f"staff[{self.staff_id}], user[{self.user_id}]: inconsistency in messages")
+                logger.error(f"staff[{self.staff_id}], user[{self.user_id}]: inconsistency in messages")
         config['messages'] = messages
 
         return config

+ 2 - 3
history_dialogue_service.py

@@ -3,8 +3,7 @@
 # vim:fenc=utf-8
 
 import requests
-import logging
-import json
+from logging_service import logger
 
 import configs
 
@@ -32,7 +31,7 @@ class HistoryDialogueService:
             elif sender == staff_id:
                 role = 'assistant'
             else:
-                logging.warning("Unknown sender in dialogue history: {}".format(sender))
+                logger.warning("Unknown sender in dialogue history: {}".format(sender))
                 continue
             ret.append({
                 'role': role,

+ 8 - 4
logging_service.py

@@ -22,14 +22,18 @@ class ColoredFormatter(logging.Formatter):
             message = f"{COLORS[record.levelname]}{message}{COLORS['RESET']}"
         return message
 
-def setup_root_logger():
-    logging.getLogger().setLevel(logging.DEBUG)
+def setup_root_logger(level=logging.DEBUG):
     console_handler = logging.StreamHandler()
     console_handler.setLevel(logging.DEBUG)
     formatter = ColoredFormatter(
-        '%(asctime)s - %(funcName)s[%(lineno)d] - %(levelname)s - %(message)s'
+        '%(asctime)s - %(name)s %(funcName)s[%(lineno)d] - %(levelname)s - %(message)s'
     )
     console_handler.setFormatter(formatter)
     root_logger = logging.getLogger()
     root_logger.handlers.clear()
-    root_logger.addHandler(console_handler)
+    root_logger.addHandler(console_handler)
+
+    agent_logger = logging.getLogger('agent')
+    agent_logger.setLevel(level)
+
+logger = logging.getLogger('agent')

+ 4 - 4
message_queue_backend.py

@@ -4,7 +4,7 @@
 
 import abc
 import time
-import logging
+from logging_service import logger
 from typing import Dict, Any, Optional
 import configs
 
@@ -84,12 +84,12 @@ class AliyunRocketMQQueueBackend(MessageQueueBackend):
             return None
         rmq_message = messages[0]
         body = rmq_message.body.decode('utf-8')
-        logging.debug("recv message body: {}".format(body))
+        logger.debug("recv message body: {}".format(body))
         try:
             message = Message.from_json(body)
             message._rmq_message = rmq_message
         except Exception as e:
-            logging.error("Invalid message: {}. Parsing error: {}".format(body, e))
+            logger.error("Invalid message: {}. Parsing error: {}".format(body, e))
             # 如果消息非法,直接ACK,避免死信
             self.consumer.ack(rmq_message)
             return None
@@ -98,7 +98,7 @@ class AliyunRocketMQQueueBackend(MessageQueueBackend):
     def ack(self, message: Message):
         if not message._rmq_message:
             raise ValueError("Message not set with _rmq_message.")
-        logging.debug("ack message: {}".format(message))
+        logger.debug("ack message: {}".format(message))
         self.consumer.ack(message._rmq_message)
 
     def produce(self, message: Message) -> None:

+ 9 - 8
user_manager.py

@@ -1,7 +1,8 @@
 #! /usr/bin/env python
 # -*- coding: utf-8 -*-
 # vim:fenc=utf-8
-import logging
+
+from logging_service import logger
 from typing import Dict, Optional, Tuple, Any, List
 import json
 import time
@@ -114,11 +115,11 @@ class MySQLUserManager(UserManager):
         sql = f"SELECT name, wxid, profile_data_v1 FROM {self.table_name} WHERE third_party_user_id = {user_id}"
         data = self.db.select(sql, pymysql.cursors.DictCursor)
         if not data:
-            logging.error(f"user[{user_id}] not found")
+            logger.error(f"user[{user_id}] not found")
             return {}
         data = data[0]
         if not data['profile_data_v1']:
-            logging.warning(f"user[{user_id}] profile not found, create a default one")
+            logger.warning(f"user[{user_id}] profile not found, create a default one")
             default_profile = self.get_default_profile(nickname=data['name'])
             self.save_user_profile(user_id, default_profile)
             return default_profile
@@ -143,7 +144,7 @@ class MySQLUserManager(UserManager):
               f"FROM {self.staff_table} WHERE third_party_user_id = '{staff_id}'"
         data = self.db.select(sql, pymysql.cursors.DictCursor)
         if not data:
-            logging.error(f"staff[{staff_id}] not found")
+            logger.error(f"staff[{staff_id}] not found")
             return {}
         profile = data[0]
         return profile
@@ -180,19 +181,19 @@ class MySQLUserRelationManager(UserRelationManager):
             sql = f"SELECT id FROM {self.staff_table} WHERE carrier_id = '{wxid}'"
             staff_data = self.wecom_db.select(sql, pymysql.cursors.DictCursor)
             if not staff_data:
-                logging.error(f"staff[{wxid}] not found in wecom database")
+                logger.error(f"staff[{wxid}] not found in wecom database")
                 continue
             staff_id = staff_data[0]['id']
             sql = f"SELECT user_id FROM {self.relation_table} WHERE staff_id = '{staff_id}' AND is_delete = 0"
             user_data = self.wecom_db.select(sql, pymysql.cursors.DictCursor)
             if not user_data:
-                logging.warning(f"staff[{wxid}] has no user")
+                logger.warning(f"staff[{wxid}] has no user")
                 continue
             user_ids = tuple(user['user_id'] for user in user_data)
             sql = f"SELECT union_id FROM {self.user_table} WHERE id IN {str(user_ids)} AND union_id is not null"
             user_data = self.wecom_db.select(sql, pymysql.cursors.DictCursor)
             if not user_data:
-                logging.warning(f"staff[{wxid}] users not found in wecom database")
+                logger.warning(f"staff[{wxid}] users not found in wecom database")
                 continue
             user_union_ids = tuple(user['union_id'] for user in user_data)
             batch_size = 100
@@ -205,7 +206,7 @@ class MySQLUserRelationManager(UserRelationManager):
                 sql = f"SELECT third_party_user_id, wxid FROM {self.agent_user_table} WHERE wxid IN {str(batch_union_ids)}"
                 batch_agent_user_data = self.agent_db.select(sql, pymysql.cursors.DictCursor)
                 if len(agent_user_data) != len(batch_union_ids):
-                    # logging.debug(f"staff[{wxid}] some users not found in agent database")
+                    # logger.debug(f"staff[{wxid}] some users not found in agent database")
                     pass
                 agent_user_data.extend(batch_agent_user_data)
             staff_user_pairs = [

+ 5 - 5
user_profile_extractor.py

@@ -9,7 +9,7 @@ import chat_service
 import configs
 from prompt_templates import USER_PROFILE_EXTRACT_PROMPT
 from openai import OpenAI
-import logging
+from logging_service import logger
 
 import global_flags
 
@@ -82,7 +82,7 @@ class UserProfileExtractor:
             return None
 
         try:
-            logging.debug("try to extract profile from message: {}".format(dialogue_history))
+            logger.debug("try to extract profile from message: {}".format(dialogue_history))
             response = self.llm_client.chat.completions.create(
                 model=self.model_name,
                 messages=[
@@ -95,7 +95,7 @@ class UserProfileExtractor:
 
             # 解析Function Call的参数
             tool_calls = response.choices[0].message.tool_calls
-            logging.debug(response)
+            logger.debug(response)
             if tool_calls:
                 function_call = tool_calls[0]
                 if function_call.function.name == 'update_user_profile':
@@ -103,11 +103,11 @@ class UserProfileExtractor:
                         profile_info = json.loads(function_call.function.arguments)
                         return {k: v for k, v in profile_info.items() if v}
                     except json.JSONDecodeError:
-                        logging.error("无法解析提取的用户信息")
+                        logger.error("无法解析提取的用户信息")
                         return None
 
         except Exception as e:
-            logging.error(f"用户画像提取出错: {e}")
+            logger.error(f"用户画像提取出错: {e}")
             return None
 
     def merge_profile_info(self, existing_profile: Dict, new_info: Dict) -> Dict: