Ver código fonte

Merge branch 'master' into dev-xym-add-test-task

# Conflicts:
#	pqai_agent_server/api_server.py
xueyiming 8 horas atrás
pai
commit
b4b7f6f14d
40 arquivos alterados com 582 adições e 210 exclusões
  1. 2 2
      pqai_agent/abtest/client.py
  2. 1 1
      pqai_agent/abtest/models.py
  3. 1 1
      pqai_agent/agent_config_manager.py
  4. 31 12
      pqai_agent/agent_service.py
  5. 1 1
      pqai_agent/agents/message_push_agent.py
  6. 1 1
      pqai_agent/agents/message_reply_agent.py
  7. 6 2
      pqai_agent/agents/multimodal_chat_agent.py
  8. 22 5
      pqai_agent/agents/simple_chat_agent.py
  9. 96 24
      pqai_agent/chat_service.py
  10. 1 1
      pqai_agent/clients/relation_stage_client.py
  11. 11 6
      pqai_agent/configs/dev.yaml
  12. 7 4
      pqai_agent/configs/prod.yaml
  13. 2 2
      pqai_agent/data_models/agent_configuration.py
  14. 2 2
      pqai_agent/data_models/agent_push_record.py
  15. 1 1
      pqai_agent/data_models/service_module.py
  16. 1 1
      pqai_agent/database.py
  17. 1 1
      pqai_agent/dialogue_manager.py
  18. 1 1
      pqai_agent/history_dialogue_service.py
  19. 9 7
      pqai_agent/logging.py
  20. 2 3
      pqai_agent/message_queue_backend.py
  21. 1 0
      pqai_agent/prompt_templates.py
  22. 64 45
      pqai_agent/push_service.py
  23. 1 1
      pqai_agent/rate_limiter.py
  24. 2 2
      pqai_agent/response_type_detector.py
  25. 1 1
      pqai_agent/service_module_manager.py
  26. 1 1
      pqai_agent/toolkit/__init__.py
  27. 1 1
      pqai_agent/toolkit/function_tool.py
  28. 1 1
      pqai_agent/toolkit/image_describer.py
  29. 1 1
      pqai_agent/toolkit/lark_sheet_record_for_human_intervention.py
  30. 1 1
      pqai_agent/toolkit/message_notifier.py
  31. 51 3
      pqai_agent/user_manager.py
  32. 3 3
      pqai_agent/user_profile_extractor.py
  33. 8 0
      pqai_agent/utils/__init__.py
  34. 7 4
      pqai_agent_server/agent_server.py
  35. 140 57
      pqai_agent_server/api_server.py
  36. 0 1
      pqai_agent_server/utils/__init__.py
  37. 14 3
      pqai_agent_server/utils/common.py
  38. 3 4
      pqai_agent_server/utils/prompt_util.py
  39. 43 0
      scripts/extract_push_action_logs.py
  40. 40 3
      tests/unit_test.py

+ 2 - 2
pqai_agent/abtest/client.py

@@ -6,7 +6,7 @@ from pqai_agent.abtest.models import Project, Domain, Layer, Experiment, Experim
     ExperimentContext, ExperimentResult
     ExperimentContext, ExperimentResult
 from alibabacloud_paiabtest20240119.models import ListProjectsRequest, ListProjectsResponseBodyProjects, \
 from alibabacloud_paiabtest20240119.models import ListProjectsRequest, ListProjectsResponseBodyProjects, \
     ListDomainsRequest, ListFeaturesRequest, ListLayersRequest, ListExperimentsRequest, ListExperimentVersionsRequest
     ListDomainsRequest, ListFeaturesRequest, ListLayersRequest, ListExperimentsRequest, ListExperimentVersionsRequest
-from pqai_agent.logging_service import logger
+from pqai_agent.logging import logger
 
 
 class ExperimentClient:
 class ExperimentClient:
     def __init__(self, client: Client):
     def __init__(self, client: Client):
@@ -267,7 +267,7 @@ def get_client():
     return g_client
     return g_client
 
 
 if __name__ == '__main__':
 if __name__ == '__main__':
-    from pqai_agent.logging_service import setup_root_logger
+    from pqai_agent.logging import setup_root_logger
     setup_root_logger(level='DEBUG')
     setup_root_logger(level='DEBUG')
     experiment_client = get_client()
     experiment_client = get_client()
 
 

+ 1 - 1
pqai_agent/abtest/models.py

@@ -3,7 +3,7 @@ import json
 from dataclasses import dataclass, field
 from dataclasses import dataclass, field
 import hashlib
 import hashlib
 
 
-from pqai_agent.logging_service import logger
+from pqai_agent.logging import logger
 
 
 
 
 class FNV:
 class FNV:

+ 1 - 1
pqai_agent/agent_config_manager.py

@@ -1,7 +1,7 @@
 from typing import Dict, Optional
 from typing import Dict, Optional
 
 
 from pqai_agent.data_models.agent_configuration import AgentConfiguration
 from pqai_agent.data_models.agent_configuration import AgentConfiguration
-from pqai_agent.logging_service import logger
+from pqai_agent.logging import logger
 
 
 class AgentConfigManager:
 class AgentConfigManager:
     def __init__(self, session_maker):
     def __init__(self, session_maker):

+ 31 - 12
pqai_agent/agent_service.py

@@ -1,7 +1,7 @@
 #! /usr/bin/env python
 #! /usr/bin/env python
 # -*- coding: utf-8 -*-
 # -*- coding: utf-8 -*-
 # vim:fenc=utf-8
 # vim:fenc=utf-8
-
+import json
 import re
 import re
 import signal
 import signal
 import sys
 import sys
@@ -15,15 +15,16 @@ import traceback
 import apscheduler.triggers.cron
 import apscheduler.triggers.cron
 import rocketmq
 import rocketmq
 from apscheduler.schedulers.background import BackgroundScheduler
 from apscheduler.schedulers.background import BackgroundScheduler
+from rocketmq import FilterExpression
 from sqlalchemy.orm import sessionmaker
 from sqlalchemy.orm import sessionmaker
 
 
-from pqai_agent import configs
+from pqai_agent import configs, push_service
 from pqai_agent.abtest.utils import get_abtest_info
 from pqai_agent.abtest.utils import get_abtest_info
 from pqai_agent.agent_config_manager import AgentConfigManager
 from pqai_agent.agent_config_manager import AgentConfigManager
 from pqai_agent.agents.message_reply_agent import MessageReplyAgent
 from pqai_agent.agents.message_reply_agent import MessageReplyAgent
 from pqai_agent.configs import apollo_config
 from pqai_agent.configs import apollo_config
 from pqai_agent.exceptions import NoRetryException
 from pqai_agent.exceptions import NoRetryException
-from pqai_agent.logging_service import logger
+from pqai_agent.logging import logger
 from pqai_agent import chat_service
 from pqai_agent import chat_service
 from pqai_agent.chat_service import CozeChat, ChatServiceType
 from pqai_agent.chat_service import CozeChat, ChatServiceType
 from pqai_agent.dialogue_manager import DialogueManager, DialogueState, DialogueStateCache
 from pqai_agent.dialogue_manager import DialogueManager, DialogueState, DialogueStateCache
@@ -96,7 +97,8 @@ class AgentService:
 
 
         # Push相关
         # Push相关
         self.push_task_producer = None
         self.push_task_producer = None
-        self.push_task_consumer = None
+        self.push_generate_task_consumer = None
+        self.push_send_task_consumer = None
         self._init_push_task_queue()
         self._init_push_task_queue()
         self.next_push_disabled = True
         self.next_push_disabled = True
         self._resume_unfinished_push_task()
         self._resume_unfinished_push_task()
@@ -344,7 +346,7 @@ class AgentService:
             logger.debug(f"staff[{staff_id}], user[{user_id}]: no messages to send")
             logger.debug(f"staff[{staff_id}], user[{user_id}]: no messages to send")
 
 
     def can_send_to_user(self, staff_id, user_id) -> bool:
     def can_send_to_user(self, staff_id, user_id) -> bool:
-        user_tags = self.user_relation_manager.get_user_tags(user_id)
+        user_tags = self.user_manager.get_user_tags([user_id]).get(user_id, [])
         white_list_tags = set(apollo_config.get_json_value("agent_response_whitelist_tags", []))
         white_list_tags = set(apollo_config.get_json_value("agent_response_whitelist_tags", []))
         hit_white_list_tags = len(set(user_tags).intersection(white_list_tags)) > 0
         hit_white_list_tags = len(set(user_tags).intersection(white_list_tags)) > 0
         staff_white_lists = set(apollo_config.get_json_value("agent_response_whitelist_staffs", []))
         staff_white_lists = set(apollo_config.get_json_value("agent_response_whitelist_staffs", []))
@@ -384,20 +386,31 @@ class AgentService:
         mq_conf = configs.get()['mq']
         mq_conf = configs.get()['mq']
         rmq_client_conf = rocketmq.ClientConfiguration(mq_conf['endpoints'], credentials, mq_conf['instance_id'])
         rmq_client_conf = rocketmq.ClientConfiguration(mq_conf['endpoints'], credentials, mq_conf['instance_id'])
         rmq_topic = mq_conf['push_tasks_topic']
         rmq_topic = mq_conf['push_tasks_topic']
-        rmq_group = mq_conf['push_tasks_group']
+        rmq_group_generate = mq_conf['push_generate_task_group']
+        rmq_group_send = mq_conf['push_send_task_group']
         self.push_task_rmq_topic = rmq_topic
         self.push_task_rmq_topic = rmq_topic
         self.push_task_producer = rocketmq.Producer(rmq_client_conf, (rmq_topic,))
         self.push_task_producer = rocketmq.Producer(rmq_client_conf, (rmq_topic,))
         self.push_task_producer.startup()
         self.push_task_producer.startup()
-        self.push_task_consumer = rocketmq.SimpleConsumer(rmq_client_conf, rmq_group, await_duration=5)
-        self.push_task_consumer.startup()
-        self.push_task_consumer.subscribe(rmq_topic)
+        # FIXME: 不应该暴露到agent service中
+        self.push_generate_task_consumer = rocketmq.SimpleConsumer(rmq_client_conf, rmq_group_generate, await_duration=5)
+        self.push_generate_task_consumer.startup()
+        self.push_generate_task_consumer.subscribe(
+            rmq_topic, filter_expression=FilterExpression(push_service.TaskType.GENERATE.value)
+        )
+        self.push_send_task_consumer = rocketmq.SimpleConsumer(rmq_client_conf, rmq_group_send, await_duration=5)
+        self.push_send_task_consumer.startup()
+        self.push_send_task_consumer.subscribe(
+            rmq_topic, filter_expression=FilterExpression(push_service.TaskType.SEND.value)
+        )
 
 
 
 
     def _resume_unfinished_push_task(self):
     def _resume_unfinished_push_task(self):
         def run_unfinished_push_task():
         def run_unfinished_push_task():
             logger.info("start to resume unfinished push task")
             logger.info("start to resume unfinished push task")
             push_task_worker_pool = PushTaskWorkerPool(
             push_task_worker_pool = PushTaskWorkerPool(
-                self, self.push_task_rmq_topic, self.push_task_consumer, self.push_task_producer)
+                self, self.push_task_rmq_topic, self.push_generate_task_consumer,
+                self.push_send_task_consumer, self.push_task_producer
+            )
             push_task_worker_pool.start()
             push_task_worker_pool.start()
             push_task_worker_pool.wait_to_finish()
             push_task_worker_pool.wait_to_finish()
             self.next_push_disabled = False
             self.next_push_disabled = False
@@ -427,7 +440,8 @@ class AgentService:
             push_scan_threads.append(scan_thread)
             push_scan_threads.append(scan_thread)
 
 
         push_task_worker_pool = PushTaskWorkerPool(
         push_task_worker_pool = PushTaskWorkerPool(
-            self, self.push_task_rmq_topic, self.push_task_consumer, self.push_task_producer)
+            self, self.push_task_rmq_topic,
+            self.push_generate_task_consumer, self.push_send_task_consumer, self.push_task_producer)
         push_task_worker_pool.start()
         push_task_worker_pool.start()
         for thread in push_scan_threads:
         for thread in push_scan_threads:
             thread.join()
             thread.join()
@@ -461,9 +475,14 @@ class AgentService:
         agent_config = get_agent_abtest_config('chat', main_agent.user_id,
         agent_config = get_agent_abtest_config('chat', main_agent.user_id,
                                                self.service_module_manager, self.agent_config_manager)
                                                self.service_module_manager, self.agent_config_manager)
         if agent_config:
         if agent_config:
+            try:
+                tool_names = json.loads(agent_config.tools)
+            except json.JSONDecodeError:
+                logger.error(f"Invalid JSON in agent tools: {agent_config.tools}")
+                tool_names = []
             chat_agent = MessageReplyAgent(model=agent_config.execution_model,
             chat_agent = MessageReplyAgent(model=agent_config.execution_model,
                                            system_prompt=agent_config.system_prompt,
                                            system_prompt=agent_config.system_prompt,
-                                           tools=get_tools(agent_config.tools))
+                                           tools=get_tools(tool_names))
         else:
         else:
             chat_agent = MessageReplyAgent()
             chat_agent = MessageReplyAgent()
         chat_responses = chat_agent.generate_message(
         chat_responses = chat_agent.generate_message(

+ 1 - 1
pqai_agent/agents/message_push_agent.py

@@ -2,7 +2,7 @@ from typing import Optional, List, Dict
 
 
 from pqai_agent.agents.multimodal_chat_agent import MultiModalChatAgent
 from pqai_agent.agents.multimodal_chat_agent import MultiModalChatAgent
 from pqai_agent.chat_service import VOLCENGINE_MODEL_DEEPSEEK_V3
 from pqai_agent.chat_service import VOLCENGINE_MODEL_DEEPSEEK_V3
-from pqai_agent.logging_service import logger
+from pqai_agent.logging import logger
 from pqai_agent.toolkit.function_tool import FunctionTool
 from pqai_agent.toolkit.function_tool import FunctionTool
 from pqai_agent.toolkit.image_describer import ImageDescriber
 from pqai_agent.toolkit.image_describer import ImageDescriber
 from pqai_agent.toolkit.message_notifier import MessageNotifier
 from pqai_agent.toolkit.message_notifier import MessageNotifier

+ 1 - 1
pqai_agent/agents/message_reply_agent.py

@@ -2,7 +2,7 @@ from typing import Optional, List, Dict
 
 
 from pqai_agent.agents.multimodal_chat_agent import MultiModalChatAgent
 from pqai_agent.agents.multimodal_chat_agent import MultiModalChatAgent
 from pqai_agent.chat_service import VOLCENGINE_MODEL_DEEPSEEK_V3
 from pqai_agent.chat_service import VOLCENGINE_MODEL_DEEPSEEK_V3
-from pqai_agent.logging_service import logger
+from pqai_agent.logging import logger
 from pqai_agent.toolkit.function_tool import FunctionTool
 from pqai_agent.toolkit.function_tool import FunctionTool
 from pqai_agent.toolkit.image_describer import ImageDescriber
 from pqai_agent.toolkit.image_describer import ImageDescriber
 from pqai_agent.toolkit.message_notifier import MessageNotifier
 from pqai_agent.toolkit.message_notifier import MessageNotifier

+ 6 - 2
pqai_agent/agents/multimodal_chat_agent.py

@@ -2,8 +2,9 @@ import datetime
 from abc import abstractmethod
 from abc import abstractmethod
 from typing import Optional, List, Dict
 from typing import Optional, List, Dict
 
 
+from pqai_agent import configs
 from pqai_agent.agents.simple_chat_agent import SimpleOpenAICompatibleChatAgent
 from pqai_agent.agents.simple_chat_agent import SimpleOpenAICompatibleChatAgent
-from pqai_agent.logging_service import logger
+from pqai_agent.logging import logger
 from pqai_agent.mq_message import MessageType
 from pqai_agent.mq_message import MessageType
 from pqai_agent.toolkit import get_tool
 from pqai_agent.toolkit import get_tool
 from pqai_agent.toolkit.function_tool import FunctionTool
 from pqai_agent.toolkit.function_tool import FunctionTool
@@ -28,7 +29,10 @@ class MultiModalChatAgent(SimpleOpenAICompatibleChatAgent):
         pass
         pass
 
 
     def _generate_message(self, context: Dict, dialogue_history: List[Dict],
     def _generate_message(self, context: Dict, dialogue_history: List[Dict],
-                         query_prompt_template: str) -> List[Dict]:
+                          query_prompt_template: str) -> List[Dict]:
+        if configs.get().get('debug_flags', {}).get('disable_llm_api_call', False):
+            return [{'type': 'text', 'content': '测试消息 -> {nickname}'.format(**context)}]
+
         formatted_dialogue = MultiModalChatAgent.compose_dialogue(dialogue_history)
         formatted_dialogue = MultiModalChatAgent.compose_dialogue(dialogue_history)
         query = query_prompt_template.format(**context, dialogue_history=formatted_dialogue)
         query = query_prompt_template.format(**context, dialogue_history=formatted_dialogue)
         self.run(query)
         self.run(query)

+ 22 - 5
pqai_agent/agents/simple_chat_agent.py

@@ -1,9 +1,10 @@
 import json
 import json
 from typing import List, Optional
 from typing import List, Optional
 
 
+import pqai_agent.utils
 from pqai_agent.agent import DEFAULT_MAX_RUN_STEPS
 from pqai_agent.agent import DEFAULT_MAX_RUN_STEPS
 from pqai_agent.chat_service import OpenAICompatible
 from pqai_agent.chat_service import OpenAICompatible
-from pqai_agent.logging_service import logger
+from pqai_agent.logging import logger
 from pqai_agent.toolkit.function_tool import FunctionTool
 from pqai_agent.toolkit.function_tool import FunctionTool
 
 
 
 
@@ -23,6 +24,8 @@ class SimpleOpenAICompatibleChatAgent:
         self.generate_cfg = generate_cfg or {}
         self.generate_cfg = generate_cfg or {}
         self.max_run_step = max_run_step or DEFAULT_MAX_RUN_STEPS
         self.max_run_step = max_run_step or DEFAULT_MAX_RUN_STEPS
         self.tool_call_records = []
         self.tool_call_records = []
+        self.total_input_tokens = 0
+        self.total_output_tokens = 0
         logger.debug(self.tool_map)
         logger.debug(self.tool_map)
 
 
     def add_tool(self, tool: FunctionTool):
     def add_tool(self, tool: FunctionTool):
@@ -33,23 +36,26 @@ class SimpleOpenAICompatibleChatAgent:
         self.tool_map[tool.name] = tool
         self.tool_map[tool.name] = tool
 
 
     def run(self, user_input: str) -> str:
     def run(self, user_input: str) -> str:
+        run_id = pqai_agent.utils.random_str()[:12]
         messages = [{"role": "system", "content": self.system_prompt}]
         messages = [{"role": "system", "content": self.system_prompt}]
         tools = [tool.get_openai_tool_schema() for tool in self.tools]
         tools = [tool.get_openai_tool_schema() for tool in self.tools]
         messages.append({"role": "user", "content": user_input})
         messages.append({"role": "user", "content": user_input})
 
 
         n_steps = 0
         n_steps = 0
-        logger.debug(f"start agent loop. messages: {messages}")
+        logger.debug(f"run_id[{run_id}] start agent loop. messages: {messages}")
         while n_steps < self.max_run_step:
         while n_steps < self.max_run_step:
             response = self.llm_client.chat.completions.create(model=self.model, messages=messages, tools=tools, **self.generate_cfg)
             response = self.llm_client.chat.completions.create(model=self.model, messages=messages, tools=tools, **self.generate_cfg)
             message = response.choices[0].message
             message = response.choices[0].message
+            self.total_input_tokens += response.usage.prompt_tokens
+            self.total_output_tokens += response.usage.completion_tokens
             messages.append(message)
             messages.append(message)
-            logger.debug(f"current step content: {message.content}")
+            logger.debug(f"run_id[{run_id}] current step content: {message.content}")
 
 
             if message.tool_calls:
             if message.tool_calls:
                 for tool_call in message.tool_calls:
                 for tool_call in message.tool_calls:
                     function_name = tool_call.function.name
                     function_name = tool_call.function.name
                     arguments = json.loads(tool_call.function.arguments)
                     arguments = json.loads(tool_call.function.arguments)
-                    logger.debug(f"call function[{function_name}], parameter: {arguments}")
+                    logger.debug(f"run_id[{run_id}] call function[{function_name}], parameter: {arguments}")
 
 
                     if function_name in self.tool_map:
                     if function_name in self.tool_map:
                         result = self.tool_map[function_name](**arguments)
                         result = self.tool_map[function_name](**arguments)
@@ -64,10 +70,21 @@ class SimpleOpenAICompatibleChatAgent:
                             "result": result
                             "result": result
                         })
                         })
                     else:
                     else:
-                        logger.error(f"Function {function_name} not found in tool map.")
+                        logger.error(f"run_id[{run_id}] Function {function_name} not found in tool map.")
                         raise Exception(f"Function {function_name} not found in tool map.")
                         raise Exception(f"Function {function_name} not found in tool map.")
             else:
             else:
                 return message.content
                 return message.content
             n_steps += 1
             n_steps += 1
 
 
         raise Exception("Max run steps exceeded")
         raise Exception("Max run steps exceeded")
+
+    def get_total_input_tokens(self) -> int:
+        """获取总输入token数"""
+        return self.total_input_tokens
+
+    def get_total_output_tokens(self) -> int:
+        """获取总输出token数"""
+        return self.total_output_tokens
+
+    def get_total_cost(self) -> float:
+        return OpenAICompatible.calculate_cost(self.model, self.total_input_tokens, self.total_output_tokens)

+ 96 - 24
pqai_agent/chat_service.py

@@ -11,7 +11,7 @@ from enum import Enum, auto
 import httpx
 import httpx
 
 
 from pqai_agent import configs
 from pqai_agent import configs
-from pqai_agent.logging_service import logger
+from pqai_agent.logging import logger
 import cozepy
 import cozepy
 from cozepy import Coze, TokenAuth, Message, ChatStatus, MessageType, JWTOAuthApp, JWTAuth
 from cozepy import Coze, TokenAuth, Message, ChatStatus, MessageType, JWTOAuthApp, JWTAuth
 import time
 import time
@@ -22,9 +22,9 @@ COZE_CN_BASE_URL = 'https://api.coze.cn'
 VOLCENGINE_API_TOKEN = '5e275c38-44fd-415f-abcf-4b59f6377f72'
 VOLCENGINE_API_TOKEN = '5e275c38-44fd-415f-abcf-4b59f6377f72'
 VOLCENGINE_BASE_URL = "https://ark.cn-beijing.volces.com/api/v3"
 VOLCENGINE_BASE_URL = "https://ark.cn-beijing.volces.com/api/v3"
 VOLCENGINE_MODEL_DEEPSEEK_V3 = "deepseek-v3-250324"
 VOLCENGINE_MODEL_DEEPSEEK_V3 = "deepseek-v3-250324"
-VOLCENGINE_MODEL_DOUBAO_PRO_1_5 = 'ep-20250307150409-4blz9'
-VOLCENGINE_MODEL_DOUBAO_PRO_32K = 'ep-20250414202859-6nkz5'
-VOLCENGINE_MODEL_DOUBAO_1_5_VISION_PRO = 'ep-20250421193334-nz5wd'
+VOLCENGINE_MODEL_DOUBAO_PRO_1_5_32K = 'doubao-1-5-pro-32k-250115'
+VOLCENGINE_MODEL_DOUBAO_PRO_32K = 'doubao-pro-32k-241215'
+VOLCENGINE_MODEL_DOUBAO_1_5_VISION_PRO = 'doubao-1-5-vision-pro-32k-250115'
 DEEPSEEK_API_TOKEN = 'sk-67daad8f424f4854bda7f1fed7ef220b'
 DEEPSEEK_API_TOKEN = 'sk-67daad8f424f4854bda7f1fed7ef220b'
 DEEPSEEK_BASE_URL = 'https://api.deepseek.com/'
 DEEPSEEK_BASE_URL = 'https://api.deepseek.com/'
 DEEPSEEK_CHAT_MODEL = 'deepseek-chat'
 DEEPSEEK_CHAT_MODEL = 'deepseek-chat'
@@ -37,35 +37,81 @@ OPENAI_MODEL_GPT_4o_mini = 'gpt-4o-mini'
 OPENROUTER_API_TOKEN = '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'
 OPENROUTER_API_TOKEN = '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'
 OPENROUTER_BASE_URL = 'https://openrouter.ai/api/v1/'
 OPENROUTER_BASE_URL = 'https://openrouter.ai/api/v1/'
 OPENROUTER_MODEL_CLAUDE_3_7_SONNET = 'anthropic/claude-3.7-sonnet'
 OPENROUTER_MODEL_CLAUDE_3_7_SONNET = 'anthropic/claude-3.7-sonnet'
+ALIYUN_API_TOKEN = 'sk-47381479425f4485af7673d3d2fd92b6'
+ALIYUN_BASE_URL = 'https://dashscope.aliyuncs.com/compatible-mode/v1'
+
 
 
 class ChatServiceType(Enum):
 class ChatServiceType(Enum):
     OPENAI_COMPATIBLE = auto()
     OPENAI_COMPATIBLE = auto()
     COZE_CHAT = auto()
     COZE_CHAT = auto()
 
 
+class ModelPrice:
+    EXCHANGE_RATE_TO_CNY = {
+        "USD": 7.2,  # Example conversion rate, adjust as needed
+    }
+
+    def __init__(self, input_price: float, output_price: float, currency: str = 'CNY'):
+        """
+        :param input_price: input price for per million tokens
+        :param output_price: output price for per million tokens
+        """
+        self.input_price = input_price
+        self.output_price = output_price
+        self.currency = currency
+
+    def get_total_cost(self, input_tokens: int, output_tokens: int, convert_to_cny: bool = True) -> float:
+        """
+        Calculate the total cost based on input and output tokens.
+        :param input_tokens: Number of input tokens
+        :param output_tokens: Number of output tokens
+        :param convert_to_cny: Whether to convert the cost to CNY (default is True)
+        :return: Total cost in the specified currency
+        """
+        total_cost = (self.input_price * input_tokens / 1_000_000) + (self.output_price * output_tokens / 1_000_000)
+        if convert_to_cny and self.currency != 'CNY':
+            conversion_rate = self.EXCHANGE_RATE_TO_CNY.get(self.currency, 1.0)
+            total_cost *= conversion_rate
+        return total_cost
+
+    def __repr__(self):
+        return f"ModelPrice(input_price={self.input_price}, output_price={self.output_price}, currency={self.currency})"
+
 class OpenAICompatible:
 class OpenAICompatible:
+    volcengine_models = [
+        VOLCENGINE_MODEL_DOUBAO_PRO_32K,
+        VOLCENGINE_MODEL_DOUBAO_PRO_1_5_32K,
+        VOLCENGINE_MODEL_DOUBAO_1_5_VISION_PRO,
+        VOLCENGINE_MODEL_DEEPSEEK_V3
+    ]
+    deepseek_models = [
+        DEEPSEEK_CHAT_MODEL,
+    ]
+    openai_models = [
+        OPENAI_MODEL_GPT_4o_mini,
+        OPENAI_MODEL_GPT_4o
+    ]
+    openrouter_models = [
+        OPENROUTER_MODEL_CLAUDE_3_7_SONNET,
+    ]
+
+    model_prices = {
+        VOLCENGINE_MODEL_DEEPSEEK_V3: ModelPrice(input_price=2, output_price=8),
+        VOLCENGINE_MODEL_DOUBAO_PRO_32K: ModelPrice(input_price=0.8, output_price=2),
+        VOLCENGINE_MODEL_DOUBAO_PRO_1_5_32K: ModelPrice(input_price=0.8, output_price=2),
+        VOLCENGINE_MODEL_DOUBAO_1_5_VISION_PRO: ModelPrice(input_price=3, output_price=9),
+        DEEPSEEK_CHAT_MODEL: ModelPrice(input_price=2, output_price=8),
+        OPENAI_MODEL_GPT_4o: ModelPrice(input_price=2.5, output_price=10, currency='USD'),
+        OPENAI_MODEL_GPT_4o_mini: ModelPrice(input_price=0.15, output_price=0.6, currency='USD'),
+        OPENROUTER_MODEL_CLAUDE_3_7_SONNET: ModelPrice(input_price=3, output_price=15, currency='USD'),
+    }
+
     @staticmethod
     @staticmethod
     def create_client(model_name, **kwargs) -> OpenAI:
     def create_client(model_name, **kwargs) -> OpenAI:
-        volcengine_models = [
-            VOLCENGINE_MODEL_DOUBAO_PRO_32K,
-            VOLCENGINE_MODEL_DOUBAO_PRO_1_5,
-            VOLCENGINE_MODEL_DOUBAO_1_5_VISION_PRO,
-            VOLCENGINE_MODEL_DEEPSEEK_V3
-        ]
-        deepseek_models = [
-            DEEPSEEK_CHAT_MODEL,
-        ]
-        openai_models = [
-            OPENAI_MODEL_GPT_4o_mini,
-            OPENAI_MODEL_GPT_4o
-        ]
-        openrouter_models = [
-            OPENROUTER_MODEL_CLAUDE_3_7_SONNET,
-        ]
-        if model_name in volcengine_models:
+        if model_name in OpenAICompatible.volcengine_models:
             llm_client = OpenAI(api_key=VOLCENGINE_API_TOKEN, base_url=VOLCENGINE_BASE_URL, **kwargs)
             llm_client = OpenAI(api_key=VOLCENGINE_API_TOKEN, base_url=VOLCENGINE_BASE_URL, **kwargs)
-        elif model_name in deepseek_models:
+        elif model_name in OpenAICompatible.deepseek_models:
             llm_client = OpenAI(api_key=DEEPSEEK_API_TOKEN, base_url=DEEPSEEK_BASE_URL, **kwargs)
             llm_client = OpenAI(api_key=DEEPSEEK_API_TOKEN, base_url=DEEPSEEK_BASE_URL, **kwargs)
-        elif model_name in openai_models:
+        elif model_name in OpenAICompatible.openai_models:
             socks_conf = configs.get().get('system', {}).get('outside_proxy', {}).get('socks5', {})
             socks_conf = configs.get().get('system', {}).get('outside_proxy', {}).get('socks5', {})
             if socks_conf:
             if socks_conf:
                 http_client = httpx.Client(
                 http_client = httpx.Client(
@@ -74,12 +120,38 @@ class OpenAICompatible:
                 )
                 )
                 kwargs['http_client'] = http_client
                 kwargs['http_client'] = http_client
             llm_client = OpenAI(api_key=OPENAI_API_TOKEN, base_url=OPENAI_BASE_URL, **kwargs)
             llm_client = OpenAI(api_key=OPENAI_API_TOKEN, base_url=OPENAI_BASE_URL, **kwargs)
-        elif model_name in openrouter_models:
+        elif model_name in OpenAICompatible.openrouter_models:
             llm_client = OpenAI(api_key=OPENROUTER_API_TOKEN, base_url=OPENROUTER_BASE_URL, **kwargs)
             llm_client = OpenAI(api_key=OPENROUTER_API_TOKEN, base_url=OPENROUTER_BASE_URL, **kwargs)
         else:
         else:
             raise Exception("Unsupported model: %s" % model_name)
             raise Exception("Unsupported model: %s" % model_name)
         return llm_client
         return llm_client
 
 
+    @staticmethod
+    def get_price(model_name: str) -> ModelPrice:
+        """
+        Get the price for a given model.
+        :param model_name: Name of the model
+        :return: ModelPrice object containing input and output prices
+        """
+        if model_name not in OpenAICompatible.model_prices:
+            raise ValueError(f"Model {model_name} not found in price list.")
+        return OpenAICompatible.model_prices[model_name]
+
+    @staticmethod
+    def calculate_cost(model_name: str, input_tokens: int, output_tokens: int, convert_to_cny: bool = True) -> float:
+        """
+        Calculate the cost for a given model based on input and output tokens.
+        :param model_name: Name of the model
+        :param input_tokens: Number of input tokens
+        :param output_tokens: Number of output tokens
+        :param convert_to_cny: Whether to convert the cost to CNY (default is True)
+        :return: Total cost in the model's currency
+        """
+        if model_name not in OpenAICompatible.model_prices:
+            raise ValueError(f"Model {model_name} not found in price list.")
+        price = OpenAICompatible.model_prices[model_name]
+        return price.get_total_cost(input_tokens, output_tokens, convert_to_cny)
+
 class CrossAccountJWTOAuthApp(JWTOAuthApp):
 class CrossAccountJWTOAuthApp(JWTOAuthApp):
     def __init__(self, account_id: str, client_id: str, private_key: str, public_key_id: str, base_url):
     def __init__(self, account_id: str, client_id: str, private_key: str, public_key_id: str, base_url):
         self.account_id = account_id
         self.account_id = account_id

+ 1 - 1
pqai_agent/clients/relation_stage_client.py

@@ -2,7 +2,7 @@ from typing import Optional
 
 
 import requests
 import requests
 
 
-from pqai_agent.logging_service import logger
+from pqai_agent.logging import logger
 
 
 class RelationStageClient:
 class RelationStageClient:
     UNKNOWN_RELATION_STAGE = '未知'
     UNKNOWN_RELATION_STAGE = '未知'

+ 11 - 6
pqai_agent/configs/dev.yaml

@@ -23,6 +23,9 @@ storage:
   staff:
   staff:
     database: ai_agent
     database: ai_agent
     table: qywx_employee
     table: qywx_employee
+  agent_user_relation:
+    database: ai_agent
+    table: qywx_employee_customer
   user_relation:
   user_relation:
     database: growth
     database: growth
     table:
     table:
@@ -51,8 +54,8 @@ chat_api:
     private_key_path: oauth/coze_privkey.pem
     private_key_path: oauth/coze_privkey.pem
     account_id: 649175100044793
     account_id: 649175100044793
   openai_compatible:
   openai_compatible:
-    text_model: ep-20250414202859-6nkz5
-    multimodal_model: ep-20250421193334-nz5wd
+    text_model: doubao-pro-32k-241215
+    multimodal_model: doubao-1-5-vision-pro-32k-250115
 
 
 system:
 system:
   outside_proxy:
   outside_proxy:
@@ -62,11 +65,12 @@ system:
   scheduler_mode: local
   scheduler_mode: local
   human_intervention_alert_url: https://open.feishu.cn/open-apis/bot/v2/hook/379fcd1a-0fed-4e58-8cd0-40b6d1895721
   human_intervention_alert_url: https://open.feishu.cn/open-apis/bot/v2/hook/379fcd1a-0fed-4e58-8cd0-40b6d1895721
   max_reply_workers: 2
   max_reply_workers: 2
-  max_push_workers: 1
-  chat_agent_version: 1
+  push_task_workers: 1
+  chat_agent_version: 2
+  log_dir: .
 
 
 debug_flags:
 debug_flags:
-  disable_llm_api_call: True
+  disable_llm_api_call: False
   use_local_user_storage: True
   use_local_user_storage: True
   console_input: True
   console_input: True
   disable_active_conversation: True
   disable_active_conversation: True
@@ -82,4 +86,5 @@ mq:
   scheduler_topic: agent_scheduler_event_dev
   scheduler_topic: agent_scheduler_event_dev
   scheduler_group: agent_scheduler_event_dev
   scheduler_group: agent_scheduler_event_dev
   push_tasks_topic: agent_push_tasks_dev
   push_tasks_topic: agent_push_tasks_dev
-  push_tasks_group: agent_push_tasks_dev
+  push_send_task_group: agent_push_tasks_dev
+  push_generate_task_group: agent_push_generate_task_dev

+ 7 - 4
pqai_agent/configs/prod.yaml

@@ -37,7 +37,7 @@ storage:
     table: qywx_chat_history
     table: qywx_chat_history
   push_record:
   push_record:
     database: ai_agent
     database: ai_agent
-    table: agent_push_record_dev
+    table: agent_push_record
 
 
 chat_api:
 chat_api:
   coze:
   coze:
@@ -46,8 +46,8 @@ chat_api:
     private_key_path: oauth/coze_privkey.pem
     private_key_path: oauth/coze_privkey.pem
     account_id: 649175100044793
     account_id: 649175100044793
   openai_compatible:
   openai_compatible:
-    text_model: ep-20250414202859-6nkz5
-    multimodal_model: ep-20250421193334-nz5wd
+    text_model: doubao-pro-32k-241215
+    multimodal_model: doubao-1-5-vision-pro-32k-250115
 
 
 system:
 system:
   outside_proxy:
   outside_proxy:
@@ -57,6 +57,8 @@ system:
   scheduler_mode: mq
   scheduler_mode: mq
   human_intervention_alert_url: https://open.feishu.cn/open-apis/bot/v2/hook/c316b559-1c6a-4c4e-97c9-50b44e4c2a9d
   human_intervention_alert_url: https://open.feishu.cn/open-apis/bot/v2/hook/c316b559-1c6a-4c4e-97c9-50b44e4c2a9d
   max_reply_workers: 5
   max_reply_workers: 5
+  push_task_workers: 5
+  log_dir: /var/log/agent_service
 
 
 agent_behavior:
 agent_behavior:
   message_aggregation_sec: 20
   message_aggregation_sec: 20
@@ -79,4 +81,5 @@ mq:
   scheduler_topic: agent_scheduler_event
   scheduler_topic: agent_scheduler_event
   scheduler_group: agent_scheduler_event
   scheduler_group: agent_scheduler_event
   push_tasks_topic: agent_push_tasks
   push_tasks_topic: agent_push_tasks
-  push_tasks_group: agent_push_tasks
+  push_send_task_group: agent_push_tasks
+  push_generate_task_group: agent_push_generate_task

+ 2 - 2
pqai_agent/data_models/agent_configuration.py

@@ -1,7 +1,7 @@
 from enum import Enum
 from enum import Enum
 
 
 from sqlalchemy import Column, Integer, Text, BigInteger, String, SmallInteger, Boolean, TIMESTAMP
 from sqlalchemy import Column, Integer, Text, BigInteger, String, SmallInteger, Boolean, TIMESTAMP
-from sqlalchemy.ext.declarative import declarative_base
+from sqlalchemy.orm import declarative_base
 
 
 Base = declarative_base()
 Base = declarative_base()
 
 
@@ -27,4 +27,4 @@ class AgentConfiguration(Base):
     create_user = Column(String(32), nullable=True, comment="创建用户")
     create_user = Column(String(32), nullable=True, comment="创建用户")
     update_user = Column(String(32), nullable=True, comment="更新用户")
     update_user = Column(String(32), nullable=True, comment="更新用户")
     create_time = Column(TIMESTAMP, nullable=True, server_default="CURRENT_TIMESTAMP", comment="创建时间")
     create_time = Column(TIMESTAMP, nullable=True, server_default="CURRENT_TIMESTAMP", comment="创建时间")
-    update_time = Column(TIMESTAMP, nullable=True, server_default="CURRENT_TIMESTAMP", onupdate="CURRENT_TIMESTAMP", comment="更新时间")
+    update_time = Column(TIMESTAMP, nullable=True, server_default="CURRENT_TIMESTAMP", server_onupdate="CURRENT_TIMESTAMP", comment="更新时间")

+ 2 - 2
pqai_agent/data_models/agent_push_record.py

@@ -1,12 +1,12 @@
 from sqlalchemy import Column, Integer, Text, BigInteger
 from sqlalchemy import Column, Integer, Text, BigInteger
-from sqlalchemy.ext.declarative import declarative_base
+from sqlalchemy.orm import declarative_base
 
 
 from pqai_agent import configs
 from pqai_agent import configs
 
 
 Base = declarative_base()
 Base = declarative_base()
 
 
 class AgentPushRecord(Base):
 class AgentPushRecord(Base):
-    __tablename__ = configs.get().get('storage', {}).get('push_record', {}).get('table_name', 'agent_push_record_dev')
+    __tablename__ = configs.get().get('storage', {}).get('push_record', {}).get('table', 'agent_push_record_dev')
     id = Column(Integer, primary_key=True)
     id = Column(Integer, primary_key=True)
     staff_id = Column(Integer)
     staff_id = Column(Integer)
     user_id = Column(Integer)
     user_id = Column(Integer)

+ 1 - 1
pqai_agent/data_models/service_module.py

@@ -1,7 +1,7 @@
 from enum import Enum
 from enum import Enum
 
 
 from sqlalchemy import Column, Integer, Text, BigInteger, String, SmallInteger, Boolean, TIMESTAMP
 from sqlalchemy import Column, Integer, Text, BigInteger, String, SmallInteger, Boolean, TIMESTAMP
-from sqlalchemy.ext.declarative import declarative_base
+from sqlalchemy.orm import declarative_base
 
 
 Base = declarative_base()
 Base = declarative_base()
 
 

+ 1 - 1
pqai_agent/database.py

@@ -5,7 +5,7 @@
 # Copyright © 2024 StrayWarrior <i@straywarrior.com>
 # Copyright © 2024 StrayWarrior <i@straywarrior.com>
 
 
 import pymysql
 import pymysql
-from pqai_agent.logging_service import logger
+from pqai_agent.logging import logger
 
 
 class MySQLManager:
 class MySQLManager:
     def __init__(self, config):
     def __init__(self, config):

+ 1 - 1
pqai_agent/dialogue_manager.py

@@ -16,7 +16,7 @@ from sqlalchemy.orm import sessionmaker, Session
 from pqai_agent import configs
 from pqai_agent import configs
 from pqai_agent.clients.relation_stage_client import RelationStageClient
 from pqai_agent.clients.relation_stage_client import RelationStageClient
 from pqai_agent.data_models.agent_push_record import AgentPushRecord
 from pqai_agent.data_models.agent_push_record import AgentPushRecord
-from pqai_agent.logging_service import logger
+from pqai_agent.logging import logger
 from pqai_agent.database import MySQLManager
 from pqai_agent.database import MySQLManager
 from pqai_agent import chat_service, prompt_templates
 from pqai_agent import chat_service, prompt_templates
 from pqai_agent.history_dialogue_service import HistoryDialogueService
 from pqai_agent.history_dialogue_service import HistoryDialogueService

+ 1 - 1
pqai_agent/history_dialogue_service.py

@@ -7,7 +7,7 @@ import requests
 from pymysql.cursors import DictCursor
 from pymysql.cursors import DictCursor
 
 
 from pqai_agent.database import MySQLManager
 from pqai_agent.database import MySQLManager
-from pqai_agent.logging_service import logger
+from pqai_agent.logging import logger
 import time
 import time
 
 
 from pqai_agent import configs
 from pqai_agent import configs

+ 9 - 7
pqai_agent/logging_service.py → pqai_agent/logging.py

@@ -26,25 +26,27 @@ class ColoredFormatter(logging.Formatter):
         return message
         return message
 
 
 def setup_root_logger(level=logging.DEBUG, logfile_name='service.log'):
 def setup_root_logger(level=logging.DEBUG, logfile_name='service.log'):
-    formatter = ColoredFormatter(
-        '%(asctime)s - %(name)s %(funcName)s[%(lineno)d] - %(levelname)s - %(message)s'
-    )
+    logging_format = '%(asctime)s - %(name)s %(funcName)s[%(lineno)d] - %(levelname)s - %(message)s'
+    plain_formatter = logging.Formatter(logging_format)
+    color_formatter = ColoredFormatter(logging_format)
     console_handler = logging.StreamHandler()
     console_handler = logging.StreamHandler()
     console_handler.setLevel(logging.DEBUG)
     console_handler.setLevel(logging.DEBUG)
-    console_handler.setFormatter(formatter)
+    console_handler.setFormatter(color_formatter)
 
 
     root_logger = logging.getLogger()
     root_logger = logging.getLogger()
     root_logger.handlers.clear()
     root_logger.handlers.clear()
     root_logger.addHandler(console_handler)
     root_logger.addHandler(console_handler)
-    if configs.get_env() == 'prod':
+
+    log_dir = configs.get().get('system', {}).get('log_dir', '')
+    if log_dir:
         file_handler = RotatingFileHandler(
         file_handler = RotatingFileHandler(
-            f'/var/log/agent_service/{logfile_name}',
+            f'{log_dir}/{logfile_name}',
             maxBytes=64 * 1024 * 1024,
             maxBytes=64 * 1024 * 1024,
             backupCount=5,
             backupCount=5,
             encoding='utf-8'
             encoding='utf-8'
         )
         )
         file_handler.setLevel(logging.DEBUG)
         file_handler.setLevel(logging.DEBUG)
-        file_handler.setFormatter(formatter)
+        file_handler.setFormatter(plain_formatter)
         root_logger.addHandler(file_handler)
         root_logger.addHandler(file_handler)
 
 
     agent_logger = logging.getLogger('agent')
     agent_logger = logging.getLogger('agent')

+ 2 - 3
pqai_agent/message_queue_backend.py

@@ -9,8 +9,7 @@ import rocketmq
 from rocketmq import ClientConfiguration, Credentials, SimpleConsumer
 from rocketmq import ClientConfiguration, Credentials, SimpleConsumer
 
 
 from pqai_agent import configs
 from pqai_agent import configs
-from pqai_agent import logging_service
-from pqai_agent.logging_service import logger
+from pqai_agent.logging import logger
 from pqai_agent.mq_message import MqMessage, MessageType, MessageChannel
 from pqai_agent.mq_message import MqMessage, MessageType, MessageChannel
 
 
 
 
@@ -87,7 +86,7 @@ class AliyunRocketMQQueueBackend(MessageQueueBackend):
             return None
             return None
         rmq_message = messages[0]
         rmq_message = messages[0]
         body = rmq_message.body.decode('utf-8')
         body = rmq_message.body.decode('utf-8')
-        logger.debug("[{}]recv message body: {}".format(self.topic, body))
+        logger.debug(f"[{self.topic}]recv message, group[{rmq_message.message_group}], body: {body}")
         try:
         try:
             message = MqMessage.from_json(body)
             message = MqMessage.from_json(body)
             message._rmq_message = rmq_message
             message._rmq_message = rmq_message

+ 1 - 0
pqai_agent/prompt_templates.py

@@ -279,6 +279,7 @@ RESPONSE_TYPE_DETECT_PROMPT = """
 * 默认使用文本
 * 默认使用文本
 * 如果用户明确提到使用语音形式,尽量选择语音
 * 如果用户明确提到使用语音形式,尽量选择语音
 * 用户自身偏向于使用语音形式沟通时,可选择语音
 * 用户自身偏向于使用语音形式沟通时,可选择语音
+* 如果用户不认字或有阅读障碍,且内容适合语音朗读,可选择语音
 * 注意分析即将发送的消息内容,如果有不适合使用语音朗读的内容,不要选择使用语音
 * 注意分析即将发送的消息内容,如果有不适合使用语音朗读的内容,不要选择使用语音
 * 注意对话中包含的时间!注意时间流逝和情境切换!判断合适的回复方式!
 * 注意对话中包含的时间!注意时间流逝和情境切换!判断合适的回复方式!
 
 

+ 64 - 45
pqai_agent/push_service.py

@@ -16,7 +16,7 @@ from pqai_agent.abtest.utils import get_abtest_info
 from pqai_agent.agents.message_push_agent import MessagePushAgent, DummyMessagePushAgent
 from pqai_agent.agents.message_push_agent import MessagePushAgent, DummyMessagePushAgent
 from pqai_agent.configs import apollo_config
 from pqai_agent.configs import apollo_config
 from pqai_agent.data_models.agent_push_record import AgentPushRecord
 from pqai_agent.data_models.agent_push_record import AgentPushRecord
-from pqai_agent.logging_service import logger
+from pqai_agent.logging import logger
 from pqai_agent.mq_message import MessageType
 from pqai_agent.mq_message import MessageType
 from pqai_agent.toolkit import get_tools
 from pqai_agent.toolkit import get_tools
 from pqai_agent.utils.agent_abtest_utils import get_agent_abtest_config
 from pqai_agent.utils.agent_abtest_utils import get_agent_abtest_config
@@ -54,7 +54,10 @@ class PushScanThread:
         first_initiate_tags = set(apollo_config.get_json_value('agent_first_initiate_whitelist_tags', []))
         first_initiate_tags = set(apollo_config.get_json_value('agent_first_initiate_whitelist_tags', []))
         # 合并白名单,减少配置成本
         # 合并白名单,减少配置成本
         white_list_tags.update(first_initiate_tags)
         white_list_tags.update(first_initiate_tags)
-        for staff_user in self.service.user_relation_manager.list_staff_users(staff_id=self.staff_id):
+        all_staff_users = self.service.user_relation_manager.list_staff_users(staff_id=self.staff_id)
+        all_users = list({pair['user_id'] for pair in all_staff_users})
+        all_user_tags = self.service.user_manager.get_user_tags(all_users)
+        for staff_user in all_staff_users:
             staff_id = staff_user['staff_id']
             staff_id = staff_user['staff_id']
             user_id = staff_user['user_id']
             user_id = staff_user['user_id']
             # 通过AB实验配置控制用户组是否启用push
             # 通过AB实验配置控制用户组是否启用push
@@ -62,8 +65,8 @@ class PushScanThread:
             # if abtest_params.get('agent_push_enabled', 'false').lower() != 'true':
             # if abtest_params.get('agent_push_enabled', 'false').lower() != 'true':
             #     logger.debug(f"User {user_id} not enabled agent push, skipping.")
             #     logger.debug(f"User {user_id} not enabled agent push, skipping.")
             #     continue
             #     continue
-            user_tags = self.service.user_relation_manager.get_user_tags(user_id)
-            if configs.get_env() != 'dev' and not white_list_tags.intersection(user_tags):
+            user_tags = all_user_tags.get(user_id, list())
+            if not white_list_tags.intersection(user_tags):
                 should_initiate = False
                 should_initiate = False
             else:
             else:
                 agent = self.service.get_agent_instance(staff_id, user_id)
                 agent = self.service.get_agent_instance(staff_id, user_id)
@@ -78,66 +81,79 @@ class PushScanThread:
 
 
 class PushTaskWorkerPool:
 class PushTaskWorkerPool:
     def __init__(self, agent_service: 'AgentService', mq_topic: str,
     def __init__(self, agent_service: 'AgentService', mq_topic: str,
-                 mq_consumer: rocketmq.SimpleConsumer, mq_producer: rocketmq.Producer):
+                 mq_consumer_generate: rocketmq.SimpleConsumer,
+                 mq_consumer_send: rocketmq.SimpleConsumer,
+                 mq_producer: rocketmq.Producer):
         self.agent_service = agent_service
         self.agent_service = agent_service
         max_workers = configs.get()['system'].get('push_task_workers', 5)
         max_workers = configs.get()['system'].get('push_task_workers', 5)
         self.max_push_workers = max_workers
         self.max_push_workers = max_workers
         self.generate_executor = ThreadPoolExecutor(max_workers=max_workers)
         self.generate_executor = ThreadPoolExecutor(max_workers=max_workers)
         self.send_executors = {}
         self.send_executors = {}
         self.rmq_topic = mq_topic
         self.rmq_topic = mq_topic
-        self.consumer = mq_consumer
+        self.generate_consumer = mq_consumer_generate
+        self.send_consumer = mq_consumer_send
         self.producer = mq_producer
         self.producer = mq_producer
-        self.loop_thread = None
+        self.generate_loop_thread = None
+        self.send_loop_thread = None
         self.is_generator_running = True
         self.is_generator_running = True
-        self.generate_send_done = False # set by wait_to_finish
-        self.no_more_generate_task = False # set by self
+        self.generate_send_done = False # set by wait_to_finish,表示所有生成任务均已进入队列
+        self.no_more_generate_task = False # generate_send_done被设置之后队列中未再收到生成任务时设置
 
 
     def start(self):
     def start(self):
-        self.loop_thread = Thread(target=self.process_push_tasks)
-        self.loop_thread.start()
+        self.send_loop_thread = Thread(target=self.process_send_tasks)
+        self.send_loop_thread.start()
+        self.generate_loop_thread = Thread(target=self.process_generate_tasks)
+        self.generate_loop_thread.start()
 
 
-    def process_push_tasks(self):
-        # RMQ consumer疑似有bug,创建后立即消费可能报NPE
+    def process_send_tasks(self):
         time.sleep(1)
         time.sleep(1)
         while True:
         while True:
-            msgs = self.consumer.receive(1, 300)
+            msgs = self.send_consumer.receive(1, 60)
             if not msgs:
             if not msgs:
                 # 没有生成任务在执行且没有消息,才可退出
                 # 没有生成任务在执行且没有消息,才可退出
-                if self.generate_send_done:
-                    if not self.no_more_generate_task:
-                        logger.debug("no message received, there should be no more generate task")
-                        self.no_more_generate_task = True
-                        continue
-                    else:
-                        if self.is_generator_running:
-                            logger.debug("Waiting for generator threads to finish")
-                            continue
-                        else:
-                            break
+                if self.no_more_generate_task and not self.is_generator_running:
+                    break
                 else:
                 else:
                     continue
                     continue
             msg = msgs[0]
             msg = msgs[0]
             task = json.loads(msg.body.decode('utf-8'))
             task = json.loads(msg.body.decode('utf-8'))
             msg_time = datetime.fromtimestamp(task['timestamp'] / 1000).strftime("%Y-%m-%d %H:%M:%S")
             msg_time = datetime.fromtimestamp(task['timestamp'] / 1000).strftime("%Y-%m-%d %H:%M:%S")
             logger.debug(f"recv message:{msg_time} - {task}")
             logger.debug(f"recv message:{msg_time} - {task}")
-            if task['task_type'] == TaskType.GENERATE.value:
-                # FIXME: 临时方案,避免消息在消费后等待超时并重复消费
-                if self.generate_executor._work_queue.qsize() > self.max_push_workers * 5:
-                    logger.warning("Too many generate tasks in queue, consume this task later")
-                    while self.generate_executor._work_queue.qsize() > self.max_push_workers * 5:
-                        time.sleep(10)
-                    # do not submit and ack this message
-                    continue
-                self.generate_executor.submit(self.handle_generate_task, task, msg)
-            elif task['task_type'] == TaskType.SEND.value:
+            if task['task_type'] == TaskType.SEND.value:
                 staff_id = task['staff_id']
                 staff_id = task['staff_id']
                 if staff_id not in self.send_executors:
                 if staff_id not in self.send_executors:
                     self.send_executors[staff_id] = ThreadPoolExecutor(max_workers=1)
                     self.send_executors[staff_id] = ThreadPoolExecutor(max_workers=1)
                 self.send_executors[staff_id].submit(self.handle_send_task, task, msg)
                 self.send_executors[staff_id].submit(self.handle_send_task, task, msg)
             else:
             else:
                 logger.error(f"Unknown task type: {task['task_type']}")
                 logger.error(f"Unknown task type: {task['task_type']}")
-                self.consumer.ack(msg)
-        logger.info("PushGenerateWorkerPool stopped")
+                self.send_consumer.ack(msg)
+        logger.info("PushGenerateWorkerPool send thread stopped")
+
+    def process_generate_tasks(self):
+        time.sleep(1)
+        while True:
+            if self.generate_executor._work_queue.qsize() > self.max_push_workers * 2:
+                logger.warning("Too many generate tasks in queue, consume later")
+                time.sleep(10)
+                continue
+            msgs = self.generate_consumer.receive(1, 300)
+            if not msgs:
+                # 生成任务已经创建完成 且 未收到新任务,才可退出
+                if self.generate_send_done:
+                    logger.debug("no message received, there should be no more generate task")
+                    self.no_more_generate_task = True
+                    break
+                else:
+                    continue
+            msg = msgs[0]
+            task = json.loads(msg.body.decode('utf-8'))
+            msg_time = datetime.fromtimestamp(task['timestamp'] / 1000).strftime("%Y-%m-%d %H:%M:%S")
+            logger.debug(f"recv message:{msg_time} - {task}")
+            if task['task_type'] == TaskType.GENERATE.value:
+                self.generate_executor.submit(self.handle_generate_task, task, msg)
+            else:
+                self.generate_consumer.ack(msg)
+        logger.info("PushGenerateWorkerPool generator thread stopped")
 
 
     def wait_to_finish(self):
     def wait_to_finish(self):
         self.generate_send_done = True
         self.generate_send_done = True
@@ -146,7 +162,8 @@ class PushTaskWorkerPool:
             time.sleep(1)
             time.sleep(1)
         self.generate_executor.shutdown(wait=True)
         self.generate_executor.shutdown(wait=True)
         self.is_generator_running = False
         self.is_generator_running = False
-        self.loop_thread.join()
+        self.generate_loop_thread.join()
+        self.send_loop_thread.join()
 
 
     def handle_send_task(self, task: Dict, msg: rocketmq.Message):
     def handle_send_task(self, task: Dict, msg: rocketmq.Message):
         try:
         try:
@@ -155,13 +172,13 @@ class PushTaskWorkerPool:
             agent = self.agent_service.get_agent_instance(staff_id, user_id)
             agent = self.agent_service.get_agent_instance(staff_id, user_id)
             # 二次校验是否需要发送
             # 二次校验是否需要发送
             if not agent.should_initiate_conversation():
             if not agent.should_initiate_conversation():
-                logger.debug(f"user[{user_id}], do not initiate conversation")
-                self.consumer.ack(msg)
+                logger.debug(f"user[{user_id}], should not initiate, skip sending task")
+                self.send_consumer.ack(msg)
                 return
                 return
             contents: List[Dict] = json.loads(task['content'])
             contents: List[Dict] = json.loads(task['content'])
             if not contents:
             if not contents:
                 logger.debug(f"staff[{staff_id}], user[{user_id}]: empty content, do not send")
                 logger.debug(f"staff[{staff_id}], user[{user_id}]: empty content, do not send")
-                self.consumer.ack(msg)
+                self.send_consumer.ack(msg)
                 return
                 return
             recent_dialogue = agent.dialogue_history[-10:]
             recent_dialogue = agent.dialogue_history[-10:]
             agent_voice_whitelist = set(apollo_config.get_json_value("agent_voice_whitelist", []))
             agent_voice_whitelist = set(apollo_config.get_json_value("agent_voice_whitelist", []))
@@ -197,11 +214,11 @@ class PushTaskWorkerPool:
                 agent.update_last_active_interaction_time(current_ts)
                 agent.update_last_active_interaction_time(current_ts)
             else:
             else:
                 logger.debug(f"staff[{staff_id}], user[{user_id}]: generate empty response")
                 logger.debug(f"staff[{staff_id}], user[{user_id}]: generate empty response")
-            self.consumer.ack(msg)
+            self.send_consumer.ack(msg)
         except Exception as e:
         except Exception as e:
             fmt_exc = traceback.format_exc()
             fmt_exc = traceback.format_exc()
             logger.error(f"Error processing message sending: {e}, {fmt_exc}")
             logger.error(f"Error processing message sending: {e}, {fmt_exc}")
-            self.consumer.ack(msg)
+            self.send_consumer.ack(msg)
 
 
     def handle_generate_task(self, task: Dict, msg: rocketmq.Message):
     def handle_generate_task(self, task: Dict, msg: rocketmq.Message):
         try:
         try:
@@ -231,15 +248,17 @@ class PushTaskWorkerPool:
                 ),
                 ),
                 query_prompt_template=query_prompt_template
                 query_prompt_template=query_prompt_template
             )
             )
+            cost = push_agent.get_total_cost()
+            logger.debug(f"staff[{staff_id}], user[{user_id}]: push message generation cost: {cost}")
             if message_to_user:
             if message_to_user:
                 rmq_message = generate_task_rmq_message(
                 rmq_message = generate_task_rmq_message(
                     self.rmq_topic, staff_id, user_id, TaskType.SEND, json.dumps(message_to_user))
                     self.rmq_topic, staff_id, user_id, TaskType.SEND, json.dumps(message_to_user))
                 self.producer.send(rmq_message)
                 self.producer.send(rmq_message)
             else:
             else:
                 logger.info(f"staff[{staff_id}], user[{user_id}]: no push message generated")
                 logger.info(f"staff[{staff_id}], user[{user_id}]: no push message generated")
-            self.consumer.ack(msg)
+            self.generate_consumer.ack(msg)
         except Exception as e:
         except Exception as e:
             fmt_exc = traceback.format_exc()
             fmt_exc = traceback.format_exc()
             logger.error(f"Error processing message generation: {e}, {fmt_exc}")
             logger.error(f"Error processing message generation: {e}, {fmt_exc}")
             # FIXME: 是否需要ACK
             # FIXME: 是否需要ACK
-            self.consumer.ack(msg)
+            self.generate_consumer.ack(msg)

+ 1 - 1
pqai_agent/rate_limiter.py

@@ -5,7 +5,7 @@
 import time
 import time
 from typing import Optional, Union, Dict
 from typing import Optional, Union, Dict
 
 
-from pqai_agent.logging_service import logger
+from pqai_agent.logging import logger
 from pqai_agent.mq_message import MessageType
 from pqai_agent.mq_message import MessageType
 
 
 
 

+ 2 - 2
pqai_agent/response_type_detector.py

@@ -12,7 +12,7 @@ from pqai_agent import chat_service
 from pqai_agent import configs
 from pqai_agent import configs
 from pqai_agent import prompt_templates
 from pqai_agent import prompt_templates
 from pqai_agent.dialogue_manager import DialogueManager
 from pqai_agent.dialogue_manager import DialogueManager
-from pqai_agent.logging_service import logger
+from pqai_agent.logging import logger
 from pqai_agent.mq_message import MessageType
 from pqai_agent.mq_message import MessageType
 
 
 
 
@@ -36,7 +36,7 @@ class ResponseTypeDetector:
             api_key=chat_service.VOLCENGINE_API_TOKEN,
             api_key=chat_service.VOLCENGINE_API_TOKEN,
             base_url=chat_service.VOLCENGINE_BASE_URL
             base_url=chat_service.VOLCENGINE_BASE_URL
         )
         )
-        self.model_name = chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_1_5
+        self.model_name = chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_1_5_32K
 
 
     def detect_type(self, dialogue_history: List[Dict], next_message: Dict, enable_random=False,
     def detect_type(self, dialogue_history: List[Dict], next_message: Dict, enable_random=False,
                     random_rate=0.25):
                     random_rate=0.25):

+ 1 - 1
pqai_agent/service_module_manager.py

@@ -1,5 +1,5 @@
 from pqai_agent.data_models.service_module import ServiceModule, ModuleAgentType
 from pqai_agent.data_models.service_module import ServiceModule, ModuleAgentType
-from pqai_agent.logging_service import logger
+from pqai_agent.logging import logger
 
 
 class ServiceModuleManager:
 class ServiceModuleManager:
     def __init__(self, session_maker):
     def __init__(self, session_maker):

+ 1 - 1
pqai_agent/toolkit/__init__.py

@@ -1,7 +1,7 @@
 # 必须要在这里导入模块,以便对应的模块执行register_toolkit
 # 必须要在这里导入模块,以便对应的模块执行register_toolkit
 from typing import Sequence, List
 from typing import Sequence, List
 
 
-from pqai_agent.logging_service import logger
+from pqai_agent.logging import logger
 from pqai_agent.toolkit.tool_registry import ToolRegistry
 from pqai_agent.toolkit.tool_registry import ToolRegistry
 from pqai_agent.toolkit.image_describer import ImageDescriber
 from pqai_agent.toolkit.image_describer import ImageDescriber
 from pqai_agent.toolkit.message_notifier import MessageNotifier
 from pqai_agent.toolkit.message_notifier import MessageNotifier

+ 1 - 1
pqai_agent/toolkit/function_tool.py

@@ -8,7 +8,7 @@ from pydantic import BaseModel, create_model
 from pydantic.fields import FieldInfo
 from pydantic.fields import FieldInfo
 from jsonschema.validators import Draft202012Validator as JSONValidator
 from jsonschema.validators import Draft202012Validator as JSONValidator
 import re
 import re
-from pqai_agent.logging_service import logger
+from pqai_agent.logging import logger
 
 
 
 
 def to_pascal(snake: str) -> str:
 def to_pascal(snake: str) -> str:

+ 1 - 1
pqai_agent/toolkit/image_describer.py

@@ -3,7 +3,7 @@ import threading
 
 
 from pqai_agent import chat_service
 from pqai_agent import chat_service
 from pqai_agent.chat_service import VOLCENGINE_MODEL_DOUBAO_1_5_VISION_PRO
 from pqai_agent.chat_service import VOLCENGINE_MODEL_DOUBAO_1_5_VISION_PRO
-from pqai_agent.logging_service import logger
+from pqai_agent.logging import logger
 from pqai_agent.toolkit.base import BaseToolkit
 from pqai_agent.toolkit.base import BaseToolkit
 from pqai_agent.toolkit.function_tool import FunctionTool
 from pqai_agent.toolkit.function_tool import FunctionTool
 from pqai_agent.toolkit.tool_registry import register_toolkit
 from pqai_agent.toolkit.tool_registry import register_toolkit

+ 1 - 1
pqai_agent/toolkit/lark_sheet_record_for_human_intervention.py

@@ -4,7 +4,7 @@ from typing import List
 import requests
 import requests
 from pqai_agent.toolkit.base import BaseToolkit
 from pqai_agent.toolkit.base import BaseToolkit
 from pqai_agent.toolkit.function_tool import FunctionTool
 from pqai_agent.toolkit.function_tool import FunctionTool
-from pqai_agent.logging_service import logger
+from pqai_agent.logging import logger
 
 
 class LarkSheetRecordForHumanIntervention(BaseToolkit):
 class LarkSheetRecordForHumanIntervention(BaseToolkit):
     r"""A toolkit for recording human intervention events into a Feishu spreadsheet."""
     r"""A toolkit for recording human intervention events into a Feishu spreadsheet."""

+ 1 - 1
pqai_agent/toolkit/message_notifier.py

@@ -1,6 +1,6 @@
 from typing import List, Dict
 from typing import List, Dict
 
 
-from pqai_agent.logging_service import logger
+from pqai_agent.logging import logger
 from pqai_agent.toolkit.base import BaseToolkit
 from pqai_agent.toolkit.base import BaseToolkit
 from pqai_agent.toolkit.function_tool import FunctionTool
 from pqai_agent.toolkit.function_tool import FunctionTool
 from pqai_agent.toolkit.tool_registry import register_toolkit
 from pqai_agent.toolkit.tool_registry import register_toolkit

+ 51 - 3
pqai_agent/user_manager.py

@@ -1,8 +1,9 @@
 #! /usr/bin/env python
 #! /usr/bin/env python
 # -*- coding: utf-8 -*-
 # -*- coding: utf-8 -*-
 # vim:fenc=utf-8
 # vim:fenc=utf-8
+from abc import abstractmethod
 
 
-from pqai_agent.logging_service import logger
+from pqai_agent.logging import logger
 from typing import Dict, Optional, List
 from typing import Dict, Optional, List
 import json
 import json
 import time
 import time
@@ -33,6 +34,10 @@ class UserManager(abc.ABC):
         #FIXME(zhoutian): 重新设计用户和员工数据管理模型
         #FIXME(zhoutian): 重新设计用户和员工数据管理模型
         pass
         pass
 
 
+    @abstractmethod
+    def get_user_tags(self, user_ids: List[str], batch_size = 500) -> Dict[str, List[str]]:
+        pass
+
     @staticmethod
     @staticmethod
     def get_default_profile(**kwargs) -> Dict:
     def get_default_profile(**kwargs) -> Dict:
         default_profile = {
         default_profile = {
@@ -133,6 +138,9 @@ class LocalUserManager(UserManager):
             logger.error("staff profile not found: {}".format(e))
             logger.error("staff profile not found: {}".format(e))
             return {}
             return {}
 
 
+    def get_user_tags(self, user_ids: List[str], batch_size = 500) -> Dict[str, List[str]]:
+        return {}
+
     def list_users(self, **kwargs) -> List[Dict]:
     def list_users(self, **kwargs) -> List[Dict]:
         pass
         pass
 
 
@@ -249,6 +257,37 @@ class MySQLUserManager(UserManager):
         sql = f"UPDATE {self.staff_table} SET agent_profile = %s WHERE third_party_user_id = '{staff_id}'"
         sql = f"UPDATE {self.staff_table} SET agent_profile = %s WHERE third_party_user_id = '{staff_id}'"
         self.db.execute(sql, (json.dumps(profile),))
         self.db.execute(sql, (json.dumps(profile),))
 
 
+    def get_user_tags(self, user_ids: List[str], batch_size = 500) -> Dict[str, List[str]]:
+        """
+        获取用户的标签列表
+        :param user_ids: 用户ID
+        :param batch_size: 批量查询的大小
+        :return: 标签名称列表
+        """
+        batches = (len(user_ids) + batch_size - 1) // batch_size
+        ret = {}
+        for i in range(batches):
+            idx_begin = i * batch_size
+            idx_end = min((i + 1) * batch_size, len(user_ids))
+            batch_user_ids = user_ids[idx_begin:idx_end]
+            sql = f"""
+                SELECT a.third_party_user_id, b.tag_id, b.name FROM qywx_user_tag a
+                    JOIN qywx_tag b ON a.tag_id = b.tag_id
+                """
+            if len(batch_user_ids) == 1:
+                sql += f" AND a.third_party_user_id = '{batch_user_ids[0]}'"""
+            else:
+                sql += f" AND a.third_party_user_id IN {str(tuple(batch_user_ids))}"
+            rows = self.db.select(sql, pymysql.cursors.DictCursor)
+            # group by user_id
+            for row in rows:
+                user_id = row['third_party_user_id']
+                tag_name = row['name']
+                if user_id not in ret:
+                    ret[user_id] = []
+                ret[user_id].append(tag_name)
+        return ret
+
     def list_users(self, **kwargs) -> List[Dict]:
     def list_users(self, **kwargs) -> List[Dict]:
         user_union_id = kwargs.get('user_union_id', None)
         user_union_id = kwargs.get('user_union_id', None)
         user_name = kwargs.get('user_name', None)
         user_name = kwargs.get('user_name', None)
@@ -333,6 +372,7 @@ class MySQLUserRelationManager(UserRelationManager):
         self.relation_table = relation_table
         self.relation_table = relation_table
         self.agent_user_table = agent_user_table
         self.agent_user_table = agent_user_table
         self.user_table = user_table
         self.user_table = user_table
+        self.agent_user_relation_table = 'qywx_employee_customer'
 
 
     def list_staffs(self):
     def list_staffs(self):
         sql = f"SELECT third_party_user_id, name, wxid, agent_name FROM {self.agent_staff_table} WHERE status = 1"
         sql = f"SELECT third_party_user_id, name, wxid, agent_name FROM {self.agent_staff_table} WHERE status = 1"
@@ -342,7 +382,14 @@ class MySQLUserRelationManager(UserRelationManager):
     def list_users(self, staff_id: str, page: int = 1, page_size: int = 100):
     def list_users(self, staff_id: str, page: int = 1, page_size: int = 100):
         return []
         return []
 
 
-    def list_staff_users(self, staff_id: str = None, tag_id: int = None):
+    def list_staff_users(self, staff_id: str = None, tag_id: int = None) -> List[Dict]:
+        sql = f"SELECT employee_id as staff_id, customer_id as user_id FROM {self.agent_user_relation_table} WHERE 1 = 1"
+        if staff_id:
+            sql += f" AND employee_id = '{staff_id}'"
+        agent_staff_data = self.agent_db.select(sql, pymysql.cursors.DictCursor)
+        return agent_staff_data
+
+    def list_staff_users_v1(self, staff_id: str = None, tag_id: int = None):
         sql = f"SELECT third_party_user_id, wxid FROM {self.agent_staff_table} WHERE status = 1"
         sql = f"SELECT third_party_user_id, wxid FROM {self.agent_staff_table} WHERE status = 1"
         if staff_id:
         if staff_id:
             sql += f" AND third_party_user_id = '{staff_id}'"
             sql += f" AND third_party_user_id = '{staff_id}'"
@@ -382,7 +429,8 @@ class MySQLUserRelationManager(UserRelationManager):
                 sql = f"SELECT third_party_user_id, wxid FROM {self.agent_user_table} WHERE wxid IN {str(batch_union_ids)}"
                 sql = f"SELECT third_party_user_id, wxid FROM {self.agent_user_table} WHERE wxid IN {str(batch_union_ids)}"
                 batch_agent_user_data = self.agent_db.select(sql, pymysql.cursors.DictCursor)
                 batch_agent_user_data = self.agent_db.select(sql, pymysql.cursors.DictCursor)
                 if len(agent_user_data) != len(batch_union_ids):
                 if len(agent_user_data) != len(batch_union_ids):
-                    # logger.debug(f"staff[{wxid}] some users not found in agent database")
+                    diff_num = len(batch_union_ids) - len(batch_agent_user_data)
+                    logger.debug(f"staff[{staff_id}] {diff_num} users not found in agent database")
                     pass
                     pass
                 agent_user_data.extend(batch_agent_user_data)
                 agent_user_data.extend(batch_agent_user_data)
             staff_user_pairs = [
             staff_user_pairs = [

+ 3 - 3
pqai_agent/user_profile_extractor.py

@@ -8,7 +8,7 @@ from typing import Dict, Optional, List
 from pqai_agent import chat_service, configs
 from pqai_agent import chat_service, configs
 from pqai_agent.prompt_templates import USER_PROFILE_EXTRACT_PROMPT, USER_PROFILE_EXTRACT_PROMPT_V2
 from pqai_agent.prompt_templates import USER_PROFILE_EXTRACT_PROMPT, USER_PROFILE_EXTRACT_PROMPT_V2
 from openai import OpenAI
 from openai import OpenAI
-from pqai_agent.logging_service import logger
+from pqai_agent.logging import logger
 from pqai_agent.utils import prompt_utils
 from pqai_agent.utils import prompt_utils
 
 
 
 
@@ -198,8 +198,8 @@ class UserProfileExtractor:
 
 
 if __name__ == '__main__':
 if __name__ == '__main__':
     from pqai_agent import configs
     from pqai_agent import configs
-    from pqai_agent import logging_service
-    logging_service.setup_root_logger()
+    from pqai_agent import logging
+    logging.setup_root_logger()
     config = configs.get()
     config = configs.get()
     config['debug_flags']['disable_llm_api_call'] = False
     config['debug_flags']['disable_llm_api_call'] = False
     extractor = UserProfileExtractor()
     extractor = UserProfileExtractor()

+ 8 - 0
pqai_agent/utils/__init__.py

@@ -0,0 +1,8 @@
+import hashlib
+import random
+import time
+
+def random_str() -> str:
+    """生成一个随机的MD5字符串"""
+    random_string = str(random.randint(0, 1000000)) + str(time.time())
+    return hashlib.md5(random_string.encode('utf-8')).hexdigest()

+ 7 - 4
pqai_agent_server/agent_server.py

@@ -2,10 +2,10 @@ import logging
 import sys
 import sys
 import time
 import time
 
 
-from pqai_agent import configs, logging_service
+from pqai_agent import configs
 from pqai_agent.agent_service import AgentService
 from pqai_agent.agent_service import AgentService
 from pqai_agent.chat_service import ChatServiceType
 from pqai_agent.chat_service import ChatServiceType
-from pqai_agent.logging_service import logger
+from pqai_agent.logging import logger, setup_root_logger
 from pqai_agent.mq_message import MessageType, MqMessage, MessageChannel
 from pqai_agent.mq_message import MessageType, MqMessage, MessageChannel
 from pqai_agent.message_queue_backend import AliyunRocketMQQueueBackend, MemoryQueueBackend
 from pqai_agent.message_queue_backend import AliyunRocketMQQueueBackend, MemoryQueueBackend
 from pqai_agent.push_service import PushTaskWorkerPool, PushScanThread
 from pqai_agent.push_service import PushTaskWorkerPool, PushScanThread
@@ -14,7 +14,7 @@ from pqai_agent.user_manager import LocalUserManager, LocalUserRelationManager,
 
 
 if __name__ == "__main__":
 if __name__ == "__main__":
     config = configs.get()
     config = configs.get()
-    logging_service.setup_root_logger()
+    setup_root_logger()
     logger.warning("current env: {}".format(configs.get_env()))
     logger.warning("current env: {}".format(configs.get_env()))
     scheduler_logger = logging.getLogger('apscheduler')
     scheduler_logger = logging.getLogger('apscheduler')
     scheduler_logger.setLevel(logging.WARNING)
     scheduler_logger.setLevel(logging.WARNING)
@@ -92,13 +92,16 @@ if __name__ == "__main__":
             continue
             continue
         message_id += 1
         message_id += 1
         sender = '7881301903997433'
         sender = '7881301903997433'
-        receiver = '1688855931724582'
+        receiver = '1688854974625870'
         if text in (MessageType.AGGREGATION_TRIGGER.name,
         if text in (MessageType.AGGREGATION_TRIGGER.name,
                     MessageType.HUMAN_INTERVENTION_END.name):
                     MessageType.HUMAN_INTERVENTION_END.name):
             message = MqMessage.build(
             message = MqMessage.build(
                 MessageType.__members__.get(text),
                 MessageType.__members__.get(text),
                 MessageChannel.CORP_WECHAT,
                 MessageChannel.CORP_WECHAT,
                 sender, receiver, None, int(time.time() * 1000))
                 sender, receiver, None, int(time.time() * 1000))
+        elif text == 'S_PUSH':
+            service._check_initiative_conversations()
+            continue
         else:
         else:
             message = MqMessage.build(MessageType.TEXT, MessageChannel.CORP_WECHAT,
             message = MqMessage.build(MessageType.TEXT, MessageChannel.CORP_WECHAT,
                                       sender, receiver, text, int(time.time() * 1000)
                                       sender, receiver, text, int(time.time() * 1000)

+ 140 - 57
pqai_agent_server/api_server.py

@@ -1,6 +1,8 @@
 #! /usr/bin/env python
 #! /usr/bin/env python
 # -*- coding: utf-8 -*-
 # -*- coding: utf-8 -*-
 # vim:fenc=utf-8
 # vim:fenc=utf-8
+import json
+import time
 import logging
 import logging
 from argparse import ArgumentParser
 from argparse import ArgumentParser
 
 
@@ -9,6 +11,10 @@ from flask import Flask, request, jsonify
 from sqlalchemy.orm import sessionmaker
 from sqlalchemy.orm import sessionmaker
 
 
 from pqai_agent import configs
 from pqai_agent import configs
+
+from pqai_agent import chat_service, prompt_templates
+from pqai_agent.logging import logger, setup_root_logger
+from pqai_agent.toolkit import global_tool_map
 from pqai_agent import logging_service, chat_service, prompt_templates
 from pqai_agent import logging_service, chat_service, prompt_templates
 from pqai_agent.data_models.agent_configuration import AgentConfiguration
 from pqai_agent.data_models.agent_configuration import AgentConfiguration
 from pqai_agent.data_models.service_module import ServiceModule
 from pqai_agent.data_models.service_module import ServiceModule
@@ -21,6 +27,8 @@ from pqai_agent_server.const.status_enum import TestTaskStatus
 from pqai_agent_server.const.type_enum import EvaluateType
 from pqai_agent_server.const.type_enum import EvaluateType
 from pqai_agent_server.dataset_service import DatasetService
 from pqai_agent_server.dataset_service import DatasetService
 from pqai_agent_server.models import MySQLSessionManager
 from pqai_agent_server.models import MySQLSessionManager
+import pqai_agent_server.utils
+from pqai_agent_server.utils import wrap_response
 from pqai_agent_server.task_server import TaskManager
 from pqai_agent_server.task_server import TaskManager
 from pqai_agent_server.utils import (
 from pqai_agent_server.utils import (
     run_extractor_prompt,
     run_extractor_prompt,
@@ -30,7 +38,6 @@ from pqai_agent_server.utils import (
 from pqai_agent_server.utils import wrap_response, quit_human_intervention_status
 from pqai_agent_server.utils import wrap_response, quit_human_intervention_status
 
 
 app = Flask('agent_api_server')
 app = Flask('agent_api_server')
-logger = logging_service.logger
 const = AgentApiConst()
 const = AgentApiConst()
 
 
 
 
@@ -93,34 +100,23 @@ def get_dialogue_history():
 
 
 @app.route('/api/listModels', methods=['GET'])
 @app.route('/api/listModels', methods=['GET'])
 def list_models():
 def list_models():
-    models = [
-        {
-            'model_type': 'openai_compatible',
-            'model_name': chat_service.VOLCENGINE_MODEL_DEEPSEEK_V3,
-            'display_name': 'DeepSeek V3 on 火山'
-        },
-        {
-            'model_type': 'openai_compatible',
-            'model_name': chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_32K,
-            'display_name': '豆包Pro 32K'
-        },
-        {
-            'model_type': 'openai_compatible',
-            'model_name': chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_1_5,
-            'display_name': '豆包Pro 1.5'
-        },
-        {
-            'model_type': 'openai_compatible',
-            'model_name': chat_service.VOLCENGINE_BOT_DEEPSEEK_V3_SEARCH,
-            'display_name': 'DeepSeek V3联网 on 火山'
-        },
+    models = {
+        "deepseek-chat": chat_service.VOLCENGINE_MODEL_DEEPSEEK_V3,
+        "gpt-4o": chat_service.OPENAI_MODEL_GPT_4o,
+        "gpt-4o-mini": chat_service.OPENAI_MODEL_GPT_4o_mini,
+        "doubao-pro-32k": chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_32K,
+        "doubao-pro-1.5": chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_1_5_32K,
+        "doubao-1.5-vision-pro": chat_service.VOLCENGINE_MODEL_DOUBAO_1_5_VISION_PRO,
+    }
+    ret_data = [
         {
         {
             'model_type': 'openai_compatible',
             'model_type': 'openai_compatible',
-            'model_name': chat_service.VOLCENGINE_MODEL_DOUBAO_1_5_VISION_PRO,
-            'display_name': '豆包1.5视觉理解Pro'
-        },
+            'model_name': model_name,
+            'display_name': model_display_name
+        }
+        for model_display_name, model_name in models.items()
     ]
     ]
-    return wrap_response(200, data=models)
+    return wrap_response(200, data=ret_data)
 
 
 
 
 @app.route('/api/listScenes', methods=['GET'])
 @app.route('/api/listScenes', methods=['GET'])
@@ -148,8 +144,8 @@ def get_base_prompt():
     model_map = {
     model_map = {
         'greeting': chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_32K,
         'greeting': chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_32K,
         'chitchat': chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_32K,
         'chitchat': chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_32K,
-        'profile_extractor': chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_1_5,
-        'response_type_detector': chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_1_5,
+        'profile_extractor': chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_1_5_32K,
+        'response_type_detector': chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_1_5_32K,
         'custom_debugging': chat_service.VOLCENGINE_BOT_DEEPSEEK_V3_SEARCH
         'custom_debugging': chat_service.VOLCENGINE_BOT_DEEPSEEK_V3_SEARCH
     }
     }
     if scene not in prompt_map:
     if scene not in prompt_map:
@@ -180,7 +176,6 @@ def run_prompt():
         logger.error(e)
         logger.error(e)
         return wrap_response(500, msg='Error: {}'.format(e))
         return wrap_response(500, msg='Error: {}'.format(e))
 
 
-
 @app.route('/api/formatForPrompt', methods=['POST'])
 @app.route('/api/formatForPrompt', methods=['POST'])
 def format_data_for_prompt():
 def format_data_for_prompt():
     try:
     try:
@@ -300,8 +295,8 @@ def send_message():
     return wrap_response(200, msg="暂不实现功能")
     return wrap_response(200, msg="暂不实现功能")
 
 
 
 
-@app.route("/api/quitHumanInterventionStatus", methods=["POST"])
-def quit_human_interventions_status():
+@app.route("/api/quitHumanIntervention", methods=["POST"])
+def quit_human_intervention():
     """
     """
     退出人工介入状态
     退出人工介入状态
     :return:
     :return:
@@ -311,10 +306,27 @@ def quit_human_interventions_status():
     user_id = req_data["user_id"]
     user_id = req_data["user_id"]
     if not user_id or not staff_id:
     if not user_id or not staff_id:
         return wrap_response(404, msg="user_id and staff_id are required")
         return wrap_response(404, msg="user_id and staff_id are required")
-    response = quit_human_intervention_status(user_id, staff_id)
+    if pqai_agent_server.utils.common.quit_human_intervention(user_id, staff_id):
+        return wrap_response(200, msg="success")
+    else:
+        return wrap_response(500, msg="error")
 
 
-    return wrap_response(200, data=response)
 
 
+@app.route("/api/enterHumanIntervention", methods=["POST"])
+def enter_human_intervention():
+    """
+    进入人工介入状态
+    :return:
+    """
+    req_data = request.json
+    staff_id = req_data["staff_id"]
+    user_id = req_data["user_id"]
+    if not user_id or not staff_id:
+        return wrap_response(404, msg="user_id and staff_id are required")
+    if pqai_agent_server.utils.common.enter_human_intervention(user_id, staff_id):
+        return wrap_response(200, msg="success")
+    else:
+        return wrap_response(500, msg="error")
 
 
 ## Agent管理接口
 ## Agent管理接口
 @app.route("/api/getNativeAgentList", methods=["GET"])
 @app.route("/api/getNativeAgentList", methods=["GET"])
@@ -336,23 +348,28 @@ def get_native_agent_list():
             query = query.filter(AgentConfiguration.create_user == create_user)
             query = query.filter(AgentConfiguration.create_user == create_user)
         if update_user:
         if update_user:
             query = query.filter(AgentConfiguration.update_user == update_user)
             query = query.filter(AgentConfiguration.update_user == update_user)
+        total = query.count()
         query = query.offset(offset).limit(int(page_size))
         query = query.offset(offset).limit(int(page_size))
         data = query.all()
         data = query.all()
-    ret_data = [
-        {
-            'id': agent.id,
-            'name': agent.name,
-            'display_name': agent.display_name,
-            'type': agent.type,
-            'execution_model': agent.execution_model,
-            'create_time': agent.create_time.strftime('%Y-%m-%d %H:%M:%S'),
-            'update_time': agent.update_time.strftime('%Y-%m-%d %H:%M:%S')
-        }
-        for agent in data
-    ]
+    ret_data = {
+        'total': total,
+        'agent_list': [
+            {
+                'id': agent.id,
+                'name': agent.name,
+                'display_name': agent.display_name,
+                'type': agent.type,
+                'execution_model': agent.execution_model,
+                'create_user': agent.create_user,
+                'update_user': agent.update_user,
+                'create_time': agent.create_time.strftime('%Y-%m-%d %H:%M:%S'),
+                'update_time': agent.update_time.strftime('%Y-%m-%d %H:%M:%S')
+            }
+            for agent in data
+        ]
+    }
     return wrap_response(200, data=ret_data)
     return wrap_response(200, data=ret_data)
 
 
-
 @app.route("/api/getNativeAgentConfiguration", methods=["GET"])
 @app.route("/api/getNativeAgentConfiguration", methods=["GET"])
 def get_native_agent_configuration():
 def get_native_agent_configuration():
     """
     """
@@ -376,15 +393,14 @@ def get_native_agent_configuration():
             'execution_model': agent.execution_model,
             'execution_model': agent.execution_model,
             'system_prompt': agent.system_prompt,
             'system_prompt': agent.system_prompt,
             'task_prompt': agent.task_prompt,
             'task_prompt': agent.task_prompt,
-            'tools': agent.tools,
-            'sub_agents': agent.sub_agents,
-            'extra_params': agent.extra_params,
+            'tools': json.loads(agent.tools),
+            'sub_agents': json.loads(agent.sub_agents),
+            'extra_params': json.loads(agent.extra_params),
             'create_time': agent.create_time.strftime('%Y-%m-%d %H:%M:%S'),
             'create_time': agent.create_time.strftime('%Y-%m-%d %H:%M:%S'),
             'update_time': agent.update_time.strftime('%Y-%m-%d %H:%M:%S')
             'update_time': agent.update_time.strftime('%Y-%m-%d %H:%M:%S')
         }
         }
         return wrap_response(200, data=data)
         return wrap_response(200, data=data)
 
 
-
 @app.route("/api/saveNativeAgentConfiguration", methods=["POST"])
 @app.route("/api/saveNativeAgentConfiguration", methods=["POST"])
 def save_native_agent_configuration():
 def save_native_agent_configuration():
     """
     """
@@ -399,9 +415,19 @@ def save_native_agent_configuration():
     execution_model = req_data.get('execution_model', None)
     execution_model = req_data.get('execution_model', None)
     system_prompt = req_data.get('system_prompt', None)
     system_prompt = req_data.get('system_prompt', None)
     task_prompt = req_data.get('task_prompt', None)
     task_prompt = req_data.get('task_prompt', None)
-    tools = req_data.get('tools', [])
-    sub_agents = req_data.get('sub_agents', [])
+    tools = json.dumps(req_data.get('tools', []))
+    sub_agents = json.dumps(req_data.get('sub_agents', []))
     extra_params = req_data.get('extra_params', {})
     extra_params = req_data.get('extra_params', {})
+    operate_user = req_data.get('user', None)
+    if isinstance(extra_params, dict):
+        extra_params = json.dumps(extra_params)
+    elif isinstance(extra_params, str):
+        try:
+            json.loads(extra_params)
+        except json.JSONDecodeError:
+            return wrap_response(400, msg='extra_params should be a valid JSON object or string')
+    if not extra_params:
+        extra_params = '{}'
 
 
     if not name:
     if not name:
         return wrap_response(400, msg='name is required')
         return wrap_response(400, msg='name is required')
@@ -421,6 +447,7 @@ def save_native_agent_configuration():
             agent.tools = tools
             agent.tools = tools
             agent.sub_agents = sub_agents
             agent.sub_agents = sub_agents
             agent.extra_params = extra_params
             agent.extra_params = extra_params
+            agent.update_user = operate_user
         else:
         else:
             agent = AgentConfiguration(
             agent = AgentConfiguration(
                 name=name,
                 name=name,
@@ -431,7 +458,9 @@ def save_native_agent_configuration():
                 task_prompt=task_prompt,
                 task_prompt=task_prompt,
                 tools=tools,
                 tools=tools,
                 sub_agents=sub_agents,
                 sub_agents=sub_agents,
-                extra_params=extra_params
+                extra_params=extra_params,
+                create_user=operate_user,
+                update_user=operate_user
             )
             )
             session.add(agent)
             session.add(agent)
 
 
@@ -439,6 +468,35 @@ def save_native_agent_configuration():
         return wrap_response(200, msg='Agent configuration saved successfully', data={'id': agent.id})
         return wrap_response(200, msg='Agent configuration saved successfully', data={'id': agent.id})
 
 
 
 
+
+@app.route("/api/deleteNativeAgentConfiguration", methods=["POST"])
+def delete_native_agent_configuration():
+    """
+    删除指定Agent配置(软删除,设置is_delete=1)
+    :return:
+    """
+    req_data = request.json
+    agent_id = req_data.get('agent_id', None)
+    if not agent_id:
+        return wrap_response(400, msg='agent_id is required')
+    try:
+        agent_id = int(agent_id)
+    except ValueError:
+        return wrap_response(400, msg='agent_id must be an integer')
+
+    with app.session_maker() as session:
+        agent = session.query(AgentConfiguration).filter(
+            AgentConfiguration.id == agent_id,
+            AgentConfiguration.is_delete == 0
+        ).first()
+        if not agent:
+            return wrap_response(404, msg='Agent not found')
+        agent.is_delete = 1
+        session.commit()
+        return wrap_response(200, msg='Agent configuration deleted successfully')
+
+
+
 @app.route("/api/getModuleList", methods=["GET"])
 @app.route("/api/getModuleList", methods=["GET"])
 def get_module_list():
 def get_module_list():
     """
     """
@@ -463,7 +521,6 @@ def get_module_list():
     ]
     ]
     return wrap_response(200, data=ret_data)
     return wrap_response(200, data=ret_data)
 
 
-
 @app.route("/api/getModuleConfiguration", methods=["GET"])
 @app.route("/api/getModuleConfiguration", methods=["GET"])
 def get_module_configuration():
 def get_module_configuration():
     """
     """
@@ -490,7 +547,6 @@ def get_module_configuration():
         }
         }
         return wrap_response(200, data=data)
         return wrap_response(200, data=data)
 
 
-
 @app.route("/api/saveModuleConfiguration", methods=["POST"])
 @app.route("/api/saveModuleConfiguration", methods=["POST"])
 def save_module_configuration():
 def save_module_configuration():
     """
     """
@@ -673,6 +729,33 @@ def get_conversation_data_list():
     return wrap_response(200, data=response)
     return wrap_response(200, data=response)
 
 
 
 
+@app.route("/api/getToolList", methods=["GET"])
+def get_tool_list():
+    """
+    获取所有的工具列表
+    :return:
+    """
+    tools = []
+    for tool_name, tool in global_tool_map.items():
+        tools.append({
+            'name': tool_name,
+            'description': tool.get_function_description(),
+            'parameters': tool.parameters if hasattr(tool, 'parameters') else {}
+        })
+    return wrap_response(200, data=tools)
+
+@app.route("/api/getModuleAgentTypes", methods=["GET"])
+def get_agent_types():
+    """
+    获取所有的Agent类型
+    :return:
+    """
+    agent_types = [
+        {'type': 0, 'display_name': '原生'},
+        {'type': 1, 'display_name': 'Coze'}
+    ]
+    return wrap_response(200, data=agent_types)
+
 @app.errorhandler(werkzeug.exceptions.BadRequest)
 @app.errorhandler(werkzeug.exceptions.BadRequest)
 def handle_bad_request(e):
 def handle_bad_request(e):
     logger.error(e)
     logger.error(e)
@@ -689,7 +772,7 @@ if __name__ == '__main__':
 
 
     config = configs.get()
     config = configs.get()
     logging_level = logging.getLevelName(args.log_level)
     logging_level = logging.getLevelName(args.log_level)
-    logging_service.setup_root_logger(level=logging_level, logfile_name='agent_api_server.log')
+    setup_root_logger(level=logging_level, logfile_name='agent_api_server.log')
 
 
     # set db config
     # set db config
     agent_db_config = config['database']['ai_agent']
     agent_db_config = config['database']['ai_agent']
@@ -700,7 +783,7 @@ if __name__ == '__main__':
     chat_history_db_config = config['storage']['chat_history']
     chat_history_db_config = config['storage']['chat_history']
 
 
     # init user manager
     # init user manager
-    user_manager = MySQLUserManager(agent_db_config, growth_db_config, staff_db_config['table'])
+    user_manager = MySQLUserManager(agent_db_config, user_db_config['table'], staff_db_config['table'])
     app.user_manager = user_manager
     app.user_manager = user_manager
 
 
     # init session manager
     # init session manager

+ 0 - 1
pqai_agent_server/utils/__init__.py

@@ -1,5 +1,4 @@
 from .common import wrap_response
 from .common import wrap_response
-from .common import quit_human_intervention_status
 
 
 from .prompt_util import (
 from .prompt_util import (
     run_openai_chat,
     run_openai_chat,

+ 14 - 3
pqai_agent_server/utils/common.py

@@ -11,7 +11,18 @@ def wrap_response(code, msg=None, data=None):
     return jsonify(resp)
     return jsonify(resp)
 
 
 
 
-def quit_human_intervention_status(user_id, staff_id):
+def quit_human_intervention(user_id, staff_id) -> bool:
     url = f"http://ai-wechat-hook-internal.piaoquantv.com/manage/insertEvent?sender={user_id}&receiver={staff_id}&type=103&content=SYSTEM"
     url = f"http://ai-wechat-hook-internal.piaoquantv.com/manage/insertEvent?sender={user_id}&receiver={staff_id}&type=103&content=SYSTEM"
-    response = requests.get(url, timeout=20)
-    return response.json()
+    response = requests.post(url, timeout=20, headers={"Content-Type": "application/json"})
+    if response.status_code == 200 and response.json().get("code") == 0:
+        return True
+    else:
+        return False
+
+def enter_human_intervention(user_id, staff_id) -> bool:
+    url = f"http://ai-wechat-hook-internal.piaoquantv.com/manage/insertEvent?sender={user_id}&receiver={staff_id}&type=104&content=SYSTEM"
+    response = requests.post(url, timeout=20, headers={"Content-Type": "application/json"})
+    if response.status_code == 200 and response.json().get("code") == 0:
+        return True
+    else:
+        return False

+ 3 - 4
pqai_agent_server/utils/prompt_util.py

@@ -5,15 +5,14 @@ from typing import List, Dict
 
 
 from openai import OpenAI
 from openai import OpenAI
 
 
-from pqai_agent import logging_service, chat_service
+from pqai_agent import chat_service
+from pqai_agent.logging import logger
 from pqai_agent.response_type_detector import ResponseTypeDetector
 from pqai_agent.response_type_detector import ResponseTypeDetector
 from pqai_agent.user_profile_extractor import UserProfileExtractor
 from pqai_agent.user_profile_extractor import UserProfileExtractor
 from pqai_agent.dialogue_manager import DialogueManager
 from pqai_agent.dialogue_manager import DialogueManager
 from pqai_agent.mq_message import MessageType
 from pqai_agent.mq_message import MessageType
 from pqai_agent.utils.prompt_utils import format_agent_profile
 from pqai_agent.utils.prompt_utils import format_agent_profile
 
 
-logger = logging_service.logger
-
 
 
 def compose_openai_chat_messages_no_time(dialogue_history, multimodal=False):
 def compose_openai_chat_messages_no_time(dialogue_history, multimodal=False):
     messages = []
     messages = []
@@ -44,7 +43,7 @@ def compose_openai_chat_messages_no_time(dialogue_history, multimodal=False):
 def create_llm_client(model_name):
 def create_llm_client(model_name):
     volcengine_models = [
     volcengine_models = [
         chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_32K,
         chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_32K,
-        chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_1_5,
+        chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_1_5_32K,
         chat_service.VOLCENGINE_MODEL_DOUBAO_1_5_VISION_PRO,
         chat_service.VOLCENGINE_MODEL_DOUBAO_1_5_VISION_PRO,
         chat_service.VOLCENGINE_MODEL_DEEPSEEK_V3,
         chat_service.VOLCENGINE_MODEL_DEEPSEEK_V3,
     ]
     ]

+ 43 - 0
scripts/extract_push_action_logs.py

@@ -0,0 +1,43 @@
+import re
+
+def extract_agent_run_steps(log_path):
+    pattern = re.compile(
+        r'^\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2},\d{3} - agent run\[\d+\] - DEBUG - current step content:'
+    )
+    timestamp_pattern = re.compile(r'^\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2},\d{3} - ')
+    results = []
+    current = []
+    collecting = False
+
+    with open(log_path, 'r', encoding='utf-8') as f:
+        for line in f:
+            if pattern.match(line):
+                if collecting and current:
+                    results.append(''.join(current).rstrip())
+                    current = []
+                collecting = True
+                current.append(line)
+            elif collecting:
+                if timestamp_pattern.match(line):
+                    results.append(''.join(current).rstrip())
+                    current = []
+                    collecting = False
+                else:
+                    current.append(line)
+        # 文件结尾处理
+        if collecting and current:
+            results.append(''.join(current).rstrip())
+    return results
+
+if __name__ == "__main__":
+    import sys
+    if len(sys.argv) != 2:
+        print("Usage: python extract_agent_run_step.py <logfile>")
+        sys.exit(1)
+    log_file = sys.argv[1]
+    steps = extract_agent_run_steps(log_file)
+    for i, step in enumerate(steps, 1):
+        print(f"--- Step {i} ---")
+        print(step)
+        print()
+

+ 40 - 3
tests/unit_test.py

@@ -4,8 +4,12 @@
 
 
 import pytest
 import pytest
 from unittest.mock import Mock, MagicMock
 from unittest.mock import Mock, MagicMock
-from pqai_agent.agent_service import AgentService, MemoryQueueBackend
+
+import pqai_agent.abtest.client
+import pqai_agent.configs
+from pqai_agent.agent_service import AgentService
 from pqai_agent.dialogue_manager import DialogueState, TimeContext
 from pqai_agent.dialogue_manager import DialogueState, TimeContext
+from pqai_agent.message_queue_backend import MemoryQueueBackend
 from pqai_agent.mq_message import MessageType, MqMessage, MessageChannel
 from pqai_agent.mq_message import MessageType, MqMessage, MessageChannel
 from pqai_agent.response_type_detector import ResponseTypeDetector
 from pqai_agent.response_type_detector import ResponseTypeDetector
 from pqai_agent.user_manager import LocalUserManager
 from pqai_agent.user_manager import LocalUserManager
@@ -44,11 +48,16 @@ def test_env():
         user_relation_manager=user_relation_manager
         user_relation_manager=user_relation_manager
     )
     )
     service.user_profile_extractor.extract_profile_info = Mock(return_value=None)
     service.user_profile_extractor.extract_profile_info = Mock(return_value=None)
+    service.can_send_to_user = Mock(return_value=True)
+    service.start()
 
 
     # 替换LLM调用为模拟响应
     # 替换LLM调用为模拟响应
     service._call_chat_api = Mock(return_value="模拟响应")
     service._call_chat_api = Mock(return_value="模拟响应")
 
 
-    return service, queues
+    yield service, queues
+
+    service.shutdown(sync=True)
+    pqai_agent.abtest.client.get_client().shutdown()
 
 
 def test_agent_state_change(test_env):
 def test_agent_state_change(test_env):
     service, _ = test_env
     service, _ = test_env
@@ -220,10 +229,38 @@ def test_initiative_conversation(test_env):
 def test_response_type_detector(test_env):
 def test_response_type_detector(test_env):
     case1 = '大哥,那可得提前了解下天气,以便安排行程~我帮您查查明天北京天气?'
     case1 = '大哥,那可得提前了解下天气,以便安排行程~我帮您查查明天北京天气?'
     assert ResponseTypeDetector.is_chinese_only(case1) == True
     assert ResponseTypeDetector.is_chinese_only(case1) == True
-    assert ResponseTypeDetector.if_message_suitable_for_voice(case1) == False
+    assert ResponseTypeDetector.if_message_suitable_for_voice(case1) == True
     case2 = 'hi'
     case2 = 'hi'
     assert ResponseTypeDetector.is_chinese_only(case2) == False
     assert ResponseTypeDetector.is_chinese_only(case2) == False
     case3 = '这是链接:http://domain.com'
     case3 = '这是链接:http://domain.com'
     assert ResponseTypeDetector.is_chinese_only(case3) == False
     assert ResponseTypeDetector.is_chinese_only(case3) == False
     case4 = '大哥,那可得提前了解下天气'
     case4 = '大哥,那可得提前了解下天气'
     assert ResponseTypeDetector.if_message_suitable_for_voice(case4) == True
     assert ResponseTypeDetector.if_message_suitable_for_voice(case4) == True
+
+    global_config = pqai_agent.configs.get()
+    global_config.get('debug_flags', {}).update({'disable_llm_api_call': False})
+
+    response_detector = ResponseTypeDetector()
+    dialogue1 = [
+        {'role': 'user', 'content': '你好', 'timestamp': 1744979571000, 'type': MessageType.TEXT},
+        {'role': 'assistant', 'content': '你好呀', 'timestamp': 1744979581000},
+    ]
+    assert response_detector.detect_type(dialogue1[:-1], dialogue1[-1]) == MessageType.TEXT
+
+    dialogue2 = [
+        {'role': 'user', 'content': '你可以读一个故事给我听吗', 'timestamp': 1744979591000},
+        {'role': 'assistant', 'content': '当然可以啦!想听什么?', 'timestamp': 1744979601000},
+        {'role': 'user', 'content': '我想听小王子', 'timestamp': 1744979611000},
+        {'role': 'assistant', 'content': '《小王子》讲述了一位年轻王子离开自己的小世界去探索宇宙的冒险经历。 在旅途中,他遇到了各种各样的人,包括被困的飞行员、狐狸和聪明的蛇。 王子通过这些遭遇学到了关于爱情、友谊和超越表面的必要性的重要教训。', 'timestamp': 1744979611000},
+    ]
+    assert response_detector.detect_type(dialogue2[:-1], dialogue2[-1]) == MessageType.VOICE
+
+    dialogue3 = [
+        {'role': 'user', 'content': '他说的是西洋参呢,晓不得到底是不是西洋参。那个样,那个茶是抽的真空的紧包包。我泡他两包,两包泡到十几盒,13盒,我还拿回来的。', 'timestamp': 1744979591000},
+        {'role': 'assistant', 'content': '咋啦?是突然想到啥啦,还是有其他事想和我分享分享?', 'timestamp': 1744979601000},
+        {'role': 'user', 'content': '不要打字,还不要打。听不到。不要打字,不要打字,打字我认不到。打字我认不到,不要打字不要打字,打字我认不到。', 'timestamp': 1744979611000},
+        {'role': 'assistant', 'content': '真是不好意思', 'timestamp': 1744979611000},
+    ]
+    assert response_detector.detect_type(dialogue3[:-1], dialogue3[-1]) == MessageType.VOICE
+
+    global_config.get('debug_flags', {}).update({'disable_llm_api_call': True})