Browse Source

Merge branch 'master' into dev-xym-add-test-task

xueyiming 2 days ago
parent
commit
07198b8cd8

+ 1 - 0
pqai_agent/agents/simple_chat_agent.py

@@ -23,6 +23,7 @@ class SimpleOpenAICompatibleChatAgent:
         self.generate_cfg = generate_cfg or {}
         self.max_run_step = max_run_step or DEFAULT_MAX_RUN_STEPS
         self.tool_call_records = []
+        logger.debug(self.tool_map)
 
     def add_tool(self, tool: FunctionTool):
         """添加一个工具到Agent中"""

+ 6 - 3
pqai_agent/dialogue_manager.py

@@ -153,7 +153,9 @@ class DialogueManager:
             return TimeContext.NIGHT
 
     def is_valid(self):
-        if not self.staff_profile.get('name', None) and not self.staff_profile.get('agent_name', None):
+        if not self.staff_profile.get('name', None) \
+                and not self.staff_profile.get('agent_name', None) \
+                and not self.staff_profile.get('基础信息', {}).get('昵称', None):
             return False
         return True
 
@@ -358,7 +360,8 @@ class DialogueManager:
 
     def _send_alert(self, alert_type: str, reason: Optional[str] = None) -> None:
         time_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
-        staff_info = f"{self.staff_profile.get('name', '未知')}[{self.staff_id}]"
+        name = self.staff_profile.get('name', None) or self.staff_profile.get('基础信息', {}).get('昵称', None)
+        staff_info = f"{name}[{self.staff_id}]"
         user_info = f"{self.user_profile.get('nickname', '未知')}[{self.user_id}]"
         alert_message = f"""
         {alert_type}告警
@@ -567,7 +570,7 @@ class DialogueManager:
             "if_first_interaction": True if self.previous_state == DialogueState.INITIALIZED else False,
             "if_active_greeting": False if user_message else True,
             "relation_stage": self.relation_stage,
-            "formatted_staff_profile": prompt_utils.format_agent_profile(self.staff_profile),
+            "formatted_staff_profile": prompt_utils.format_agent_profile_v2(self.staff_profile),
             "formatted_user_profile": prompt_utils.format_user_profile(self.user_profile),
             **self.user_profile,
             **legacy_staff_profile

+ 8 - 0
pqai_agent/push_service.py

@@ -81,6 +81,7 @@ class PushTaskWorkerPool:
                  mq_consumer: rocketmq.SimpleConsumer, mq_producer: rocketmq.Producer):
         self.agent_service = agent_service
         max_workers = configs.get()['system'].get('push_task_workers', 5)
+        self.max_push_workers = max_workers
         self.generate_executor = ThreadPoolExecutor(max_workers=max_workers)
         self.send_executors = {}
         self.rmq_topic = mq_topic
@@ -120,6 +121,13 @@ class PushTaskWorkerPool:
             msg_time = datetime.fromtimestamp(task['timestamp'] / 1000).strftime("%Y-%m-%d %H:%M:%S")
             logger.debug(f"recv message:{msg_time} - {task}")
             if task['task_type'] == TaskType.GENERATE.value:
+                # FIXME: 临时方案,避免消息在消费后等待超时并重复消费
+                if self.generate_executor._work_queue.qsize() > self.max_push_workers * 5:
+                    logger.warning("Too many generate tasks in queue, consume this task later")
+                    while self.generate_executor._work_queue.qsize() > self.max_push_workers * 5:
+                        time.sleep(10)
+                    # do not submit and ack this message
+                    continue
                 self.generate_executor.submit(self.handle_generate_task, task, msg)
             elif task['task_type'] == TaskType.SEND.value:
                 staff_id = task['staff_id']

+ 1 - 1
pqai_agent/user_manager.py

@@ -228,7 +228,7 @@ class MySQLUserManager(UserManager):
         return profile
 
     def get_staff_profile_v3(self, staff_id) -> Dict:
-        sql = f"SELECT agent_profile " \
+        sql = f"SELECT agent_profile_v2 " \
               f"FROM {self.staff_table} WHERE third_party_user_id = '{staff_id}'"
         data = self.db.select(sql)
         if not data:

+ 12 - 3
pqai_agent/utils/prompt_utils.py

@@ -1,5 +1,9 @@
+import json
+from io import StringIO
 from typing import Dict
 
+import yaml
+
 
 def format_agent_profile(profile: Dict) -> str:
     fields = [
@@ -16,12 +20,17 @@ def format_agent_profile(profile: Dict) -> str:
     ]
     strings_to_join = []
     for field in fields:
-        if not profile.get(field[0], None):
+        if profile.get(field[0], None) is None:
             continue
         cur_string = f"- {field[1]}:{profile[field[0]]}"
         strings_to_join.append(cur_string)
     return "\n".join(strings_to_join)
 
+def format_agent_profile_v2(profile: Dict) -> str:
+    str_stream = StringIO()
+    yaml.dump(profile, str_stream, indent=2, allow_unicode=True)
+    return str_stream.getvalue()
+
 def format_user_profile(profile: Dict) -> str:
     """
     :param profile:
@@ -54,8 +63,8 @@ def format_user_profile(profile: Dict) -> str:
     for field in fields:
         value = profile.get(field[0], None)
         if not value:
-            continue
-        if isinstance(value, list):
+            value = '未知'
+        elif isinstance(value, list):
             value = ','.join(value)
         elif isinstance(value, dict):
             value = ';'.join(f"{k}: {v}" for k, v in value.items())