Bläddra i källkod

Merge branch 'master' into dev-xym-add-test-task

# Conflicts:
#	pqai_agent_server/api_server.py
xueyiming 9 timmar sedan
förälder
incheckning
b4b7f6f14d
40 ändrade filer med 582 tillägg och 210 borttagningar
  1. 2 2
      pqai_agent/abtest/client.py
  2. 1 1
      pqai_agent/abtest/models.py
  3. 1 1
      pqai_agent/agent_config_manager.py
  4. 31 12
      pqai_agent/agent_service.py
  5. 1 1
      pqai_agent/agents/message_push_agent.py
  6. 1 1
      pqai_agent/agents/message_reply_agent.py
  7. 6 2
      pqai_agent/agents/multimodal_chat_agent.py
  8. 22 5
      pqai_agent/agents/simple_chat_agent.py
  9. 96 24
      pqai_agent/chat_service.py
  10. 1 1
      pqai_agent/clients/relation_stage_client.py
  11. 11 6
      pqai_agent/configs/dev.yaml
  12. 7 4
      pqai_agent/configs/prod.yaml
  13. 2 2
      pqai_agent/data_models/agent_configuration.py
  14. 2 2
      pqai_agent/data_models/agent_push_record.py
  15. 1 1
      pqai_agent/data_models/service_module.py
  16. 1 1
      pqai_agent/database.py
  17. 1 1
      pqai_agent/dialogue_manager.py
  18. 1 1
      pqai_agent/history_dialogue_service.py
  19. 9 7
      pqai_agent/logging.py
  20. 2 3
      pqai_agent/message_queue_backend.py
  21. 1 0
      pqai_agent/prompt_templates.py
  22. 64 45
      pqai_agent/push_service.py
  23. 1 1
      pqai_agent/rate_limiter.py
  24. 2 2
      pqai_agent/response_type_detector.py
  25. 1 1
      pqai_agent/service_module_manager.py
  26. 1 1
      pqai_agent/toolkit/__init__.py
  27. 1 1
      pqai_agent/toolkit/function_tool.py
  28. 1 1
      pqai_agent/toolkit/image_describer.py
  29. 1 1
      pqai_agent/toolkit/lark_sheet_record_for_human_intervention.py
  30. 1 1
      pqai_agent/toolkit/message_notifier.py
  31. 51 3
      pqai_agent/user_manager.py
  32. 3 3
      pqai_agent/user_profile_extractor.py
  33. 8 0
      pqai_agent/utils/__init__.py
  34. 7 4
      pqai_agent_server/agent_server.py
  35. 140 57
      pqai_agent_server/api_server.py
  36. 0 1
      pqai_agent_server/utils/__init__.py
  37. 14 3
      pqai_agent_server/utils/common.py
  38. 3 4
      pqai_agent_server/utils/prompt_util.py
  39. 43 0
      scripts/extract_push_action_logs.py
  40. 40 3
      tests/unit_test.py

+ 2 - 2
pqai_agent/abtest/client.py

@@ -6,7 +6,7 @@ from pqai_agent.abtest.models import Project, Domain, Layer, Experiment, Experim
     ExperimentContext, ExperimentResult
 from alibabacloud_paiabtest20240119.models import ListProjectsRequest, ListProjectsResponseBodyProjects, \
     ListDomainsRequest, ListFeaturesRequest, ListLayersRequest, ListExperimentsRequest, ListExperimentVersionsRequest
-from pqai_agent.logging_service import logger
+from pqai_agent.logging import logger
 
 class ExperimentClient:
     def __init__(self, client: Client):
@@ -267,7 +267,7 @@ def get_client():
     return g_client
 
 if __name__ == '__main__':
-    from pqai_agent.logging_service import setup_root_logger
+    from pqai_agent.logging import setup_root_logger
     setup_root_logger(level='DEBUG')
     experiment_client = get_client()
 

+ 1 - 1
pqai_agent/abtest/models.py

@@ -3,7 +3,7 @@ import json
 from dataclasses import dataclass, field
 import hashlib
 
-from pqai_agent.logging_service import logger
+from pqai_agent.logging import logger
 
 
 class FNV:

+ 1 - 1
pqai_agent/agent_config_manager.py

@@ -1,7 +1,7 @@
 from typing import Dict, Optional
 
 from pqai_agent.data_models.agent_configuration import AgentConfiguration
-from pqai_agent.logging_service import logger
+from pqai_agent.logging import logger
 
 class AgentConfigManager:
     def __init__(self, session_maker):

+ 31 - 12
pqai_agent/agent_service.py

@@ -1,7 +1,7 @@
 #! /usr/bin/env python
 # -*- coding: utf-8 -*-
 # vim:fenc=utf-8
-
+import json
 import re
 import signal
 import sys
@@ -15,15 +15,16 @@ import traceback
 import apscheduler.triggers.cron
 import rocketmq
 from apscheduler.schedulers.background import BackgroundScheduler
+from rocketmq import FilterExpression
 from sqlalchemy.orm import sessionmaker
 
-from pqai_agent import configs
+from pqai_agent import configs, push_service
 from pqai_agent.abtest.utils import get_abtest_info
 from pqai_agent.agent_config_manager import AgentConfigManager
 from pqai_agent.agents.message_reply_agent import MessageReplyAgent
 from pqai_agent.configs import apollo_config
 from pqai_agent.exceptions import NoRetryException
-from pqai_agent.logging_service import logger
+from pqai_agent.logging import logger
 from pqai_agent import chat_service
 from pqai_agent.chat_service import CozeChat, ChatServiceType
 from pqai_agent.dialogue_manager import DialogueManager, DialogueState, DialogueStateCache
@@ -96,7 +97,8 @@ class AgentService:
 
         # Push相关
         self.push_task_producer = None
-        self.push_task_consumer = None
+        self.push_generate_task_consumer = None
+        self.push_send_task_consumer = None
         self._init_push_task_queue()
         self.next_push_disabled = True
         self._resume_unfinished_push_task()
@@ -344,7 +346,7 @@ class AgentService:
             logger.debug(f"staff[{staff_id}], user[{user_id}]: no messages to send")
 
     def can_send_to_user(self, staff_id, user_id) -> bool:
-        user_tags = self.user_relation_manager.get_user_tags(user_id)
+        user_tags = self.user_manager.get_user_tags([user_id]).get(user_id, [])
         white_list_tags = set(apollo_config.get_json_value("agent_response_whitelist_tags", []))
         hit_white_list_tags = len(set(user_tags).intersection(white_list_tags)) > 0
         staff_white_lists = set(apollo_config.get_json_value("agent_response_whitelist_staffs", []))
@@ -384,20 +386,31 @@ class AgentService:
         mq_conf = configs.get()['mq']
         rmq_client_conf = rocketmq.ClientConfiguration(mq_conf['endpoints'], credentials, mq_conf['instance_id'])
         rmq_topic = mq_conf['push_tasks_topic']
-        rmq_group = mq_conf['push_tasks_group']
+        rmq_group_generate = mq_conf['push_generate_task_group']
+        rmq_group_send = mq_conf['push_send_task_group']
         self.push_task_rmq_topic = rmq_topic
         self.push_task_producer = rocketmq.Producer(rmq_client_conf, (rmq_topic,))
         self.push_task_producer.startup()
-        self.push_task_consumer = rocketmq.SimpleConsumer(rmq_client_conf, rmq_group, await_duration=5)
-        self.push_task_consumer.startup()
-        self.push_task_consumer.subscribe(rmq_topic)
+        # FIXME: 不应该暴露到agent service中
+        self.push_generate_task_consumer = rocketmq.SimpleConsumer(rmq_client_conf, rmq_group_generate, await_duration=5)
+        self.push_generate_task_consumer.startup()
+        self.push_generate_task_consumer.subscribe(
+            rmq_topic, filter_expression=FilterExpression(push_service.TaskType.GENERATE.value)
+        )
+        self.push_send_task_consumer = rocketmq.SimpleConsumer(rmq_client_conf, rmq_group_send, await_duration=5)
+        self.push_send_task_consumer.startup()
+        self.push_send_task_consumer.subscribe(
+            rmq_topic, filter_expression=FilterExpression(push_service.TaskType.SEND.value)
+        )
 
 
     def _resume_unfinished_push_task(self):
         def run_unfinished_push_task():
             logger.info("start to resume unfinished push task")
             push_task_worker_pool = PushTaskWorkerPool(
-                self, self.push_task_rmq_topic, self.push_task_consumer, self.push_task_producer)
+                self, self.push_task_rmq_topic, self.push_generate_task_consumer,
+                self.push_send_task_consumer, self.push_task_producer
+            )
             push_task_worker_pool.start()
             push_task_worker_pool.wait_to_finish()
             self.next_push_disabled = False
@@ -427,7 +440,8 @@ class AgentService:
             push_scan_threads.append(scan_thread)
 
         push_task_worker_pool = PushTaskWorkerPool(
-            self, self.push_task_rmq_topic, self.push_task_consumer, self.push_task_producer)
+            self, self.push_task_rmq_topic,
+            self.push_generate_task_consumer, self.push_send_task_consumer, self.push_task_producer)
         push_task_worker_pool.start()
         for thread in push_scan_threads:
             thread.join()
@@ -461,9 +475,14 @@ class AgentService:
         agent_config = get_agent_abtest_config('chat', main_agent.user_id,
                                                self.service_module_manager, self.agent_config_manager)
         if agent_config:
+            try:
+                tool_names = json.loads(agent_config.tools)
+            except json.JSONDecodeError:
+                logger.error(f"Invalid JSON in agent tools: {agent_config.tools}")
+                tool_names = []
             chat_agent = MessageReplyAgent(model=agent_config.execution_model,
                                            system_prompt=agent_config.system_prompt,
-                                           tools=get_tools(agent_config.tools))
+                                           tools=get_tools(tool_names))
         else:
             chat_agent = MessageReplyAgent()
         chat_responses = chat_agent.generate_message(

+ 1 - 1
pqai_agent/agents/message_push_agent.py

@@ -2,7 +2,7 @@ from typing import Optional, List, Dict
 
 from pqai_agent.agents.multimodal_chat_agent import MultiModalChatAgent
 from pqai_agent.chat_service import VOLCENGINE_MODEL_DEEPSEEK_V3
-from pqai_agent.logging_service import logger
+from pqai_agent.logging import logger
 from pqai_agent.toolkit.function_tool import FunctionTool
 from pqai_agent.toolkit.image_describer import ImageDescriber
 from pqai_agent.toolkit.message_notifier import MessageNotifier

+ 1 - 1
pqai_agent/agents/message_reply_agent.py

@@ -2,7 +2,7 @@ from typing import Optional, List, Dict
 
 from pqai_agent.agents.multimodal_chat_agent import MultiModalChatAgent
 from pqai_agent.chat_service import VOLCENGINE_MODEL_DEEPSEEK_V3
-from pqai_agent.logging_service import logger
+from pqai_agent.logging import logger
 from pqai_agent.toolkit.function_tool import FunctionTool
 from pqai_agent.toolkit.image_describer import ImageDescriber
 from pqai_agent.toolkit.message_notifier import MessageNotifier

+ 6 - 2
pqai_agent/agents/multimodal_chat_agent.py

@@ -2,8 +2,9 @@ import datetime
 from abc import abstractmethod
 from typing import Optional, List, Dict
 
+from pqai_agent import configs
 from pqai_agent.agents.simple_chat_agent import SimpleOpenAICompatibleChatAgent
-from pqai_agent.logging_service import logger
+from pqai_agent.logging import logger
 from pqai_agent.mq_message import MessageType
 from pqai_agent.toolkit import get_tool
 from pqai_agent.toolkit.function_tool import FunctionTool
@@ -28,7 +29,10 @@ class MultiModalChatAgent(SimpleOpenAICompatibleChatAgent):
         pass
 
     def _generate_message(self, context: Dict, dialogue_history: List[Dict],
-                         query_prompt_template: str) -> List[Dict]:
+                          query_prompt_template: str) -> List[Dict]:
+        if configs.get().get('debug_flags', {}).get('disable_llm_api_call', False):
+            return [{'type': 'text', 'content': '测试消息 -> {nickname}'.format(**context)}]
+
         formatted_dialogue = MultiModalChatAgent.compose_dialogue(dialogue_history)
         query = query_prompt_template.format(**context, dialogue_history=formatted_dialogue)
         self.run(query)

+ 22 - 5
pqai_agent/agents/simple_chat_agent.py

@@ -1,9 +1,10 @@
 import json
 from typing import List, Optional
 
+import pqai_agent.utils
 from pqai_agent.agent import DEFAULT_MAX_RUN_STEPS
 from pqai_agent.chat_service import OpenAICompatible
-from pqai_agent.logging_service import logger
+from pqai_agent.logging import logger
 from pqai_agent.toolkit.function_tool import FunctionTool
 
 
@@ -23,6 +24,8 @@ class SimpleOpenAICompatibleChatAgent:
         self.generate_cfg = generate_cfg or {}
         self.max_run_step = max_run_step or DEFAULT_MAX_RUN_STEPS
         self.tool_call_records = []
+        self.total_input_tokens = 0
+        self.total_output_tokens = 0
         logger.debug(self.tool_map)
 
     def add_tool(self, tool: FunctionTool):
@@ -33,23 +36,26 @@ class SimpleOpenAICompatibleChatAgent:
         self.tool_map[tool.name] = tool
 
     def run(self, user_input: str) -> str:
+        run_id = pqai_agent.utils.random_str()[:12]
         messages = [{"role": "system", "content": self.system_prompt}]
         tools = [tool.get_openai_tool_schema() for tool in self.tools]
         messages.append({"role": "user", "content": user_input})
 
         n_steps = 0
-        logger.debug(f"start agent loop. messages: {messages}")
+        logger.debug(f"run_id[{run_id}] start agent loop. messages: {messages}")
         while n_steps < self.max_run_step:
             response = self.llm_client.chat.completions.create(model=self.model, messages=messages, tools=tools, **self.generate_cfg)
             message = response.choices[0].message
+            self.total_input_tokens += response.usage.prompt_tokens
+            self.total_output_tokens += response.usage.completion_tokens
             messages.append(message)
-            logger.debug(f"current step content: {message.content}")
+            logger.debug(f"run_id[{run_id}] current step content: {message.content}")
 
             if message.tool_calls:
                 for tool_call in message.tool_calls:
                     function_name = tool_call.function.name
                     arguments = json.loads(tool_call.function.arguments)
-                    logger.debug(f"call function[{function_name}], parameter: {arguments}")
+                    logger.debug(f"run_id[{run_id}] call function[{function_name}], parameter: {arguments}")
 
                     if function_name in self.tool_map:
                         result = self.tool_map[function_name](**arguments)
@@ -64,10 +70,21 @@ class SimpleOpenAICompatibleChatAgent:
                             "result": result
                         })
                     else:
-                        logger.error(f"Function {function_name} not found in tool map.")
+                        logger.error(f"run_id[{run_id}] Function {function_name} not found in tool map.")
                         raise Exception(f"Function {function_name} not found in tool map.")
             else:
                 return message.content
             n_steps += 1
 
         raise Exception("Max run steps exceeded")
+
+    def get_total_input_tokens(self) -> int:
+        """获取总输入token数"""
+        return self.total_input_tokens
+
+    def get_total_output_tokens(self) -> int:
+        """获取总输出token数"""
+        return self.total_output_tokens
+
+    def get_total_cost(self) -> float:
+        return OpenAICompatible.calculate_cost(self.model, self.total_input_tokens, self.total_output_tokens)

+ 96 - 24
pqai_agent/chat_service.py

@@ -11,7 +11,7 @@ from enum import Enum, auto
 import httpx
 
 from pqai_agent import configs
-from pqai_agent.logging_service import logger
+from pqai_agent.logging import logger
 import cozepy
 from cozepy import Coze, TokenAuth, Message, ChatStatus, MessageType, JWTOAuthApp, JWTAuth
 import time
@@ -22,9 +22,9 @@ COZE_CN_BASE_URL = 'https://api.coze.cn'
 VOLCENGINE_API_TOKEN = '5e275c38-44fd-415f-abcf-4b59f6377f72'
 VOLCENGINE_BASE_URL = "https://ark.cn-beijing.volces.com/api/v3"
 VOLCENGINE_MODEL_DEEPSEEK_V3 = "deepseek-v3-250324"
-VOLCENGINE_MODEL_DOUBAO_PRO_1_5 = 'ep-20250307150409-4blz9'
-VOLCENGINE_MODEL_DOUBAO_PRO_32K = 'ep-20250414202859-6nkz5'
-VOLCENGINE_MODEL_DOUBAO_1_5_VISION_PRO = 'ep-20250421193334-nz5wd'
+VOLCENGINE_MODEL_DOUBAO_PRO_1_5_32K = 'doubao-1-5-pro-32k-250115'
+VOLCENGINE_MODEL_DOUBAO_PRO_32K = 'doubao-pro-32k-241215'
+VOLCENGINE_MODEL_DOUBAO_1_5_VISION_PRO = 'doubao-1-5-vision-pro-32k-250115'
 DEEPSEEK_API_TOKEN = 'sk-67daad8f424f4854bda7f1fed7ef220b'
 DEEPSEEK_BASE_URL = 'https://api.deepseek.com/'
 DEEPSEEK_CHAT_MODEL = 'deepseek-chat'
@@ -37,35 +37,81 @@ OPENAI_MODEL_GPT_4o_mini = 'gpt-4o-mini'
 OPENROUTER_API_TOKEN = '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'
 OPENROUTER_BASE_URL = 'https://openrouter.ai/api/v1/'
 OPENROUTER_MODEL_CLAUDE_3_7_SONNET = 'anthropic/claude-3.7-sonnet'
+ALIYUN_API_TOKEN = 'sk-47381479425f4485af7673d3d2fd92b6'
+ALIYUN_BASE_URL = 'https://dashscope.aliyuncs.com/compatible-mode/v1'
+
 
 class ChatServiceType(Enum):
     OPENAI_COMPATIBLE = auto()
     COZE_CHAT = auto()
 
+class ModelPrice:
+    EXCHANGE_RATE_TO_CNY = {
+        "USD": 7.2,  # Example conversion rate, adjust as needed
+    }
+
+    def __init__(self, input_price: float, output_price: float, currency: str = 'CNY'):
+        """
+        :param input_price: input price for per million tokens
+        :param output_price: output price for per million tokens
+        """
+        self.input_price = input_price
+        self.output_price = output_price
+        self.currency = currency
+
+    def get_total_cost(self, input_tokens: int, output_tokens: int, convert_to_cny: bool = True) -> float:
+        """
+        Calculate the total cost based on input and output tokens.
+        :param input_tokens: Number of input tokens
+        :param output_tokens: Number of output tokens
+        :param convert_to_cny: Whether to convert the cost to CNY (default is True)
+        :return: Total cost in the specified currency
+        """
+        total_cost = (self.input_price * input_tokens / 1_000_000) + (self.output_price * output_tokens / 1_000_000)
+        if convert_to_cny and self.currency != 'CNY':
+            conversion_rate = self.EXCHANGE_RATE_TO_CNY.get(self.currency, 1.0)
+            total_cost *= conversion_rate
+        return total_cost
+
+    def __repr__(self):
+        return f"ModelPrice(input_price={self.input_price}, output_price={self.output_price}, currency={self.currency})"
+
 class OpenAICompatible:
+    volcengine_models = [
+        VOLCENGINE_MODEL_DOUBAO_PRO_32K,
+        VOLCENGINE_MODEL_DOUBAO_PRO_1_5_32K,
+        VOLCENGINE_MODEL_DOUBAO_1_5_VISION_PRO,
+        VOLCENGINE_MODEL_DEEPSEEK_V3
+    ]
+    deepseek_models = [
+        DEEPSEEK_CHAT_MODEL,
+    ]
+    openai_models = [
+        OPENAI_MODEL_GPT_4o_mini,
+        OPENAI_MODEL_GPT_4o
+    ]
+    openrouter_models = [
+        OPENROUTER_MODEL_CLAUDE_3_7_SONNET,
+    ]
+
+    model_prices = {
+        VOLCENGINE_MODEL_DEEPSEEK_V3: ModelPrice(input_price=2, output_price=8),
+        VOLCENGINE_MODEL_DOUBAO_PRO_32K: ModelPrice(input_price=0.8, output_price=2),
+        VOLCENGINE_MODEL_DOUBAO_PRO_1_5_32K: ModelPrice(input_price=0.8, output_price=2),
+        VOLCENGINE_MODEL_DOUBAO_1_5_VISION_PRO: ModelPrice(input_price=3, output_price=9),
+        DEEPSEEK_CHAT_MODEL: ModelPrice(input_price=2, output_price=8),
+        OPENAI_MODEL_GPT_4o: ModelPrice(input_price=2.5, output_price=10, currency='USD'),
+        OPENAI_MODEL_GPT_4o_mini: ModelPrice(input_price=0.15, output_price=0.6, currency='USD'),
+        OPENROUTER_MODEL_CLAUDE_3_7_SONNET: ModelPrice(input_price=3, output_price=15, currency='USD'),
+    }
+
     @staticmethod
     def create_client(model_name, **kwargs) -> OpenAI:
-        volcengine_models = [
-            VOLCENGINE_MODEL_DOUBAO_PRO_32K,
-            VOLCENGINE_MODEL_DOUBAO_PRO_1_5,
-            VOLCENGINE_MODEL_DOUBAO_1_5_VISION_PRO,
-            VOLCENGINE_MODEL_DEEPSEEK_V3
-        ]
-        deepseek_models = [
-            DEEPSEEK_CHAT_MODEL,
-        ]
-        openai_models = [
-            OPENAI_MODEL_GPT_4o_mini,
-            OPENAI_MODEL_GPT_4o
-        ]
-        openrouter_models = [
-            OPENROUTER_MODEL_CLAUDE_3_7_SONNET,
-        ]
-        if model_name in volcengine_models:
+        if model_name in OpenAICompatible.volcengine_models:
             llm_client = OpenAI(api_key=VOLCENGINE_API_TOKEN, base_url=VOLCENGINE_BASE_URL, **kwargs)
-        elif model_name in deepseek_models:
+        elif model_name in OpenAICompatible.deepseek_models:
             llm_client = OpenAI(api_key=DEEPSEEK_API_TOKEN, base_url=DEEPSEEK_BASE_URL, **kwargs)
-        elif model_name in openai_models:
+        elif model_name in OpenAICompatible.openai_models:
             socks_conf = configs.get().get('system', {}).get('outside_proxy', {}).get('socks5', {})
             if socks_conf:
                 http_client = httpx.Client(
@@ -74,12 +120,38 @@ class OpenAICompatible:
                 )
                 kwargs['http_client'] = http_client
             llm_client = OpenAI(api_key=OPENAI_API_TOKEN, base_url=OPENAI_BASE_URL, **kwargs)
-        elif model_name in openrouter_models:
+        elif model_name in OpenAICompatible.openrouter_models:
             llm_client = OpenAI(api_key=OPENROUTER_API_TOKEN, base_url=OPENROUTER_BASE_URL, **kwargs)
         else:
             raise Exception("Unsupported model: %s" % model_name)
         return llm_client
 
+    @staticmethod
+    def get_price(model_name: str) -> ModelPrice:
+        """
+        Get the price for a given model.
+        :param model_name: Name of the model
+        :return: ModelPrice object containing input and output prices
+        """
+        if model_name not in OpenAICompatible.model_prices:
+            raise ValueError(f"Model {model_name} not found in price list.")
+        return OpenAICompatible.model_prices[model_name]
+
+    @staticmethod
+    def calculate_cost(model_name: str, input_tokens: int, output_tokens: int, convert_to_cny: bool = True) -> float:
+        """
+        Calculate the cost for a given model based on input and output tokens.
+        :param model_name: Name of the model
+        :param input_tokens: Number of input tokens
+        :param output_tokens: Number of output tokens
+        :param convert_to_cny: Whether to convert the cost to CNY (default is True)
+        :return: Total cost in the model's currency
+        """
+        if model_name not in OpenAICompatible.model_prices:
+            raise ValueError(f"Model {model_name} not found in price list.")
+        price = OpenAICompatible.model_prices[model_name]
+        return price.get_total_cost(input_tokens, output_tokens, convert_to_cny)
+
 class CrossAccountJWTOAuthApp(JWTOAuthApp):
     def __init__(self, account_id: str, client_id: str, private_key: str, public_key_id: str, base_url):
         self.account_id = account_id

+ 1 - 1
pqai_agent/clients/relation_stage_client.py

@@ -2,7 +2,7 @@ from typing import Optional
 
 import requests
 
-from pqai_agent.logging_service import logger
+from pqai_agent.logging import logger
 
 class RelationStageClient:
     UNKNOWN_RELATION_STAGE = '未知'

+ 11 - 6
pqai_agent/configs/dev.yaml

@@ -23,6 +23,9 @@ storage:
   staff:
     database: ai_agent
     table: qywx_employee
+  agent_user_relation:
+    database: ai_agent
+    table: qywx_employee_customer
   user_relation:
     database: growth
     table:
@@ -51,8 +54,8 @@ chat_api:
     private_key_path: oauth/coze_privkey.pem
     account_id: 649175100044793
   openai_compatible:
-    text_model: ep-20250414202859-6nkz5
-    multimodal_model: ep-20250421193334-nz5wd
+    text_model: doubao-pro-32k-241215
+    multimodal_model: doubao-1-5-vision-pro-32k-250115
 
 system:
   outside_proxy:
@@ -62,11 +65,12 @@ system:
   scheduler_mode: local
   human_intervention_alert_url: https://open.feishu.cn/open-apis/bot/v2/hook/379fcd1a-0fed-4e58-8cd0-40b6d1895721
   max_reply_workers: 2
-  max_push_workers: 1
-  chat_agent_version: 1
+  push_task_workers: 1
+  chat_agent_version: 2
+  log_dir: .
 
 debug_flags:
-  disable_llm_api_call: True
+  disable_llm_api_call: False
   use_local_user_storage: True
   console_input: True
   disable_active_conversation: True
@@ -82,4 +86,5 @@ mq:
   scheduler_topic: agent_scheduler_event_dev
   scheduler_group: agent_scheduler_event_dev
   push_tasks_topic: agent_push_tasks_dev
-  push_tasks_group: agent_push_tasks_dev
+  push_send_task_group: agent_push_tasks_dev
+  push_generate_task_group: agent_push_generate_task_dev

+ 7 - 4
pqai_agent/configs/prod.yaml

@@ -37,7 +37,7 @@ storage:
     table: qywx_chat_history
   push_record:
     database: ai_agent
-    table: agent_push_record_dev
+    table: agent_push_record
 
 chat_api:
   coze:
@@ -46,8 +46,8 @@ chat_api:
     private_key_path: oauth/coze_privkey.pem
     account_id: 649175100044793
   openai_compatible:
-    text_model: ep-20250414202859-6nkz5
-    multimodal_model: ep-20250421193334-nz5wd
+    text_model: doubao-pro-32k-241215
+    multimodal_model: doubao-1-5-vision-pro-32k-250115
 
 system:
   outside_proxy:
@@ -57,6 +57,8 @@ system:
   scheduler_mode: mq
   human_intervention_alert_url: https://open.feishu.cn/open-apis/bot/v2/hook/c316b559-1c6a-4c4e-97c9-50b44e4c2a9d
   max_reply_workers: 5
+  push_task_workers: 5
+  log_dir: /var/log/agent_service
 
 agent_behavior:
   message_aggregation_sec: 20
@@ -79,4 +81,5 @@ mq:
   scheduler_topic: agent_scheduler_event
   scheduler_group: agent_scheduler_event
   push_tasks_topic: agent_push_tasks
-  push_tasks_group: agent_push_tasks
+  push_send_task_group: agent_push_tasks
+  push_generate_task_group: agent_push_generate_task

+ 2 - 2
pqai_agent/data_models/agent_configuration.py

@@ -1,7 +1,7 @@
 from enum import Enum
 
 from sqlalchemy import Column, Integer, Text, BigInteger, String, SmallInteger, Boolean, TIMESTAMP
-from sqlalchemy.ext.declarative import declarative_base
+from sqlalchemy.orm import declarative_base
 
 Base = declarative_base()
 
@@ -27,4 +27,4 @@ class AgentConfiguration(Base):
     create_user = Column(String(32), nullable=True, comment="创建用户")
     update_user = Column(String(32), nullable=True, comment="更新用户")
     create_time = Column(TIMESTAMP, nullable=True, server_default="CURRENT_TIMESTAMP", comment="创建时间")
-    update_time = Column(TIMESTAMP, nullable=True, server_default="CURRENT_TIMESTAMP", onupdate="CURRENT_TIMESTAMP", comment="更新时间")
+    update_time = Column(TIMESTAMP, nullable=True, server_default="CURRENT_TIMESTAMP", server_onupdate="CURRENT_TIMESTAMP", comment="更新时间")

+ 2 - 2
pqai_agent/data_models/agent_push_record.py

@@ -1,12 +1,12 @@
 from sqlalchemy import Column, Integer, Text, BigInteger
-from sqlalchemy.ext.declarative import declarative_base
+from sqlalchemy.orm import declarative_base
 
 from pqai_agent import configs
 
 Base = declarative_base()
 
 class AgentPushRecord(Base):
-    __tablename__ = configs.get().get('storage', {}).get('push_record', {}).get('table_name', 'agent_push_record_dev')
+    __tablename__ = configs.get().get('storage', {}).get('push_record', {}).get('table', 'agent_push_record_dev')
     id = Column(Integer, primary_key=True)
     staff_id = Column(Integer)
     user_id = Column(Integer)

+ 1 - 1
pqai_agent/data_models/service_module.py

@@ -1,7 +1,7 @@
 from enum import Enum
 
 from sqlalchemy import Column, Integer, Text, BigInteger, String, SmallInteger, Boolean, TIMESTAMP
-from sqlalchemy.ext.declarative import declarative_base
+from sqlalchemy.orm import declarative_base
 
 Base = declarative_base()
 

+ 1 - 1
pqai_agent/database.py

@@ -5,7 +5,7 @@
 # Copyright © 2024 StrayWarrior <i@straywarrior.com>
 
 import pymysql
-from pqai_agent.logging_service import logger
+from pqai_agent.logging import logger
 
 class MySQLManager:
     def __init__(self, config):

+ 1 - 1
pqai_agent/dialogue_manager.py

@@ -16,7 +16,7 @@ from sqlalchemy.orm import sessionmaker, Session
 from pqai_agent import configs
 from pqai_agent.clients.relation_stage_client import RelationStageClient
 from pqai_agent.data_models.agent_push_record import AgentPushRecord
-from pqai_agent.logging_service import logger
+from pqai_agent.logging import logger
 from pqai_agent.database import MySQLManager
 from pqai_agent import chat_service, prompt_templates
 from pqai_agent.history_dialogue_service import HistoryDialogueService

+ 1 - 1
pqai_agent/history_dialogue_service.py

@@ -7,7 +7,7 @@ import requests
 from pymysql.cursors import DictCursor
 
 from pqai_agent.database import MySQLManager
-from pqai_agent.logging_service import logger
+from pqai_agent.logging import logger
 import time
 
 from pqai_agent import configs

+ 9 - 7
pqai_agent/logging_service.py → pqai_agent/logging.py

@@ -26,25 +26,27 @@ class ColoredFormatter(logging.Formatter):
         return message
 
 def setup_root_logger(level=logging.DEBUG, logfile_name='service.log'):
-    formatter = ColoredFormatter(
-        '%(asctime)s - %(name)s %(funcName)s[%(lineno)d] - %(levelname)s - %(message)s'
-    )
+    logging_format = '%(asctime)s - %(name)s %(funcName)s[%(lineno)d] - %(levelname)s - %(message)s'
+    plain_formatter = logging.Formatter(logging_format)
+    color_formatter = ColoredFormatter(logging_format)
     console_handler = logging.StreamHandler()
     console_handler.setLevel(logging.DEBUG)
-    console_handler.setFormatter(formatter)
+    console_handler.setFormatter(color_formatter)
 
     root_logger = logging.getLogger()
     root_logger.handlers.clear()
     root_logger.addHandler(console_handler)
-    if configs.get_env() == 'prod':
+
+    log_dir = configs.get().get('system', {}).get('log_dir', '')
+    if log_dir:
         file_handler = RotatingFileHandler(
-            f'/var/log/agent_service/{logfile_name}',
+            f'{log_dir}/{logfile_name}',
             maxBytes=64 * 1024 * 1024,
             backupCount=5,
             encoding='utf-8'
         )
         file_handler.setLevel(logging.DEBUG)
-        file_handler.setFormatter(formatter)
+        file_handler.setFormatter(plain_formatter)
         root_logger.addHandler(file_handler)
 
     agent_logger = logging.getLogger('agent')

+ 2 - 3
pqai_agent/message_queue_backend.py

@@ -9,8 +9,7 @@ import rocketmq
 from rocketmq import ClientConfiguration, Credentials, SimpleConsumer
 
 from pqai_agent import configs
-from pqai_agent import logging_service
-from pqai_agent.logging_service import logger
+from pqai_agent.logging import logger
 from pqai_agent.mq_message import MqMessage, MessageType, MessageChannel
 
 
@@ -87,7 +86,7 @@ class AliyunRocketMQQueueBackend(MessageQueueBackend):
             return None
         rmq_message = messages[0]
         body = rmq_message.body.decode('utf-8')
-        logger.debug("[{}]recv message body: {}".format(self.topic, body))
+        logger.debug(f"[{self.topic}]recv message, group[{rmq_message.message_group}], body: {body}")
         try:
             message = MqMessage.from_json(body)
             message._rmq_message = rmq_message

+ 1 - 0
pqai_agent/prompt_templates.py

@@ -279,6 +279,7 @@ RESPONSE_TYPE_DETECT_PROMPT = """
 * 默认使用文本
 * 如果用户明确提到使用语音形式,尽量选择语音
 * 用户自身偏向于使用语音形式沟通时,可选择语音
+* 如果用户不认字或有阅读障碍,且内容适合语音朗读,可选择语音
 * 注意分析即将发送的消息内容,如果有不适合使用语音朗读的内容,不要选择使用语音
 * 注意对话中包含的时间!注意时间流逝和情境切换!判断合适的回复方式!
 

+ 64 - 45
pqai_agent/push_service.py

@@ -16,7 +16,7 @@ from pqai_agent.abtest.utils import get_abtest_info
 from pqai_agent.agents.message_push_agent import MessagePushAgent, DummyMessagePushAgent
 from pqai_agent.configs import apollo_config
 from pqai_agent.data_models.agent_push_record import AgentPushRecord
-from pqai_agent.logging_service import logger
+from pqai_agent.logging import logger
 from pqai_agent.mq_message import MessageType
 from pqai_agent.toolkit import get_tools
 from pqai_agent.utils.agent_abtest_utils import get_agent_abtest_config
@@ -54,7 +54,10 @@ class PushScanThread:
         first_initiate_tags = set(apollo_config.get_json_value('agent_first_initiate_whitelist_tags', []))
         # 合并白名单,减少配置成本
         white_list_tags.update(first_initiate_tags)
-        for staff_user in self.service.user_relation_manager.list_staff_users(staff_id=self.staff_id):
+        all_staff_users = self.service.user_relation_manager.list_staff_users(staff_id=self.staff_id)
+        all_users = list({pair['user_id'] for pair in all_staff_users})
+        all_user_tags = self.service.user_manager.get_user_tags(all_users)
+        for staff_user in all_staff_users:
             staff_id = staff_user['staff_id']
             user_id = staff_user['user_id']
             # 通过AB实验配置控制用户组是否启用push
@@ -62,8 +65,8 @@ class PushScanThread:
             # if abtest_params.get('agent_push_enabled', 'false').lower() != 'true':
             #     logger.debug(f"User {user_id} not enabled agent push, skipping.")
             #     continue
-            user_tags = self.service.user_relation_manager.get_user_tags(user_id)
-            if configs.get_env() != 'dev' and not white_list_tags.intersection(user_tags):
+            user_tags = all_user_tags.get(user_id, list())
+            if not white_list_tags.intersection(user_tags):
                 should_initiate = False
             else:
                 agent = self.service.get_agent_instance(staff_id, user_id)
@@ -78,66 +81,79 @@ class PushScanThread:
 
 class PushTaskWorkerPool:
     def __init__(self, agent_service: 'AgentService', mq_topic: str,
-                 mq_consumer: rocketmq.SimpleConsumer, mq_producer: rocketmq.Producer):
+                 mq_consumer_generate: rocketmq.SimpleConsumer,
+                 mq_consumer_send: rocketmq.SimpleConsumer,
+                 mq_producer: rocketmq.Producer):
         self.agent_service = agent_service
         max_workers = configs.get()['system'].get('push_task_workers', 5)
         self.max_push_workers = max_workers
         self.generate_executor = ThreadPoolExecutor(max_workers=max_workers)
         self.send_executors = {}
         self.rmq_topic = mq_topic
-        self.consumer = mq_consumer
+        self.generate_consumer = mq_consumer_generate
+        self.send_consumer = mq_consumer_send
         self.producer = mq_producer
-        self.loop_thread = None
+        self.generate_loop_thread = None
+        self.send_loop_thread = None
         self.is_generator_running = True
-        self.generate_send_done = False # set by wait_to_finish
-        self.no_more_generate_task = False # set by self
+        self.generate_send_done = False # set by wait_to_finish,表示所有生成任务均已进入队列
+        self.no_more_generate_task = False # generate_send_done被设置之后队列中未再收到生成任务时设置
 
     def start(self):
-        self.loop_thread = Thread(target=self.process_push_tasks)
-        self.loop_thread.start()
+        self.send_loop_thread = Thread(target=self.process_send_tasks)
+        self.send_loop_thread.start()
+        self.generate_loop_thread = Thread(target=self.process_generate_tasks)
+        self.generate_loop_thread.start()
 
-    def process_push_tasks(self):
-        # RMQ consumer疑似有bug,创建后立即消费可能报NPE
+    def process_send_tasks(self):
         time.sleep(1)
         while True:
-            msgs = self.consumer.receive(1, 300)
+            msgs = self.send_consumer.receive(1, 60)
             if not msgs:
                 # 没有生成任务在执行且没有消息,才可退出
-                if self.generate_send_done:
-                    if not self.no_more_generate_task:
-                        logger.debug("no message received, there should be no more generate task")
-                        self.no_more_generate_task = True
-                        continue
-                    else:
-                        if self.is_generator_running:
-                            logger.debug("Waiting for generator threads to finish")
-                            continue
-                        else:
-                            break
+                if self.no_more_generate_task and not self.is_generator_running:
+                    break
                 else:
                     continue
             msg = msgs[0]
             task = json.loads(msg.body.decode('utf-8'))
             msg_time = datetime.fromtimestamp(task['timestamp'] / 1000).strftime("%Y-%m-%d %H:%M:%S")
             logger.debug(f"recv message:{msg_time} - {task}")
-            if task['task_type'] == TaskType.GENERATE.value:
-                # FIXME: 临时方案,避免消息在消费后等待超时并重复消费
-                if self.generate_executor._work_queue.qsize() > self.max_push_workers * 5:
-                    logger.warning("Too many generate tasks in queue, consume this task later")
-                    while self.generate_executor._work_queue.qsize() > self.max_push_workers * 5:
-                        time.sleep(10)
-                    # do not submit and ack this message
-                    continue
-                self.generate_executor.submit(self.handle_generate_task, task, msg)
-            elif task['task_type'] == TaskType.SEND.value:
+            if task['task_type'] == TaskType.SEND.value:
                 staff_id = task['staff_id']
                 if staff_id not in self.send_executors:
                     self.send_executors[staff_id] = ThreadPoolExecutor(max_workers=1)
                 self.send_executors[staff_id].submit(self.handle_send_task, task, msg)
             else:
                 logger.error(f"Unknown task type: {task['task_type']}")
-                self.consumer.ack(msg)
-        logger.info("PushGenerateWorkerPool stopped")
+                self.send_consumer.ack(msg)
+        logger.info("PushGenerateWorkerPool send thread stopped")
+
+    def process_generate_tasks(self):
+        time.sleep(1)
+        while True:
+            if self.generate_executor._work_queue.qsize() > self.max_push_workers * 2:
+                logger.warning("Too many generate tasks in queue, consume later")
+                time.sleep(10)
+                continue
+            msgs = self.generate_consumer.receive(1, 300)
+            if not msgs:
+                # 生成任务已经创建完成 且 未收到新任务,才可退出
+                if self.generate_send_done:
+                    logger.debug("no message received, there should be no more generate task")
+                    self.no_more_generate_task = True
+                    break
+                else:
+                    continue
+            msg = msgs[0]
+            task = json.loads(msg.body.decode('utf-8'))
+            msg_time = datetime.fromtimestamp(task['timestamp'] / 1000).strftime("%Y-%m-%d %H:%M:%S")
+            logger.debug(f"recv message:{msg_time} - {task}")
+            if task['task_type'] == TaskType.GENERATE.value:
+                self.generate_executor.submit(self.handle_generate_task, task, msg)
+            else:
+                self.generate_consumer.ack(msg)
+        logger.info("PushGenerateWorkerPool generator thread stopped")
 
     def wait_to_finish(self):
         self.generate_send_done = True
@@ -146,7 +162,8 @@ class PushTaskWorkerPool:
             time.sleep(1)
         self.generate_executor.shutdown(wait=True)
         self.is_generator_running = False
-        self.loop_thread.join()
+        self.generate_loop_thread.join()
+        self.send_loop_thread.join()
 
     def handle_send_task(self, task: Dict, msg: rocketmq.Message):
         try:
@@ -155,13 +172,13 @@ class PushTaskWorkerPool:
             agent = self.agent_service.get_agent_instance(staff_id, user_id)
             # 二次校验是否需要发送
             if not agent.should_initiate_conversation():
-                logger.debug(f"user[{user_id}], do not initiate conversation")
-                self.consumer.ack(msg)
+                logger.debug(f"user[{user_id}], should not initiate, skip sending task")
+                self.send_consumer.ack(msg)
                 return
             contents: List[Dict] = json.loads(task['content'])
             if not contents:
                 logger.debug(f"staff[{staff_id}], user[{user_id}]: empty content, do not send")
-                self.consumer.ack(msg)
+                self.send_consumer.ack(msg)
                 return
             recent_dialogue = agent.dialogue_history[-10:]
             agent_voice_whitelist = set(apollo_config.get_json_value("agent_voice_whitelist", []))
@@ -197,11 +214,11 @@ class PushTaskWorkerPool:
                 agent.update_last_active_interaction_time(current_ts)
             else:
                 logger.debug(f"staff[{staff_id}], user[{user_id}]: generate empty response")
-            self.consumer.ack(msg)
+            self.send_consumer.ack(msg)
         except Exception as e:
             fmt_exc = traceback.format_exc()
             logger.error(f"Error processing message sending: {e}, {fmt_exc}")
-            self.consumer.ack(msg)
+            self.send_consumer.ack(msg)
 
     def handle_generate_task(self, task: Dict, msg: rocketmq.Message):
         try:
@@ -231,15 +248,17 @@ class PushTaskWorkerPool:
                 ),
                 query_prompt_template=query_prompt_template
             )
+            cost = push_agent.get_total_cost()
+            logger.debug(f"staff[{staff_id}], user[{user_id}]: push message generation cost: {cost}")
             if message_to_user:
                 rmq_message = generate_task_rmq_message(
                     self.rmq_topic, staff_id, user_id, TaskType.SEND, json.dumps(message_to_user))
                 self.producer.send(rmq_message)
             else:
                 logger.info(f"staff[{staff_id}], user[{user_id}]: no push message generated")
-            self.consumer.ack(msg)
+            self.generate_consumer.ack(msg)
         except Exception as e:
             fmt_exc = traceback.format_exc()
             logger.error(f"Error processing message generation: {e}, {fmt_exc}")
             # FIXME: 是否需要ACK
-            self.consumer.ack(msg)
+            self.generate_consumer.ack(msg)

+ 1 - 1
pqai_agent/rate_limiter.py

@@ -5,7 +5,7 @@
 import time
 from typing import Optional, Union, Dict
 
-from pqai_agent.logging_service import logger
+from pqai_agent.logging import logger
 from pqai_agent.mq_message import MessageType
 
 

+ 2 - 2
pqai_agent/response_type_detector.py

@@ -12,7 +12,7 @@ from pqai_agent import chat_service
 from pqai_agent import configs
 from pqai_agent import prompt_templates
 from pqai_agent.dialogue_manager import DialogueManager
-from pqai_agent.logging_service import logger
+from pqai_agent.logging import logger
 from pqai_agent.mq_message import MessageType
 
 
@@ -36,7 +36,7 @@ class ResponseTypeDetector:
             api_key=chat_service.VOLCENGINE_API_TOKEN,
             base_url=chat_service.VOLCENGINE_BASE_URL
         )
-        self.model_name = chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_1_5
+        self.model_name = chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_1_5_32K
 
     def detect_type(self, dialogue_history: List[Dict], next_message: Dict, enable_random=False,
                     random_rate=0.25):

+ 1 - 1
pqai_agent/service_module_manager.py

@@ -1,5 +1,5 @@
 from pqai_agent.data_models.service_module import ServiceModule, ModuleAgentType
-from pqai_agent.logging_service import logger
+from pqai_agent.logging import logger
 
 class ServiceModuleManager:
     def __init__(self, session_maker):

+ 1 - 1
pqai_agent/toolkit/__init__.py

@@ -1,7 +1,7 @@
 # 必须要在这里导入模块,以便对应的模块执行register_toolkit
 from typing import Sequence, List
 
-from pqai_agent.logging_service import logger
+from pqai_agent.logging import logger
 from pqai_agent.toolkit.tool_registry import ToolRegistry
 from pqai_agent.toolkit.image_describer import ImageDescriber
 from pqai_agent.toolkit.message_notifier import MessageNotifier

+ 1 - 1
pqai_agent/toolkit/function_tool.py

@@ -8,7 +8,7 @@ from pydantic import BaseModel, create_model
 from pydantic.fields import FieldInfo
 from jsonschema.validators import Draft202012Validator as JSONValidator
 import re
-from pqai_agent.logging_service import logger
+from pqai_agent.logging import logger
 
 
 def to_pascal(snake: str) -> str:

+ 1 - 1
pqai_agent/toolkit/image_describer.py

@@ -3,7 +3,7 @@ import threading
 
 from pqai_agent import chat_service
 from pqai_agent.chat_service import VOLCENGINE_MODEL_DOUBAO_1_5_VISION_PRO
-from pqai_agent.logging_service import logger
+from pqai_agent.logging import logger
 from pqai_agent.toolkit.base import BaseToolkit
 from pqai_agent.toolkit.function_tool import FunctionTool
 from pqai_agent.toolkit.tool_registry import register_toolkit

+ 1 - 1
pqai_agent/toolkit/lark_sheet_record_for_human_intervention.py

@@ -4,7 +4,7 @@ from typing import List
 import requests
 from pqai_agent.toolkit.base import BaseToolkit
 from pqai_agent.toolkit.function_tool import FunctionTool
-from pqai_agent.logging_service import logger
+from pqai_agent.logging import logger
 
 class LarkSheetRecordForHumanIntervention(BaseToolkit):
     r"""A toolkit for recording human intervention events into a Feishu spreadsheet."""

+ 1 - 1
pqai_agent/toolkit/message_notifier.py

@@ -1,6 +1,6 @@
 from typing import List, Dict
 
-from pqai_agent.logging_service import logger
+from pqai_agent.logging import logger
 from pqai_agent.toolkit.base import BaseToolkit
 from pqai_agent.toolkit.function_tool import FunctionTool
 from pqai_agent.toolkit.tool_registry import register_toolkit

+ 51 - 3
pqai_agent/user_manager.py

@@ -1,8 +1,9 @@
 #! /usr/bin/env python
 # -*- coding: utf-8 -*-
 # vim:fenc=utf-8
+from abc import abstractmethod
 
-from pqai_agent.logging_service import logger
+from pqai_agent.logging import logger
 from typing import Dict, Optional, List
 import json
 import time
@@ -33,6 +34,10 @@ class UserManager(abc.ABC):
         #FIXME(zhoutian): 重新设计用户和员工数据管理模型
         pass
 
+    @abstractmethod
+    def get_user_tags(self, user_ids: List[str], batch_size = 500) -> Dict[str, List[str]]:
+        pass
+
     @staticmethod
     def get_default_profile(**kwargs) -> Dict:
         default_profile = {
@@ -133,6 +138,9 @@ class LocalUserManager(UserManager):
             logger.error("staff profile not found: {}".format(e))
             return {}
 
+    def get_user_tags(self, user_ids: List[str], batch_size = 500) -> Dict[str, List[str]]:
+        return {}
+
     def list_users(self, **kwargs) -> List[Dict]:
         pass
 
@@ -249,6 +257,37 @@ class MySQLUserManager(UserManager):
         sql = f"UPDATE {self.staff_table} SET agent_profile = %s WHERE third_party_user_id = '{staff_id}'"
         self.db.execute(sql, (json.dumps(profile),))
 
+    def get_user_tags(self, user_ids: List[str], batch_size = 500) -> Dict[str, List[str]]:
+        """
+        获取用户的标签列表
+        :param user_ids: 用户ID
+        :param batch_size: 批量查询的大小
+        :return: 标签名称列表
+        """
+        batches = (len(user_ids) + batch_size - 1) // batch_size
+        ret = {}
+        for i in range(batches):
+            idx_begin = i * batch_size
+            idx_end = min((i + 1) * batch_size, len(user_ids))
+            batch_user_ids = user_ids[idx_begin:idx_end]
+            sql = f"""
+                SELECT a.third_party_user_id, b.tag_id, b.name FROM qywx_user_tag a
+                    JOIN qywx_tag b ON a.tag_id = b.tag_id
+                """
+            if len(batch_user_ids) == 1:
+                sql += f" AND a.third_party_user_id = '{batch_user_ids[0]}'"""
+            else:
+                sql += f" AND a.third_party_user_id IN {str(tuple(batch_user_ids))}"
+            rows = self.db.select(sql, pymysql.cursors.DictCursor)
+            # group by user_id
+            for row in rows:
+                user_id = row['third_party_user_id']
+                tag_name = row['name']
+                if user_id not in ret:
+                    ret[user_id] = []
+                ret[user_id].append(tag_name)
+        return ret
+
     def list_users(self, **kwargs) -> List[Dict]:
         user_union_id = kwargs.get('user_union_id', None)
         user_name = kwargs.get('user_name', None)
@@ -333,6 +372,7 @@ class MySQLUserRelationManager(UserRelationManager):
         self.relation_table = relation_table
         self.agent_user_table = agent_user_table
         self.user_table = user_table
+        self.agent_user_relation_table = 'qywx_employee_customer'
 
     def list_staffs(self):
         sql = f"SELECT third_party_user_id, name, wxid, agent_name FROM {self.agent_staff_table} WHERE status = 1"
@@ -342,7 +382,14 @@ class MySQLUserRelationManager(UserRelationManager):
     def list_users(self, staff_id: str, page: int = 1, page_size: int = 100):
         return []
 
-    def list_staff_users(self, staff_id: str = None, tag_id: int = None):
+    def list_staff_users(self, staff_id: str = None, tag_id: int = None) -> List[Dict]:
+        sql = f"SELECT employee_id as staff_id, customer_id as user_id FROM {self.agent_user_relation_table} WHERE 1 = 1"
+        if staff_id:
+            sql += f" AND employee_id = '{staff_id}'"
+        agent_staff_data = self.agent_db.select(sql, pymysql.cursors.DictCursor)
+        return agent_staff_data
+
+    def list_staff_users_v1(self, staff_id: str = None, tag_id: int = None):
         sql = f"SELECT third_party_user_id, wxid FROM {self.agent_staff_table} WHERE status = 1"
         if staff_id:
             sql += f" AND third_party_user_id = '{staff_id}'"
@@ -382,7 +429,8 @@ class MySQLUserRelationManager(UserRelationManager):
                 sql = f"SELECT third_party_user_id, wxid FROM {self.agent_user_table} WHERE wxid IN {str(batch_union_ids)}"
                 batch_agent_user_data = self.agent_db.select(sql, pymysql.cursors.DictCursor)
                 if len(agent_user_data) != len(batch_union_ids):
-                    # logger.debug(f"staff[{wxid}] some users not found in agent database")
+                    diff_num = len(batch_union_ids) - len(batch_agent_user_data)
+                    logger.debug(f"staff[{staff_id}] {diff_num} users not found in agent database")
                     pass
                 agent_user_data.extend(batch_agent_user_data)
             staff_user_pairs = [

+ 3 - 3
pqai_agent/user_profile_extractor.py

@@ -8,7 +8,7 @@ from typing import Dict, Optional, List
 from pqai_agent import chat_service, configs
 from pqai_agent.prompt_templates import USER_PROFILE_EXTRACT_PROMPT, USER_PROFILE_EXTRACT_PROMPT_V2
 from openai import OpenAI
-from pqai_agent.logging_service import logger
+from pqai_agent.logging import logger
 from pqai_agent.utils import prompt_utils
 
 
@@ -198,8 +198,8 @@ class UserProfileExtractor:
 
 if __name__ == '__main__':
     from pqai_agent import configs
-    from pqai_agent import logging_service
-    logging_service.setup_root_logger()
+    from pqai_agent import logging
+    logging.setup_root_logger()
     config = configs.get()
     config['debug_flags']['disable_llm_api_call'] = False
     extractor = UserProfileExtractor()

+ 8 - 0
pqai_agent/utils/__init__.py

@@ -0,0 +1,8 @@
+import hashlib
+import random
+import time
+
+def random_str() -> str:
+    """生成一个随机的MD5字符串"""
+    random_string = str(random.randint(0, 1000000)) + str(time.time())
+    return hashlib.md5(random_string.encode('utf-8')).hexdigest()

+ 7 - 4
pqai_agent_server/agent_server.py

@@ -2,10 +2,10 @@ import logging
 import sys
 import time
 
-from pqai_agent import configs, logging_service
+from pqai_agent import configs
 from pqai_agent.agent_service import AgentService
 from pqai_agent.chat_service import ChatServiceType
-from pqai_agent.logging_service import logger
+from pqai_agent.logging import logger, setup_root_logger
 from pqai_agent.mq_message import MessageType, MqMessage, MessageChannel
 from pqai_agent.message_queue_backend import AliyunRocketMQQueueBackend, MemoryQueueBackend
 from pqai_agent.push_service import PushTaskWorkerPool, PushScanThread
@@ -14,7 +14,7 @@ from pqai_agent.user_manager import LocalUserManager, LocalUserRelationManager,
 
 if __name__ == "__main__":
     config = configs.get()
-    logging_service.setup_root_logger()
+    setup_root_logger()
     logger.warning("current env: {}".format(configs.get_env()))
     scheduler_logger = logging.getLogger('apscheduler')
     scheduler_logger.setLevel(logging.WARNING)
@@ -92,13 +92,16 @@ if __name__ == "__main__":
             continue
         message_id += 1
         sender = '7881301903997433'
-        receiver = '1688855931724582'
+        receiver = '1688854974625870'
         if text in (MessageType.AGGREGATION_TRIGGER.name,
                     MessageType.HUMAN_INTERVENTION_END.name):
             message = MqMessage.build(
                 MessageType.__members__.get(text),
                 MessageChannel.CORP_WECHAT,
                 sender, receiver, None, int(time.time() * 1000))
+        elif text == 'S_PUSH':
+            service._check_initiative_conversations()
+            continue
         else:
             message = MqMessage.build(MessageType.TEXT, MessageChannel.CORP_WECHAT,
                                       sender, receiver, text, int(time.time() * 1000)

+ 140 - 57
pqai_agent_server/api_server.py

@@ -1,6 +1,8 @@
 #! /usr/bin/env python
 # -*- coding: utf-8 -*-
 # vim:fenc=utf-8
+import json
+import time
 import logging
 from argparse import ArgumentParser
 
@@ -9,6 +11,10 @@ from flask import Flask, request, jsonify
 from sqlalchemy.orm import sessionmaker
 
 from pqai_agent import configs
+
+from pqai_agent import chat_service, prompt_templates
+from pqai_agent.logging import logger, setup_root_logger
+from pqai_agent.toolkit import global_tool_map
 from pqai_agent import logging_service, chat_service, prompt_templates
 from pqai_agent.data_models.agent_configuration import AgentConfiguration
 from pqai_agent.data_models.service_module import ServiceModule
@@ -21,6 +27,8 @@ from pqai_agent_server.const.status_enum import TestTaskStatus
 from pqai_agent_server.const.type_enum import EvaluateType
 from pqai_agent_server.dataset_service import DatasetService
 from pqai_agent_server.models import MySQLSessionManager
+import pqai_agent_server.utils
+from pqai_agent_server.utils import wrap_response
 from pqai_agent_server.task_server import TaskManager
 from pqai_agent_server.utils import (
     run_extractor_prompt,
@@ -30,7 +38,6 @@ from pqai_agent_server.utils import (
 from pqai_agent_server.utils import wrap_response, quit_human_intervention_status
 
 app = Flask('agent_api_server')
-logger = logging_service.logger
 const = AgentApiConst()
 
 
@@ -93,34 +100,23 @@ def get_dialogue_history():
 
 @app.route('/api/listModels', methods=['GET'])
 def list_models():
-    models = [
-        {
-            'model_type': 'openai_compatible',
-            'model_name': chat_service.VOLCENGINE_MODEL_DEEPSEEK_V3,
-            'display_name': 'DeepSeek V3 on 火山'
-        },
-        {
-            'model_type': 'openai_compatible',
-            'model_name': chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_32K,
-            'display_name': '豆包Pro 32K'
-        },
-        {
-            'model_type': 'openai_compatible',
-            'model_name': chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_1_5,
-            'display_name': '豆包Pro 1.5'
-        },
-        {
-            'model_type': 'openai_compatible',
-            'model_name': chat_service.VOLCENGINE_BOT_DEEPSEEK_V3_SEARCH,
-            'display_name': 'DeepSeek V3联网 on 火山'
-        },
+    models = {
+        "deepseek-chat": chat_service.VOLCENGINE_MODEL_DEEPSEEK_V3,
+        "gpt-4o": chat_service.OPENAI_MODEL_GPT_4o,
+        "gpt-4o-mini": chat_service.OPENAI_MODEL_GPT_4o_mini,
+        "doubao-pro-32k": chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_32K,
+        "doubao-pro-1.5": chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_1_5_32K,
+        "doubao-1.5-vision-pro": chat_service.VOLCENGINE_MODEL_DOUBAO_1_5_VISION_PRO,
+    }
+    ret_data = [
         {
             'model_type': 'openai_compatible',
-            'model_name': chat_service.VOLCENGINE_MODEL_DOUBAO_1_5_VISION_PRO,
-            'display_name': '豆包1.5视觉理解Pro'
-        },
+            'model_name': model_name,
+            'display_name': model_display_name
+        }
+        for model_display_name, model_name in models.items()
     ]
-    return wrap_response(200, data=models)
+    return wrap_response(200, data=ret_data)
 
 
 @app.route('/api/listScenes', methods=['GET'])
@@ -148,8 +144,8 @@ def get_base_prompt():
     model_map = {
         'greeting': chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_32K,
         'chitchat': chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_32K,
-        'profile_extractor': chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_1_5,
-        'response_type_detector': chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_1_5,
+        'profile_extractor': chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_1_5_32K,
+        'response_type_detector': chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_1_5_32K,
         'custom_debugging': chat_service.VOLCENGINE_BOT_DEEPSEEK_V3_SEARCH
     }
     if scene not in prompt_map:
@@ -180,7 +176,6 @@ def run_prompt():
         logger.error(e)
         return wrap_response(500, msg='Error: {}'.format(e))
 
-
 @app.route('/api/formatForPrompt', methods=['POST'])
 def format_data_for_prompt():
     try:
@@ -300,8 +295,8 @@ def send_message():
     return wrap_response(200, msg="暂不实现功能")
 
 
-@app.route("/api/quitHumanInterventionStatus", methods=["POST"])
-def quit_human_interventions_status():
+@app.route("/api/quitHumanIntervention", methods=["POST"])
+def quit_human_intervention():
     """
     退出人工介入状态
     :return:
@@ -311,10 +306,27 @@ def quit_human_interventions_status():
     user_id = req_data["user_id"]
     if not user_id or not staff_id:
         return wrap_response(404, msg="user_id and staff_id are required")
-    response = quit_human_intervention_status(user_id, staff_id)
+    if pqai_agent_server.utils.common.quit_human_intervention(user_id, staff_id):
+        return wrap_response(200, msg="success")
+    else:
+        return wrap_response(500, msg="error")
 
-    return wrap_response(200, data=response)
 
+@app.route("/api/enterHumanIntervention", methods=["POST"])
+def enter_human_intervention():
+    """
+    进入人工介入状态
+    :return:
+    """
+    req_data = request.json
+    staff_id = req_data["staff_id"]
+    user_id = req_data["user_id"]
+    if not user_id or not staff_id:
+        return wrap_response(404, msg="user_id and staff_id are required")
+    if pqai_agent_server.utils.common.enter_human_intervention(user_id, staff_id):
+        return wrap_response(200, msg="success")
+    else:
+        return wrap_response(500, msg="error")
 
 ## Agent管理接口
 @app.route("/api/getNativeAgentList", methods=["GET"])
@@ -336,23 +348,28 @@ def get_native_agent_list():
             query = query.filter(AgentConfiguration.create_user == create_user)
         if update_user:
             query = query.filter(AgentConfiguration.update_user == update_user)
+        total = query.count()
         query = query.offset(offset).limit(int(page_size))
         data = query.all()
-    ret_data = [
-        {
-            'id': agent.id,
-            'name': agent.name,
-            'display_name': agent.display_name,
-            'type': agent.type,
-            'execution_model': agent.execution_model,
-            'create_time': agent.create_time.strftime('%Y-%m-%d %H:%M:%S'),
-            'update_time': agent.update_time.strftime('%Y-%m-%d %H:%M:%S')
-        }
-        for agent in data
-    ]
+    ret_data = {
+        'total': total,
+        'agent_list': [
+            {
+                'id': agent.id,
+                'name': agent.name,
+                'display_name': agent.display_name,
+                'type': agent.type,
+                'execution_model': agent.execution_model,
+                'create_user': agent.create_user,
+                'update_user': agent.update_user,
+                'create_time': agent.create_time.strftime('%Y-%m-%d %H:%M:%S'),
+                'update_time': agent.update_time.strftime('%Y-%m-%d %H:%M:%S')
+            }
+            for agent in data
+        ]
+    }
     return wrap_response(200, data=ret_data)
 
-
 @app.route("/api/getNativeAgentConfiguration", methods=["GET"])
 def get_native_agent_configuration():
     """
@@ -376,15 +393,14 @@ def get_native_agent_configuration():
             'execution_model': agent.execution_model,
             'system_prompt': agent.system_prompt,
             'task_prompt': agent.task_prompt,
-            'tools': agent.tools,
-            'sub_agents': agent.sub_agents,
-            'extra_params': agent.extra_params,
+            'tools': json.loads(agent.tools),
+            'sub_agents': json.loads(agent.sub_agents),
+            'extra_params': json.loads(agent.extra_params),
             'create_time': agent.create_time.strftime('%Y-%m-%d %H:%M:%S'),
             'update_time': agent.update_time.strftime('%Y-%m-%d %H:%M:%S')
         }
         return wrap_response(200, data=data)
 
-
 @app.route("/api/saveNativeAgentConfiguration", methods=["POST"])
 def save_native_agent_configuration():
     """
@@ -399,9 +415,19 @@ def save_native_agent_configuration():
     execution_model = req_data.get('execution_model', None)
     system_prompt = req_data.get('system_prompt', None)
     task_prompt = req_data.get('task_prompt', None)
-    tools = req_data.get('tools', [])
-    sub_agents = req_data.get('sub_agents', [])
+    tools = json.dumps(req_data.get('tools', []))
+    sub_agents = json.dumps(req_data.get('sub_agents', []))
     extra_params = req_data.get('extra_params', {})
+    operate_user = req_data.get('user', None)
+    if isinstance(extra_params, dict):
+        extra_params = json.dumps(extra_params)
+    elif isinstance(extra_params, str):
+        try:
+            json.loads(extra_params)
+        except json.JSONDecodeError:
+            return wrap_response(400, msg='extra_params should be a valid JSON object or string')
+    if not extra_params:
+        extra_params = '{}'
 
     if not name:
         return wrap_response(400, msg='name is required')
@@ -421,6 +447,7 @@ def save_native_agent_configuration():
             agent.tools = tools
             agent.sub_agents = sub_agents
             agent.extra_params = extra_params
+            agent.update_user = operate_user
         else:
             agent = AgentConfiguration(
                 name=name,
@@ -431,7 +458,9 @@ def save_native_agent_configuration():
                 task_prompt=task_prompt,
                 tools=tools,
                 sub_agents=sub_agents,
-                extra_params=extra_params
+                extra_params=extra_params,
+                create_user=operate_user,
+                update_user=operate_user
             )
             session.add(agent)
 
@@ -439,6 +468,35 @@ def save_native_agent_configuration():
         return wrap_response(200, msg='Agent configuration saved successfully', data={'id': agent.id})
 
 
+
+@app.route("/api/deleteNativeAgentConfiguration", methods=["POST"])
+def delete_native_agent_configuration():
+    """
+    删除指定Agent配置(软删除,设置is_delete=1)
+    :return:
+    """
+    req_data = request.json
+    agent_id = req_data.get('agent_id', None)
+    if not agent_id:
+        return wrap_response(400, msg='agent_id is required')
+    try:
+        agent_id = int(agent_id)
+    except ValueError:
+        return wrap_response(400, msg='agent_id must be an integer')
+
+    with app.session_maker() as session:
+        agent = session.query(AgentConfiguration).filter(
+            AgentConfiguration.id == agent_id,
+            AgentConfiguration.is_delete == 0
+        ).first()
+        if not agent:
+            return wrap_response(404, msg='Agent not found')
+        agent.is_delete = 1
+        session.commit()
+        return wrap_response(200, msg='Agent configuration deleted successfully')
+
+
+
 @app.route("/api/getModuleList", methods=["GET"])
 def get_module_list():
     """
@@ -463,7 +521,6 @@ def get_module_list():
     ]
     return wrap_response(200, data=ret_data)
 
-
 @app.route("/api/getModuleConfiguration", methods=["GET"])
 def get_module_configuration():
     """
@@ -490,7 +547,6 @@ def get_module_configuration():
         }
         return wrap_response(200, data=data)
 
-
 @app.route("/api/saveModuleConfiguration", methods=["POST"])
 def save_module_configuration():
     """
@@ -673,6 +729,33 @@ def get_conversation_data_list():
     return wrap_response(200, data=response)
 
 
+@app.route("/api/getToolList", methods=["GET"])
+def get_tool_list():
+    """
+    获取所有的工具列表
+    :return:
+    """
+    tools = []
+    for tool_name, tool in global_tool_map.items():
+        tools.append({
+            'name': tool_name,
+            'description': tool.get_function_description(),
+            'parameters': tool.parameters if hasattr(tool, 'parameters') else {}
+        })
+    return wrap_response(200, data=tools)
+
+@app.route("/api/getModuleAgentTypes", methods=["GET"])
+def get_agent_types():
+    """
+    获取所有的Agent类型
+    :return:
+    """
+    agent_types = [
+        {'type': 0, 'display_name': '原生'},
+        {'type': 1, 'display_name': 'Coze'}
+    ]
+    return wrap_response(200, data=agent_types)
+
 @app.errorhandler(werkzeug.exceptions.BadRequest)
 def handle_bad_request(e):
     logger.error(e)
@@ -689,7 +772,7 @@ if __name__ == '__main__':
 
     config = configs.get()
     logging_level = logging.getLevelName(args.log_level)
-    logging_service.setup_root_logger(level=logging_level, logfile_name='agent_api_server.log')
+    setup_root_logger(level=logging_level, logfile_name='agent_api_server.log')
 
     # set db config
     agent_db_config = config['database']['ai_agent']
@@ -700,7 +783,7 @@ if __name__ == '__main__':
     chat_history_db_config = config['storage']['chat_history']
 
     # init user manager
-    user_manager = MySQLUserManager(agent_db_config, growth_db_config, staff_db_config['table'])
+    user_manager = MySQLUserManager(agent_db_config, user_db_config['table'], staff_db_config['table'])
     app.user_manager = user_manager
 
     # init session manager

+ 0 - 1
pqai_agent_server/utils/__init__.py

@@ -1,5 +1,4 @@
 from .common import wrap_response
-from .common import quit_human_intervention_status
 
 from .prompt_util import (
     run_openai_chat,

+ 14 - 3
pqai_agent_server/utils/common.py

@@ -11,7 +11,18 @@ def wrap_response(code, msg=None, data=None):
     return jsonify(resp)
 
 
-def quit_human_intervention_status(user_id, staff_id):
+def quit_human_intervention(user_id, staff_id) -> bool:
     url = f"http://ai-wechat-hook-internal.piaoquantv.com/manage/insertEvent?sender={user_id}&receiver={staff_id}&type=103&content=SYSTEM"
-    response = requests.get(url, timeout=20)
-    return response.json()
+    response = requests.post(url, timeout=20, headers={"Content-Type": "application/json"})
+    if response.status_code == 200 and response.json().get("code") == 0:
+        return True
+    else:
+        return False
+
+def enter_human_intervention(user_id, staff_id) -> bool:
+    url = f"http://ai-wechat-hook-internal.piaoquantv.com/manage/insertEvent?sender={user_id}&receiver={staff_id}&type=104&content=SYSTEM"
+    response = requests.post(url, timeout=20, headers={"Content-Type": "application/json"})
+    if response.status_code == 200 and response.json().get("code") == 0:
+        return True
+    else:
+        return False

+ 3 - 4
pqai_agent_server/utils/prompt_util.py

@@ -5,15 +5,14 @@ from typing import List, Dict
 
 from openai import OpenAI
 
-from pqai_agent import logging_service, chat_service
+from pqai_agent import chat_service
+from pqai_agent.logging import logger
 from pqai_agent.response_type_detector import ResponseTypeDetector
 from pqai_agent.user_profile_extractor import UserProfileExtractor
 from pqai_agent.dialogue_manager import DialogueManager
 from pqai_agent.mq_message import MessageType
 from pqai_agent.utils.prompt_utils import format_agent_profile
 
-logger = logging_service.logger
-
 
 def compose_openai_chat_messages_no_time(dialogue_history, multimodal=False):
     messages = []
@@ -44,7 +43,7 @@ def compose_openai_chat_messages_no_time(dialogue_history, multimodal=False):
 def create_llm_client(model_name):
     volcengine_models = [
         chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_32K,
-        chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_1_5,
+        chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_1_5_32K,
         chat_service.VOLCENGINE_MODEL_DOUBAO_1_5_VISION_PRO,
         chat_service.VOLCENGINE_MODEL_DEEPSEEK_V3,
     ]

+ 43 - 0
scripts/extract_push_action_logs.py

@@ -0,0 +1,43 @@
+import re
+
+def extract_agent_run_steps(log_path):
+    pattern = re.compile(
+        r'^\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2},\d{3} - agent run\[\d+\] - DEBUG - current step content:'
+    )
+    timestamp_pattern = re.compile(r'^\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2},\d{3} - ')
+    results = []
+    current = []
+    collecting = False
+
+    with open(log_path, 'r', encoding='utf-8') as f:
+        for line in f:
+            if pattern.match(line):
+                if collecting and current:
+                    results.append(''.join(current).rstrip())
+                    current = []
+                collecting = True
+                current.append(line)
+            elif collecting:
+                if timestamp_pattern.match(line):
+                    results.append(''.join(current).rstrip())
+                    current = []
+                    collecting = False
+                else:
+                    current.append(line)
+        # 文件结尾处理
+        if collecting and current:
+            results.append(''.join(current).rstrip())
+    return results
+
+if __name__ == "__main__":
+    import sys
+    if len(sys.argv) != 2:
+        print("Usage: python extract_agent_run_step.py <logfile>")
+        sys.exit(1)
+    log_file = sys.argv[1]
+    steps = extract_agent_run_steps(log_file)
+    for i, step in enumerate(steps, 1):
+        print(f"--- Step {i} ---")
+        print(step)
+        print()
+

+ 40 - 3
tests/unit_test.py

@@ -4,8 +4,12 @@
 
 import pytest
 from unittest.mock import Mock, MagicMock
-from pqai_agent.agent_service import AgentService, MemoryQueueBackend
+
+import pqai_agent.abtest.client
+import pqai_agent.configs
+from pqai_agent.agent_service import AgentService
 from pqai_agent.dialogue_manager import DialogueState, TimeContext
+from pqai_agent.message_queue_backend import MemoryQueueBackend
 from pqai_agent.mq_message import MessageType, MqMessage, MessageChannel
 from pqai_agent.response_type_detector import ResponseTypeDetector
 from pqai_agent.user_manager import LocalUserManager
@@ -44,11 +48,16 @@ def test_env():
         user_relation_manager=user_relation_manager
     )
     service.user_profile_extractor.extract_profile_info = Mock(return_value=None)
+    service.can_send_to_user = Mock(return_value=True)
+    service.start()
 
     # 替换LLM调用为模拟响应
     service._call_chat_api = Mock(return_value="模拟响应")
 
-    return service, queues
+    yield service, queues
+
+    service.shutdown(sync=True)
+    pqai_agent.abtest.client.get_client().shutdown()
 
 def test_agent_state_change(test_env):
     service, _ = test_env
@@ -220,10 +229,38 @@ def test_initiative_conversation(test_env):
 def test_response_type_detector(test_env):
     case1 = '大哥,那可得提前了解下天气,以便安排行程~我帮您查查明天北京天气?'
     assert ResponseTypeDetector.is_chinese_only(case1) == True
-    assert ResponseTypeDetector.if_message_suitable_for_voice(case1) == False
+    assert ResponseTypeDetector.if_message_suitable_for_voice(case1) == True
     case2 = 'hi'
     assert ResponseTypeDetector.is_chinese_only(case2) == False
     case3 = '这是链接:http://domain.com'
     assert ResponseTypeDetector.is_chinese_only(case3) == False
     case4 = '大哥,那可得提前了解下天气'
     assert ResponseTypeDetector.if_message_suitable_for_voice(case4) == True
+
+    global_config = pqai_agent.configs.get()
+    global_config.get('debug_flags', {}).update({'disable_llm_api_call': False})
+
+    response_detector = ResponseTypeDetector()
+    dialogue1 = [
+        {'role': 'user', 'content': '你好', 'timestamp': 1744979571000, 'type': MessageType.TEXT},
+        {'role': 'assistant', 'content': '你好呀', 'timestamp': 1744979581000},
+    ]
+    assert response_detector.detect_type(dialogue1[:-1], dialogue1[-1]) == MessageType.TEXT
+
+    dialogue2 = [
+        {'role': 'user', 'content': '你可以读一个故事给我听吗', 'timestamp': 1744979591000},
+        {'role': 'assistant', 'content': '当然可以啦!想听什么?', 'timestamp': 1744979601000},
+        {'role': 'user', 'content': '我想听小王子', 'timestamp': 1744979611000},
+        {'role': 'assistant', 'content': '《小王子》讲述了一位年轻王子离开自己的小世界去探索宇宙的冒险经历。 在旅途中,他遇到了各种各样的人,包括被困的飞行员、狐狸和聪明的蛇。 王子通过这些遭遇学到了关于爱情、友谊和超越表面的必要性的重要教训。', 'timestamp': 1744979611000},
+    ]
+    assert response_detector.detect_type(dialogue2[:-1], dialogue2[-1]) == MessageType.VOICE
+
+    dialogue3 = [
+        {'role': 'user', 'content': '他说的是西洋参呢,晓不得到底是不是西洋参。那个样,那个茶是抽的真空的紧包包。我泡他两包,两包泡到十几盒,13盒,我还拿回来的。', 'timestamp': 1744979591000},
+        {'role': 'assistant', 'content': '咋啦?是突然想到啥啦,还是有其他事想和我分享分享?', 'timestamp': 1744979601000},
+        {'role': 'user', 'content': '不要打字,还不要打。听不到。不要打字,不要打字,打字我认不到。打字我认不到,不要打字不要打字,打字我认不到。', 'timestamp': 1744979611000},
+        {'role': 'assistant', 'content': '真是不好意思', 'timestamp': 1744979611000},
+    ]
+    assert response_detector.detect_type(dialogue3[:-1], dialogue3[-1]) == MessageType.VOICE
+
+    global_config.get('debug_flags', {}).update({'disable_llm_api_call': True})