Browse Source

Merge remote-tracking branch 'origin/master' into feature/dev/20250605-dev

luojunhui 4 weeks ago
parent
commit
a6cb55b0e5
40 changed files with 1486 additions and 235 deletions
  1. 1 0
      .gitignore
  2. 0 0
      pqai_agent/abtest/__init__.py
  3. 288 0
      pqai_agent/abtest/client.py
  4. 259 0
      pqai_agent/abtest/models.py
  5. 7 0
      pqai_agent/abtest/utils.py
  6. 25 0
      pqai_agent/agent_config_manager.py
  7. 38 8
      pqai_agent/agent_service.py
  8. 12 38
      pqai_agent/agents/message_push_agent.py
  9. 12 34
      pqai_agent/agents/message_reply_agent.py
  10. 53 0
      pqai_agent/agents/multimodal_chat_agent.py
  11. 11 1
      pqai_agent/agents/simple_chat_agent.py
  12. 36 0
      pqai_agent/clients/relation_stage_client.py
  13. 25 22
      pqai_agent/configs/dev.yaml
  14. 24 22
      pqai_agent/configs/prod.yaml
  15. 30 0
      pqai_agent/data_models/agent_configuration.py
  16. 22 0
      pqai_agent/data_models/service_module.py
  17. 14 5
      pqai_agent/dialogue_manager.py
  18. 2 2
      pqai_agent/history_dialogue_service.py
  19. 55 0
      pqai_agent/prompt_templates.py
  20. 31 7
      pqai_agent/push_service.py
  21. 3 2
      pqai_agent/response_type_detector.py
  22. 27 0
      pqai_agent/service_module_manager.py
  23. 42 0
      pqai_agent/toolkit/__init__.py
  24. 61 0
      pqai_agent/toolkit/coze_function_tools.py
  25. 2 0
      pqai_agent/toolkit/image_describer.py
  26. 2 0
      pqai_agent/toolkit/message_notifier.py
  27. 3 0
      pqai_agent/toolkit/pq_video_searcher.py
  28. 2 0
      pqai_agent/toolkit/search_toolkit.py
  29. 27 0
      pqai_agent/toolkit/tool_registry.py
  30. 0 25
      pqai_agent/user_manager.py
  31. 85 21
      pqai_agent/user_profile_extractor.py
  32. 18 0
      pqai_agent/utils/agent_abtest_utils.py
  33. 10 2
      pqai_agent/utils/db_utils.py
  34. 11 6
      pqai_agent/utils/prompt_utils.py
  35. 5 3
      pqai_agent_server/agent_server.py
  36. 220 4
      pqai_agent_server/api_server.py
  37. 13 27
      pqai_agent_server/utils/prompt_util.py
  38. 4 2
      scripts/disable_user_daily_push.py
  39. 2 1
      scripts/profile_cleaner.py
  40. 4 3
      scripts/resend_lost_message.py

+ 1 - 0
.gitignore

@@ -1,3 +1,4 @@
+image_descriptions_cache/
 # ---> Python
 # Byte-compiled / optimized / DLL files
 __pycache__/

+ 0 - 0
pqai_agent/abtest/__init__.py


+ 288 - 0
pqai_agent/abtest/client.py

@@ -0,0 +1,288 @@
+# Python: experiment_client.py
+import threading
+from typing import List, Dict, Optional
+from alibabacloud_paiabtest20240119.client import Client
+from pqai_agent.abtest.models import Project, Domain, Layer, Experiment, ExperimentVersion, \
+    ExperimentContext, ExperimentResult
+from alibabacloud_paiabtest20240119.models import ListProjectsRequest, ListProjectsResponseBodyProjects, \
+    ListDomainsRequest, ListFeaturesRequest, ListLayersRequest, ListExperimentsRequest, ListExperimentVersionsRequest
+from pqai_agent.logging_service import logger
+
+class ExperimentClient:
+    def __init__(self, client: Client):
+        self.client = client
+        self.project_map = {}
+        self.running = False
+        self.worker_thread = None
+
+    def start(self):
+        self.running = True
+        self.worker_thread = threading.Thread(target=self._worker_loop)
+        self.worker_thread.start()
+
+    def shutdown(self, blocking=False):
+        self.running = False
+        if self.worker_thread:
+            if blocking:
+                self.worker_thread.join()
+            else:
+                self.worker_thread = None
+
+    def _worker_loop(self):
+        while self.running:
+            # Sleep or wait for a condition to avoid busy waiting
+            threading.Event().wait(60)
+            try:
+                self.load_experiment_data()
+                logger.debug("Experiment data loaded successfully.")
+            except Exception as e:
+                logger.error(f"Error loading experiment data: {e}")
+        logger.info("ExperimentClient worker thread exit.")
+
+    def load_experiment_data(self):
+        project_map = {}
+
+        # 获取所有项目
+        list_project_req = ListProjectsRequest()
+        list_project_req.all = True
+        projects_response = self.client.list_projects(list_project_req)
+        projects: List[ListProjectsResponseBodyProjects] = projects_response.body.projects
+
+        for project_data in projects:
+            project = Project(name=project_data.name, project_id=project_data.project_id)
+            # logger.debug(f"[Project] {project_data}")
+
+            # 获取项目的域
+            list_domain_req = ListDomainsRequest()
+            list_domain_req.project_id = project.id
+            domains_response = self.client.list_domains(list_domain_req)
+
+            for domain_data in domains_response.body.domains:
+                domain = Domain(domain_id=domain_data.domain_id,
+                                name=domain_data.name,
+                                flow=domain_data.flow,
+                                buckets=domain_data.buckets,
+                                bucket_type=domain_data.bucket_type,
+                                is_default_domain=domain_data.is_default_domain,
+                                exp_layer_id=domain_data.layer_id,
+                                debug_users=domain_data.debug_users)
+                # logger.debug(f"[Domain] {domain_data}")
+                if domain.is_default_domain:
+                    project.set_default_domain(domain)
+                domain.init()
+                project.add_domain(domain)
+
+                # 获取域的特性(暂无实际用处)
+                list_feature_req = ListFeaturesRequest()
+                list_feature_req.domain_id = str(domain.id)
+                features_response = self.client.list_features(list_feature_req)
+                for feature_data in features_response.body.features:
+                    domain.add_feature(feature_data)
+
+                # 获取域的层
+                list_layer_req = ListLayersRequest()
+                list_layer_req.domain_id = str(domain.id)
+                layers_response = self.client.list_layers(list_layer_req)
+                for layer_data in layers_response.body.layers:
+                    # logger.debug(f'[Layer] {layer_data}')
+                    layer = Layer(id=int(layer_data.layer_id), name=layer_data.name)
+                    project.add_layer(layer)
+
+                    # 获取层的实验
+                    list_experiment_req = ListExperimentsRequest()
+                    list_experiment_req.layer_id = str(layer.id)
+                    # FIXME: magic code
+                    list_experiment_req.status = 'Running'
+                    experiments_response = self.client.list_experiments(list_experiment_req)
+
+                    for experiment_data in experiments_response.body.experiments:
+                        # logger.debug(f'[Experiment] {experiment_data}')
+                        # FIXME: Java SDK中有特殊处理
+                        crowd_ids = experiment_data.crowd_ids if experiment_data.crowd_ids else ""
+                        experiment = Experiment(id=int(experiment_data.experiment_id), bucket_type=experiment_data.bucket_type,
+                                                flow=experiment_data.flow, buckets=experiment_data.buckets,
+                                                crowd_ids=crowd_ids.split(","),
+                                                debug_users=experiment_data.debug_users,
+                                                filter_condition=experiment_data.condition
+                                                )
+                        experiment.init()
+
+                        # 获取实验的版本
+                        list_exp_ver_req = ListExperimentVersionsRequest()
+                        list_exp_ver_req.experiment_id = int(experiment.id)
+                        versions_response = self.client.list_experiment_versions(list_exp_ver_req)
+                        for version_data in versions_response.body.experiment_versions:
+                            # logger.debug(f'[ExperimentVersion] {version_data}')
+                            version = ExperimentVersion(exp_version_id=version_data.experiment_version_id,
+                                                        exp_id=experiment.id,
+                                                        flow=int(version_data.flow),
+                                                        buckets=version_data.buckets,
+                                                        debug_users=version_data.debug_users,
+                                                        exp_version_name=version_data.name,
+                                                        config=version_data.config)
+                            version.init()
+                            experiment.add_experiment_version(version)
+                        layer.add_experiment(experiment)
+                    domain.add_layer(layer)
+
+            # 建立layer-domain的反向映射,从而形成嵌套结构
+            for domain in project.domains:
+                if domain.is_default_domain:
+                    continue
+                # domain.exp_layer_id是domain所属的layer id
+                layer: Layer = project.layer_map.get(domain.exp_layer_id, None)
+                if not layer:
+                    continue
+                layer.add_domain(domain)
+
+            project_map[project.name] = project
+
+        self.project_map = project_map
+
+    def match_experiment(self, project_name, experiment_context) -> ExperimentResult:
+        if project_name not in self.project_map:
+            experiment_result = ExperimentResult(experiment_context=experiment_context)
+            experiment_result.project_name = project_name
+            return experiment_result
+
+        project = self.project_map[project_name]
+        experiment_result = ExperimentResult(project=project, experiment_context=experiment_context)
+
+        self._match_domain(project.default_domain, experiment_result)
+        matched_versions = [str(ver.id) for ver in experiment_result.experiment_versions]
+        logger.debug(f"Matched experiment, uid[{experiment_context.uid}], versions[{','.join(matched_versions)}], params: {experiment_result.params}")
+        experiment_result.init()
+        return experiment_result
+
+    def _match_domain(self, domain: Domain, experiment_result: ExperimentResult):
+        if not domain:
+            return
+
+        for feature in domain.features:
+            if feature.match(experiment_result.experiment_context):
+                experiment_result.add_params(feature.params)
+
+        for layer in domain.layers:
+            self._match_layer(layer, experiment_result)
+
+    def _match_layer(self, layer, experiment_result):
+        if not layer:
+            return
+
+        for experiment in layer.experiments:
+            if experiment.match_debug_users(experiment_result.experiment_context):
+                logger.debug(f"Matched debug user for experiment: {experiment.id}")
+                self._match_experiment(experiment, experiment_result)
+                return
+
+        for domain in layer.domains:
+            if domain.match_debug_users(experiment_result.experiment_context):
+                # logger.debug(f"Matched debug user for domain: {domain.id}")
+                self._match_domain(domain, experiment_result)
+
+        hash_key = f"{experiment_result.experiment_context.uid}_LAYER{layer.id}"
+        hash_value = self._hash_value(hash_key)
+
+        exp_context = ExperimentContext(uid=hash_value,
+                                        filter_params=experiment_result.experiment_context.filter_params)
+
+        matched_experiments = [exp for exp in layer.experiments if exp.match(exp_context)]
+
+        if len(matched_experiments) == 1:
+            self._match_experiment(matched_experiments[0], experiment_result)
+        elif len(matched_experiments) > 1:
+            for experiment in matched_experiments:
+                if experiment.bucket_type == "Condition":
+                    self._match_experiment(experiment, experiment_result)
+                    return
+            logger.warning(f"Warning: Multiple experiments matched under layer {layer.id}.")
+            self._match_experiment(matched_experiments[0], experiment_result)
+
+        matched_domains = []
+        for domain in layer.domains:
+            if domain.match(exp_context):
+                logger.debug(f"Matched domain {domain.id} for uid {experiment_result.experiment_context.uid}.")
+                matched_domains.append(domain)
+        if len(matched_domains) == 1:
+            self._match_domain(matched_domains[0], experiment_result)
+            return
+        elif len(matched_domains) > 1:
+            for domain in matched_domains:
+                if domain.bucket_type == "Condition":
+                    self._match_domain(domain, experiment_result)
+                    return
+            logger.warning(f"Warning: Multiple domains matched under layer {layer.id}, using the first one.")
+            self._match_domain(matched_domains[0], experiment_result)
+            return
+
+    def _match_experiment(self, experiment: Experiment, experiment_result: ExperimentResult):
+        if not experiment:
+            return
+
+        for version in experiment.experiment_versions:
+            if version.match_debug_users(experiment_result.experiment_context):
+                logger.debug(f"Matched debug user for experiment version: {version.id}")
+                experiment_result.add_params(version.params)
+                experiment_result.add_experiment_version(version)
+                return
+
+        hash_key = f"{experiment_result.experiment_context.uid}_EXPERIMENT{experiment.id}"
+        hash_value = self._hash_value(hash_key)
+
+        exp_context = ExperimentContext(uid=hash_value,
+                                        filter_params=experiment_result.experiment_context.filter_params)
+
+        for version in experiment.experiment_versions:
+            if version.match(exp_context):
+                experiment_result.add_params(version.params)
+                experiment_result.add_experiment_version(version)
+                return
+
+    def _hash_value(self, hash_key) -> int:
+        import hashlib
+        from pqai_agent.abtest.models import FNV
+        md5_hash = hashlib.md5(hash_key.encode()).hexdigest().encode()
+        return FNV.fnv1_64(md5_hash)
+
+    def __del__(self):
+        if self.running and self.worker_thread:
+            self.shutdown()
+
+g_client: Optional[ExperimentClient] = None
+
+def get_client():
+    global g_client
+    if not g_client:
+        ak_id = 'LTAI5tFGqgC8f3mh1fRCrAEy'
+        ak_secret = 'XhOjK9XmTYRhVAtf6yii4s4kZwWzvV'
+        region = 'cn-hangzhou'
+        from alibabacloud_tea_openapi.models import Config
+        endpoint = f"paiabtest.{region}.aliyuncs.com"
+        conf = Config(access_key_id=ak_id, access_key_secret=ak_secret, region_id=region,
+                      endpoint=endpoint, type="access_key")
+        api_client = Client(conf)
+        g_client = ExperimentClient(api_client)
+        g_client.load_experiment_data()
+        g_client.start()
+    return g_client
+
+if __name__ == '__main__':
+    from pqai_agent.logging_service import setup_root_logger
+    setup_root_logger(level='DEBUG')
+    experiment_client = get_client()
+
+    for project_name, project in experiment_client.project_map.items():
+        print(f"Project: {project_name}, ID: {project.id}")
+        for domain in project.domains:
+            print(f"  Domain: {domain.id}, Default: {domain.is_default_domain}")
+            for layer in domain.layers:
+                print(f"    Layer: {layer.id}")
+                for experiment in layer.experiments:
+                    print(f"      Experiment: {experiment.id}")
+                    for version in experiment.experiment_versions:
+                        print(f"        Version: {version.id}, Config: {version.config}")
+
+    exp_context = ExperimentContext(uid='123')
+    result = experiment_client.match_experiment('PQAgent', exp_context)
+    print(result)
+    g_client.shutdown()

+ 259 - 0
pqai_agent/abtest/models.py

@@ -0,0 +1,259 @@
+from typing import List, Dict, Optional, Set
+import json
+from dataclasses import dataclass, field
+import hashlib
+
+from pqai_agent.logging_service import logger
+
+
+class FNV:
+    INIT64 = int("cbf29ce484222325", 16)
+    PRIME64 = int("100000001b3", 16)
+    MOD64 = 2**64
+
+    @staticmethod
+    def fnv1_64(data: bytes) -> int:
+        hash_value = FNV.INIT64
+        for byte in data:
+            hash_value = (hash_value * FNV.PRIME64) % FNV.MOD64
+            hash_value = hash_value ^ byte
+        return hash_value
+
+class DiversionBucket:
+    def match(self, experiment_context):
+        raise NotImplementedError("Subclasses must implement this method")
+
+class UidDiversionBucket(DiversionBucket):
+    def __init__(self, total_buckets: int, buckets: str):
+        self.total_buckets = total_buckets
+        if buckets:
+            self.buckets = set(map(int, buckets.split(",")))
+        else:
+            self.buckets = set()
+
+    def match(self, experiment_context):
+        uid_hash = int(experiment_context.uid)
+        bucket = uid_hash % self.total_buckets
+        # print(f"Matching UID {experiment_context.uid} with hash {uid_hash} to bucket {bucket} in {self.buckets}")
+        return bucket in self.buckets
+
+
+class FilterDiversionBucket(DiversionBucket):
+    def __init__(self, filter_condition: str):
+        self.filter_condition = filter_condition
+
+    def match(self, experiment_context):
+        raise NotImplementedError("not implemented")
+
+class Feature:
+    def __init__(self, params=None):
+        self.params = params
+
+    def init(self):
+        # Initialize feature-specific logic
+        pass
+
+
+class ExperimentContext:
+    def __init__(self, uid=None, filter_params=None):
+        self.uid = uid
+        self.filter_params = filter_params or {}
+
+    def __str__(self):
+        return f"ExperimentContext(uid={self.uid}, filter_params={self.filter_params})"
+
+class Domain:
+    def __init__(self, domain_id, name, flow: int, buckets: str, bucket_type: str, debug_crowd_ids=None, is_default_domain=False, exp_layer_id=None,
+                 debug_users=""):
+        self.id = int(domain_id)
+        self.name = name
+        self.debug_crowd_ids = debug_crowd_ids
+        self.is_default_domain = is_default_domain
+        self.exp_layer_id = int(exp_layer_id) if exp_layer_id is not None else None
+        self.features = []
+        self.layers = []
+        self.debug_users = debug_users
+        self.flow = flow
+        self.buckets = buckets
+        self.diversion_bucket = None
+        self.bucket_type = bucket_type
+        self.debug_user_set = set()
+
+    def add_debug_users(self, users: List[str]):
+        self.debug_user_set.update(users)
+
+    def match_debug_users(self, experiment_context):
+        return experiment_context.uid in self.debug_user_set
+
+    def add_feature(self, feature: Feature):
+        self.features.append(feature)
+
+    def add_layer(self, layer):
+        self.layers.append(layer)
+
+    def init(self):
+        self.debug_user_set.update(self.debug_users.split(","))
+        self.diversion_bucket = UidDiversionBucket(100, self.buckets)
+
+    def match(self, experiment_context):
+        if self.flow == 0:
+            return False
+        elif self.flow == 100:
+            return True
+        if self.diversion_bucket:
+            return self.diversion_bucket.match(experiment_context)
+        return False
+
+
+@dataclass
+class Layer:
+    id: int
+    name: str
+    experiments: List['Experiment'] = field(default_factory=list)
+    domains: List[Domain] = field(default_factory=list)
+
+    def add_experiment(self, experiment):
+        self.experiments.append(experiment)
+
+    def add_domain(self, domain):
+        self.domains.append(domain)
+
+
+@dataclass
+class Experiment:
+    id: int
+    flow: int
+    crowd_ids: List[str]
+    debug_users: str
+    buckets: str
+    filter_condition: str
+    bucket_type: str = "Random"
+    debug_user_set: Set[str] = field(default_factory=set)
+    diversion_bucket: Optional[DiversionBucket] = None
+    experiment_versions: List['ExperimentVersion'] = field(default_factory=list)
+
+    def add_debug_users(self, users: List[str]):
+        self.debug_user_set.update(users)
+
+    def match_debug_users(self, experiment_context):
+        return experiment_context.uid in self.debug_user_set
+
+    def add_experiment_version(self, version):
+        self.experiment_versions.append(version)
+
+    def match(self, experiment_context: ExperimentContext) -> bool:
+        if self.bucket_type == "Random":
+            if self.flow == 0:
+                return False
+            elif self.flow == 100:
+                return True
+        if self.diversion_bucket:
+            return self.diversion_bucket.match(experiment_context)
+        return False
+
+    def init(self):
+        # 初始化 debug_user_map
+        if self.debug_users:
+            self.debug_user_set.update(self.debug_users.split(","))
+        # 初始化 diversion_bucket
+        if self.bucket_type == "Random":  # ExpBucketTypeRand
+            self.diversion_bucket = UidDiversionBucket(100, self.buckets)
+        elif self.bucket_type == "Condition" and self.filter_condition:  # ExpBucketTypeCond
+            self.diversion_bucket = FilterDiversionBucket(self.filter_condition)
+
+
+class ExperimentVersion:
+    def __init__(self, exp_version_id, flow, buckets: str, exp_id: int, exp_version_name=None,
+                 debug_users: str = '', config=None, debug_crowd_ids=None):
+        self.id = int(exp_version_id)
+        self.exp_version_name = exp_version_name
+        self.exp_id = int(exp_id)
+        self.config = config
+        self.debug_crowd_ids = debug_crowd_ids
+        self.debug_users = debug_users
+        self.params = {}
+        self.flow = flow
+        self.buckets = buckets
+        self.debug_user_set = set()
+        self.diversion_bucket = None
+
+    def add_debug_users(self, users: List[str]):
+        self.debug_user_set.update(users)
+
+    def match_debug_users(self, experiment_context):
+        return experiment_context.uid in self.debug_user_set
+
+    def match(self, experiment_context: ExperimentContext):
+        if self.flow == 0:
+            return False
+        elif self.flow == 100:
+            return True
+        if self.diversion_bucket:
+            return self.diversion_bucket.match(experiment_context)
+        return False
+
+    def init(self):
+        self.debug_user_set.update(self.debug_users.split(","))
+        self.diversion_bucket = UidDiversionBucket(100, self.buckets)
+        params = json.loads(self.config)
+        for kv in params:
+            self.params[kv['key']] = kv['value']
+
+
+class Project:
+    def __init__(self, name=None, project_id=None):
+        self.name = name
+        self.id = int(project_id)
+        self.domains = []
+        self.layers = []
+        self.default_domain : Optional[Domain] = None
+        self.layer_map = {}
+        self.domain_map = {}
+
+    def add_domain(self, domain):
+        self.domains.append(domain)
+        self.domain_map[domain.id] = domain
+
+    def add_layer(self, layer):
+        self.layers.append(layer)
+        self.layer_map[layer.id] = layer
+
+    def set_default_domain(self, domain: Domain):
+        self.default_domain = domain
+
+
+class ExperimentResult:
+    def __init__(self, project=None, experiment_context=None):
+        self.project = project
+        if project:
+            self.project_name = project.name
+        else:
+            self.project_name = None
+        self.experiment_context = experiment_context
+        self.params = {}
+        self.experiment_versions: List[ExperimentVersion] = []
+        self.exp_id = ""
+
+    def add_params(self, params: Dict[str, str]):
+        for key, value in params.items():
+            if key in self.params:
+                logger.warning(f"Duplicate key '{key}' in params, overwriting value: {self.params[key]} with {value}")
+            self.params[key] = value
+
+    def add_experiment_version(self, version):
+        self.experiment_versions.append(version)
+
+    def init(self):
+        buf = []
+        if self.project:
+            buf.append(f"ER{self.project.id}")
+
+        if self.experiment_versions:
+            for experiment_version in self.experiment_versions:
+                buf.append(f"_E{experiment_version.exp_id}")
+                buf.append(f"#EV{experiment_version.id}")
+
+            self.exp_id = "".join(buf)
+
+    def __str__(self):
+        return f"ExperimentResult(project={self.project_name}, params={self.params}, experiment_context={self.experiment_context}, experiment_versions={self.experiment_versions})"

+ 7 - 0
pqai_agent/abtest/utils.py

@@ -0,0 +1,7 @@
+from pqai_agent.abtest.models import ExperimentContext
+from pqai_agent.abtest.client import get_client
+
+def get_abtest_info(uid: str):
+    client = get_client()
+    exp_ctx = ExperimentContext(uid=uid)
+    return client.match_experiment('PQAgent', exp_ctx)

+ 25 - 0
pqai_agent/agent_config_manager.py

@@ -0,0 +1,25 @@
+from typing import Dict, Optional
+
+from pqai_agent.data_models.agent_configuration import AgentConfiguration
+from pqai_agent.logging_service import logger
+
+class AgentConfigManager:
+    def __init__(self, session_maker):
+        self.session_maker = session_maker
+        self.agent_configs: Dict[int, AgentConfiguration] = {}
+        self.refresh_configs()
+
+    def refresh_configs(self):
+        try:
+            with self.session_maker() as session:
+                data = session.query(AgentConfiguration).filter_by(is_delete=False).all()
+                agent_configs = {}
+                for config in data:
+                    agent_configs[config.id] = config
+            self.agent_configs = agent_configs
+            logger.debug(f"Refreshed agent configurations: {list(self.agent_configs.keys())}")
+        except Exception as e:
+            logger.error(f"Failed to refresh agent configurations: {e}")
+
+    def get_config(self, agent_id: int) -> Optional[AgentConfiguration]:
+        return self.agent_configs.get(agent_id, None)

+ 38 - 8
pqai_agent/agent_service.py

@@ -18,6 +18,8 @@ from apscheduler.schedulers.background import BackgroundScheduler
 from sqlalchemy.orm import sessionmaker
 
 from pqai_agent import configs
+from pqai_agent.abtest.utils import get_abtest_info
+from pqai_agent.agent_config_manager import AgentConfigManager
 from pqai_agent.agents.message_reply_agent import MessageReplyAgent
 from pqai_agent.configs import apollo_config
 from pqai_agent.exceptions import NoRetryException
@@ -29,11 +31,14 @@ from pqai_agent.history_dialogue_service import HistoryDialogueDatabase
 from pqai_agent.push_service import PushScanThread, PushTaskWorkerPool
 from pqai_agent.rate_limiter import MessageSenderRateLimiter
 from pqai_agent.response_type_detector import ResponseTypeDetector
+from pqai_agent.service_module_manager import ServiceModuleManager
+from pqai_agent.toolkit import get_tools
 from pqai_agent.user_manager import UserManager, UserRelationManager
 from pqai_agent.message_queue_backend import MessageQueueBackend, AliyunRocketMQQueueBackend
 from pqai_agent.user_profile_extractor import UserProfileExtractor
 from pqai_agent.mq_message import MessageType, MqMessage, MessageChannel
-from pqai_agent.utils.db_utils import create_sql_engine
+from pqai_agent.utils.agent_abtest_utils import get_agent_abtest_config
+from pqai_agent.utils.db_utils import create_ai_agent_db_engine
 
 
 class AgentService:
@@ -59,9 +64,9 @@ class AgentService:
         self.user_profile_extractor = UserProfileExtractor()
         self.response_type_detector = ResponseTypeDetector()
         self.agent_registry: Dict[str, DialogueManager] = {}
-        self.history_dialogue_db = HistoryDialogueDatabase(self.config['storage']['user']['mysql'])
-        self.agent_db_engine = create_sql_engine(self.config['storage']['agent_state']['mysql'])
-        self.AgentDBSession = sessionmaker(bind=self.agent_db_engine)
+        self.history_dialogue_db = HistoryDialogueDatabase(self.config['database']['ai_agent'])
+        self.agent_db_engine = create_ai_agent_db_engine()
+        self.agent_db_session_maker = sessionmaker(bind=self.agent_db_engine)
 
         chat_config = self.config['chat_api']['openai_compatible']
         self.text_model_name = chat_config['text_model']
@@ -98,6 +103,10 @@ class AgentService:
 
         self.send_rate_limiter = MessageSenderRateLimiter()
 
+        # Agent配置和实验相关
+        self.service_module_manager = ServiceModuleManager(self.agent_db_session_maker)
+        self.agent_config_manager = AgentConfigManager(self.agent_db_session_maker)
+
     def setup_initiative_conversations(self, schedule_params: Optional[Dict] = None):
         if not schedule_params:
             schedule_params = {'hour': '8,16,20'}
@@ -123,6 +132,11 @@ class AgentService:
             )
             self.msg_scheduler_thread = threading.Thread(target=self.process_scheduler_events)
             self.msg_scheduler_thread.start()
+        # 定时更新模块配置任务
+        self.scheduler.add_job(self.service_module_manager.refresh_configs, 'interval',
+                               seconds=60, id='refresh_module_configs')
+        self.scheduler.add_job(self.agent_config_manager.refresh_configs, 'interval',
+                               seconds=60, id='refresh_agent_configs')
         self.scheduler.start()
 
     def process_scheduler_events(self):
@@ -149,7 +163,7 @@ class AgentService:
         agent_key = 'agent_{}_{}'.format(staff_id, user_id)
         if agent_key not in self.agent_registry:
             self.agent_registry[agent_key] = DialogueManager(
-                staff_id, user_id, self.user_manager, self.agent_state_cache, self.AgentDBSession)
+                staff_id, user_id, self.user_manager, self.agent_state_cache, self.agent_db_session_maker)
         agent = self.agent_registry[agent_key]
         agent.refresh_profile()
         return agent
@@ -240,7 +254,12 @@ class AgentService:
             sys.exit(0)
 
     def _update_user_profile(self, user_id, user_profile, recent_dialogue: List[Dict]):
-        profile_to_update = self.user_profile_extractor.extract_profile_info(user_profile, recent_dialogue)
+        agent_info = get_agent_abtest_config('profile_extractor', user_id, self.service_module_manager, self.agent_config_manager)
+        if agent_info:
+            prompt_template = agent_info.task_prompt
+        else:
+            prompt_template = None
+        profile_to_update = self.user_profile_extractor.extract_profile_info_v2(user_profile, recent_dialogue, prompt_template)
         if not profile_to_update:
             logger.debug("user_id: {}, no profile info extracted".format(user_id))
             return
@@ -319,7 +338,7 @@ class AgentService:
                     item["type"] = message_type
         if contents:
             for response in contents:
-                self.send_multimodal_response(staff_id, user_id, response, skip_check=True)
+                self.send_multimodal_response(staff_id, user_id, response)
             agent.update_last_active_interaction_time(current_ts)
         else:
             logger.debug(f"staff[{staff_id}], user[{user_id}]: no messages to send")
@@ -396,8 +415,12 @@ class AgentService:
             return
 
         push_scan_threads = []
+        whitelist_staffs = apollo_config.get_json_value("agent_initiate_whitelist_staffs", [])
         for staff in self.user_relation_manager.list_staffs():
             staff_id = staff['third_party_user_id']
+            if staff_id not in whitelist_staffs:
+                logger.info(f"staff[{staff_id}] is not in whitelist, skip")
+                continue
             scan_thread = threading.Thread(target=PushScanThread(
                 staff_id, self, self.push_task_rmq_topic, self.push_task_producer).run)
             scan_thread.start()
@@ -435,7 +458,14 @@ class AgentService:
             return None
 
     def _get_chat_response_v2(self, main_agent: DialogueManager) -> List[Dict]:
-        chat_agent = MessageReplyAgent()
+        agent_config = get_agent_abtest_config('chat', main_agent.user_id,
+                                               self.service_module_manager, self.agent_config_manager)
+        if agent_config:
+            chat_agent = MessageReplyAgent(model=agent_config.execution_model,
+                                           system_prompt=agent_config.system_prompt,
+                                           tools=get_tools(agent_config.tools))
+        else:
+            chat_agent = MessageReplyAgent()
         chat_responses = chat_agent.generate_message(
             context=main_agent.get_prompt_context(None),
             dialogue_history=main_agent.dialogue_history[-100:]

+ 12 - 38
pqai_agent/agents/message_push_agent.py

@@ -1,11 +1,8 @@
-import datetime
-import time
 from typing import Optional, List, Dict
 
-from pqai_agent.agents.simple_chat_agent import SimpleOpenAICompatibleChatAgent
+from pqai_agent.agents.multimodal_chat_agent import MultiModalChatAgent
 from pqai_agent.chat_service import VOLCENGINE_MODEL_DEEPSEEK_V3
 from pqai_agent.logging_service import logger
-from pqai_agent.mq_message import MessageType
 from pqai_agent.toolkit.function_tool import FunctionTool
 from pqai_agent.toolkit.image_describer import ImageDescriber
 from pqai_agent.toolkit.message_notifier import MessageNotifier
@@ -99,6 +96,7 @@ QUERY_PROMPT_TEMPLATE = """现在,请通过多步思考,以客服的角色
 - 姓名:{name}
 - 头像:{avatar}
 - 偏好的称呼:{preferred_nickname}
+- 性别:{gender}
 - 年龄:{age}
 - 地区:{region}
 - 健康状况:{health_conditions}
@@ -120,48 +118,24 @@ QUERY_PROMPT_TEMPLATE = """现在,请通过多步思考,以客服的角色
 Now, start to process your task. Please think step by step.
  """
 
-class MessagePushAgent(SimpleOpenAICompatibleChatAgent):
+class MessagePushAgent(MultiModalChatAgent):
     """A specialized agent for message push tasks."""
 
     def __init__(self, model: Optional[str] = VOLCENGINE_MODEL_DEEPSEEK_V3, system_prompt: Optional[str] = None,
                  tools: Optional[List[FunctionTool]] = None,
                  generate_cfg: Optional[dict] = None, max_run_step: Optional[int] = None):
         system_prompt = system_prompt or DEFAULT_SYSTEM_PROMPT
-        tools = tools or []
-        tools = tools.copy()
-        tools.extend([
-            *ImageDescriber().get_tools(),
-            *MessageNotifier().get_tools(),
-        ])
+        if tools is None:
+            tools = [
+                *ImageDescriber().get_tools(),
+                *MessageNotifier().get_tools()
+            ]
         super().__init__(model, system_prompt, tools, generate_cfg, max_run_step)
 
-    def generate_message(self, context: Dict, dialogue_history: List[Dict], timestamp_type: str='ms') -> str:
-        formatted_dialogue = MessagePushAgent.compose_dialogue(dialogue_history, timestamp_type)
-        query = QUERY_PROMPT_TEMPLATE.format(**context, dialogue_history=formatted_dialogue)
-        self.run(query)
-        for tool_call in reversed(self.tool_call_records):
-            if tool_call['name'] == MessageNotifier.message_notify_user.__name__:
-                # time.sleep(1)
-                print("Function call return", tool_call['arguments']['message'])
-                return tool_call['arguments']['message']
-        return ''
-
-    @staticmethod
-    def compose_dialogue(dialogue: List[Dict], timestamp_type: str='ms') -> str:
-        role_map = {'user': '用户', 'assistant': '客服'}
-        messages = []
-        for msg in dialogue:
-            if not msg['content']:
-                continue
-            if msg['role'] not in role_map:
-                continue
-            if timestamp_type == 'ms':
-                format_dt = datetime.datetime.fromtimestamp(msg['timestamp'] / 1000).strftime('%Y-%m-%d %H:%M:%S')
-            else:
-                format_dt = datetime.datetime.fromtimestamp(msg['timestamp']).strftime('%Y-%m-%d %H:%M:%S')
-            msg_type = msg.get('type', MessageType.TEXT).description
-            messages.append('[{}][{}][{}]{}'.format(role_map[msg['role']], format_dt, msg_type, msg['content']))
-        return '\n'.join(messages)
+    def generate_message(self, context: Dict, dialogue_history: List[Dict],
+                         query_prompt_template: Optional[str] = None) -> List[Dict]:
+        query_prompt_template = query_prompt_template or QUERY_PROMPT_TEMPLATE
+        return self._generate_message(context, dialogue_history, query_prompt_template)
 
 class DummyMessagePushAgent(MessagePushAgent):
     """A dummy agent for testing purposes."""

+ 12 - 34
pqai_agent/agents/message_reply_agent.py

@@ -1,10 +1,8 @@
-import datetime
 from typing import Optional, List, Dict
 
-from pqai_agent.agents.simple_chat_agent import SimpleOpenAICompatibleChatAgent
+from pqai_agent.agents.multimodal_chat_agent import MultiModalChatAgent
 from pqai_agent.chat_service import VOLCENGINE_MODEL_DEEPSEEK_V3
 from pqai_agent.logging_service import logger
-from pqai_agent.mq_message import MessageType
 from pqai_agent.toolkit.function_tool import FunctionTool
 from pqai_agent.toolkit.image_describer import ImageDescriber
 from pqai_agent.toolkit.message_notifier import MessageNotifier
@@ -86,44 +84,24 @@ QUERY_PROMPT_TEMPLATE = """现在,请以客服的角色分析以下会话并
 Now, start to process your task. Please think step by step.
  """
 
-class MessageReplyAgent(SimpleOpenAICompatibleChatAgent):
+class MessageReplyAgent(MultiModalChatAgent):
     """A specialized agent for message reply tasks."""
 
     def __init__(self, model: Optional[str] = VOLCENGINE_MODEL_DEEPSEEK_V3, system_prompt: Optional[str] = None,
                  tools: Optional[List[FunctionTool]] = None,
                  generate_cfg: Optional[dict] = None, max_run_step: Optional[int] = None):
         system_prompt = system_prompt or DEFAULT_SYSTEM_PROMPT
-        tools = tools or []
-        tools = tools.copy()
-        tools.extend([
-            *ImageDescriber().get_tools(),
-            *MessageNotifier().get_tools()
-        ])
+        if tools is None:
+            tools = [
+                *ImageDescriber().get_tools(),
+                *MessageNotifier().get_tools()
+            ]
         super().__init__(model, system_prompt, tools, generate_cfg, max_run_step)
 
-    def generate_message(self, context: Dict, dialogue_history: List[Dict]) -> List[Dict]:
-        formatted_dialogue = MessageReplyAgent.compose_dialogue(dialogue_history)
-        query = QUERY_PROMPT_TEMPLATE.format(**context, dialogue_history=formatted_dialogue)
-        self.run(query)
-        result = []
-        for tool_call in self.tool_call_records:
-            if tool_call['name'] == MessageNotifier.output_multimodal_message.__name__:
-                result.append(tool_call['arguments']['message'])
-        return result
-
-    @staticmethod
-    def compose_dialogue(dialogue: List[Dict]) -> str:
-        role_map = {'user': '用户', 'assistant': '客服'}
-        messages = []
-        for msg in dialogue:
-            if not msg['content']:
-                continue
-            if msg['role'] not in role_map:
-                continue
-            format_dt = datetime.datetime.fromtimestamp(msg['timestamp'] / 1000).strftime('%Y-%m-%d %H:%M:%S')
-            msg_type = msg.get('type', MessageType.TEXT).description
-            messages.append('[{}][{}][{}]{}'.format(role_map[msg['role']], format_dt, msg_type, msg['content']))
-        return '\n'.join(messages)
+    def generate_message(self, context: Dict, dialogue_history: List[Dict],
+                         query_prompt_template: Optional[str] = None) -> List[Dict]:
+        query_prompt_template = query_prompt_template or QUERY_PROMPT_TEMPLATE
+        return self._generate_message(context, dialogue_history, query_prompt_template)
 
 class DummyMessageReplyAgent(MessageReplyAgent):
     """A dummy agent for testing purposes."""
@@ -131,7 +109,7 @@ class DummyMessageReplyAgent(MessageReplyAgent):
     def __init__(self, *args, **kwargs):
         super().__init__(*args, **kwargs)
 
-    def generate_message(self, context: Dict, dialogue_history: List[Dict]) -> List[Dict]:
+    def generate_message(self, context: Dict, dialogue_history: List[Dict], query_prompt_template = None) -> List[Dict]:
         logger.debug(f"DummyMessageReplyAgent.generate_message called, context: {context}")
         result = [{"type": "text", "content": "测试消息: {agent_name} -> {nickname}".format(**context)},
                   {"type": "image", "content": "https://example.com/test_image.jpg"}]

+ 53 - 0
pqai_agent/agents/multimodal_chat_agent.py

@@ -0,0 +1,53 @@
+import datetime
+from abc import abstractmethod
+from typing import Optional, List, Dict
+
+from pqai_agent.agents.simple_chat_agent import SimpleOpenAICompatibleChatAgent
+from pqai_agent.logging_service import logger
+from pqai_agent.mq_message import MessageType
+from pqai_agent.toolkit import get_tool
+from pqai_agent.toolkit.function_tool import FunctionTool
+from pqai_agent.toolkit.message_notifier import MessageNotifier
+
+
+class MultiModalChatAgent(SimpleOpenAICompatibleChatAgent):
+    """A specialized agent for message reply tasks."""
+
+    def __init__(self, model: str, system_prompt: str,
+                 tools: Optional[List[FunctionTool]] = None,
+                 generate_cfg: Optional[dict] = None, max_run_step: Optional[int] = None):
+        super().__init__(model, system_prompt, tools, generate_cfg, max_run_step)
+        if 'output_multimodal_message' not in self.tool_map:
+            self.add_tool(get_tool('output_multimodal_message'))
+        if 'message_notify_user' not in self.tool_map:
+            self.add_tool(get_tool('message_notify_user'))
+
+    @abstractmethod
+    def generate_message(self, context: Dict, dialogue_history: List[Dict],
+                         query_prompt_template: str) -> List[Dict]:
+        pass
+
+    def _generate_message(self, context: Dict, dialogue_history: List[Dict],
+                         query_prompt_template: str) -> List[Dict]:
+        formatted_dialogue = MultiModalChatAgent.compose_dialogue(dialogue_history)
+        query = query_prompt_template.format(**context, dialogue_history=formatted_dialogue)
+        self.run(query)
+        result = []
+        for tool_call in self.tool_call_records:
+            if tool_call['name'] == MessageNotifier.output_multimodal_message.__name__:
+                result.append(tool_call['arguments']['message'])
+        return result
+
+    @staticmethod
+    def compose_dialogue(dialogue: List[Dict]) -> str:
+        role_map = {'user': '用户', 'assistant': '客服'}
+        messages = []
+        for msg in dialogue:
+            if not msg['content']:
+                continue
+            if msg['role'] not in role_map:
+                continue
+            format_dt = datetime.datetime.fromtimestamp(msg['timestamp'] / 1000).strftime('%Y-%m-%d %H:%M:%S')
+            msg_type = msg.get('type', MessageType.TEXT).description
+            messages.append('[{}][{}][{}]{}'.format(role_map[msg['role']], format_dt, msg_type, msg['content']))
+        return '\n'.join(messages)

+ 11 - 1
pqai_agent/agents/simple_chat_agent.py

@@ -15,12 +15,22 @@ class SimpleOpenAICompatibleChatAgent:
         self.model = model
         self.llm_client = OpenAICompatible.create_client(model)
         self.system_prompt = system_prompt
-        self.tools = tools or []
+        if tools:
+            self.tools = [*tools]
+        else:
+            self.tools = []
         self.tool_map = {tool.name: tool for tool in self.tools}
         self.generate_cfg = generate_cfg or {}
         self.max_run_step = max_run_step or DEFAULT_MAX_RUN_STEPS
         self.tool_call_records = []
 
+    def add_tool(self, tool: FunctionTool):
+        """添加一个工具到Agent中"""
+        if tool.name in self.tool_map:
+            logger.warning(f"Tool {tool.name} already exists, replacing it.")
+        self.tools.append(tool)
+        self.tool_map[tool.name] = tool
+
     def run(self, user_input: str) -> str:
         messages = [{"role": "system", "content": self.system_prompt}]
         tools = [tool.get_openai_tool_schema() for tool in self.tools]

+ 36 - 0
pqai_agent/clients/relation_stage_client.py

@@ -0,0 +1,36 @@
+from typing import Optional
+
+import requests
+
+from pqai_agent.logging_service import logger
+
+class RelationStageClient:
+    UNKNOWN_RELATION_STAGE = '未知'
+
+    def __init__(self, base_url: Optional[str] = None):
+        base_url = base_url or "http://ai-wechat-hook-internal.piaoquantv.com/analyse/getUserEmployeeRelStage"
+        self.base_url = base_url
+
+    def get_relation_stage(self, staff_id: str, user_id: str) -> str:
+        url = f"{self.base_url}?employeeId={staff_id}&userId={user_id}"
+        response = requests.get(url)
+        if response.status_code != 200:
+            logger.error(f"Request error [{response.status_code}]: {response.text}")
+            return self.UNKNOWN_RELATION_STAGE
+        data = response.json()
+        if not data.get('success', False):
+            logger.error(f"Error in response: {data.get('message', 'no message returned')}")
+            return self.UNKNOWN_RELATION_STAGE
+        if 'data' not in data:
+            logger.error("No 'data' field in response")
+            return self.UNKNOWN_RELATION_STAGE
+        return data.get('data')
+
+if __name__ == "__main__":
+    # Example usage
+    client = RelationStageClient()
+    stage = client.get_relation_stage("1688856125791790", "7881301780233975")
+    if stage:
+        print(f"Relation stage: {stage}")
+    else:
+        print("Failed to retrieve relation stage.")

+ 25 - 22
pqai_agent/configs/dev.yaml

@@ -1,40 +1,42 @@
+database:
+    ai_agent:
+        host: rm-bp13g3ra2f59q49xs.mysql.rds.aliyuncs.com
+        port: 3306
+        user: wqsd
+        password: wqsd@2025
+        database: ai_agent
+        charset: utf8mb4
+    growth:
+        host: rm-bp17q95335a99272b.mysql.rds.aliyuncs.com
+        port: 3306
+        user: crawler
+        password: crawler123456@
+        database: growth
+        charset: utf8mb4
+
 storage:
   history_dialogue:
     api_base_url: http://ai-wechat-hook-internal.piaoquantv.com/wechat/message/getConversation
   user:
-    mysql:
-      host: rm-bp13g3ra2f59q49xs.mysql.rds.aliyuncs.com
-      port: 3306
-      user: wqsd
-      password: wqsd@2025
-      database: ai_agent
-      charset: utf8mb4
+    database: ai_agent
     table: third_party_user
+  staff:
+    database: ai_agent
+    table: qywx_employee
   user_relation:
-    mysql:
-      host: rm-bp17q95335a99272b.mysql.rds.aliyuncs.com
-      port: 3306
-      user: crawler
-      password: crawler123456@
-      database: growth
-      charset: utf8mb4
+    database: growth
     table:
       staff: we_com_staff
       relation: we_com_staff_with_user
       user: we_com_user
-  staff:
-    table: qywx_employee
   agent_state:
-    mysql:
-      host: rm-bp13g3ra2f59q49xs.mysql.rds.aliyuncs.com
-      port: 3306
-      user: wqsd
-      password: wqsd@2025
-      database: ai_agent
+    database: ai_agent
     table: agent_state
   chat_history:
+    database: ai_agent
     table: qywx_chat_history
   push_record:
+    database: ai_agent
     table: agent_push_record_dev
 
 agent_behavior:
@@ -61,6 +63,7 @@ system:
   human_intervention_alert_url: https://open.feishu.cn/open-apis/bot/v2/hook/379fcd1a-0fed-4e58-8cd0-40b6d1895721
   max_reply_workers: 2
   max_push_workers: 1
+  chat_agent_version: 1
 
 debug_flags:
   disable_llm_api_call: True

+ 24 - 22
pqai_agent/configs/prod.yaml

@@ -1,40 +1,42 @@
+database:
+  ai_agent:
+    host: rm-bp13g3ra2f59q49xs.mysql.rds.aliyuncs.com
+    port: 3306
+    user: wqsd
+    password: wqsd@2025
+    database: ai_agent
+    charset: utf8mb4
+  growth:
+    host: rm-bp17q95335a99272b.mysql.rds.aliyuncs.com
+    port: 3306
+    user: crawler
+    password: crawler123456@
+    database: growth
+    charset: utf8mb4
+
 storage:
   history_dialogue:
     api_base_url: http://ai-wechat-hook-internal.piaoquantv.com/wechat/message/getConversation
   user:
-    mysql:
-      host: rm-bp13g3ra2f59q49xs.mysql.rds.aliyuncs.com
-      port: 3306
-      user: wqsd
-      password: wqsd@2025
-      database: ai_agent
-      charset: utf8mb4
+    database: ai_agent
     table: third_party_user
+  staff:
+    database: ai_agent
+    table: qywx_employee
   user_relation:
-    mysql:
-      host: rm-bp17q95335a99272b.mysql.rds.aliyuncs.com
-      port: 3306
-      user: crawler
-      password: crawler123456@
-      database: growth
-      charset: utf8mb4
+    database: growth
     table:
       staff: we_com_staff
       relation: we_com_staff_with_user
       user: we_com_user
-  staff:
-    table: qywx_employee
   agent_state:
-    mysql:
-      host: rm-bp13g3ra2f59q49xs.mysql.rds.aliyuncs.com
-      port: 3306
-      user: wqsd
-      password: wqsd@2025
-      database: ai_agent
+    database: ai_agent
     table: agent_state
   chat_history:
+    database: ai_agent
     table: qywx_chat_history
   push_record:
+    database: ai_agent
     table: agent_push_record_dev
 
 chat_api:

+ 30 - 0
pqai_agent/data_models/agent_configuration.py

@@ -0,0 +1,30 @@
+from enum import Enum
+
+from sqlalchemy import Column, Integer, Text, BigInteger, String, SmallInteger, Boolean, TIMESTAMP
+from sqlalchemy.ext.declarative import declarative_base
+
+Base = declarative_base()
+
+
+class AgentType(int, Enum):
+    REACTIVE = 0  # 响应式
+    PLANNING = 1  # 自主规划式
+
+class AgentConfiguration(Base):
+    __tablename__ = "agent_configuration"
+
+    id = Column(BigInteger, primary_key=True, autoincrement=True, comment="主键id")
+    name = Column(String(64), nullable=False, comment="唯一名称")
+    display_name = Column(String(64), nullable=True, comment="可选,显示名")
+    type = Column(SmallInteger, nullable=False, default=0, comment="Agent类型,0-响应式,1-自主规划式")
+    execution_model = Column(String(64), nullable=True, comment="执行LLM")
+    system_prompt = Column(Text, nullable=True, comment="系统设定prompt模板")
+    task_prompt = Column(Text, nullable=True, comment="执行任务prompt模板")
+    tools = Column(Text, nullable=True, comment="JSON数组,tool name")
+    sub_agents = Column(Text, nullable=True, comment="JSON数组,agent ID")
+    extra_params = Column(Text, nullable=True, comment="JSON KV对象")
+    is_delete = Column(Boolean, nullable=False, default=False, comment="逻辑删除标识")
+    create_user = Column(String(32), nullable=True, comment="创建用户")
+    update_user = Column(String(32), nullable=True, comment="更新用户")
+    create_time = Column(TIMESTAMP, nullable=True, server_default="CURRENT_TIMESTAMP", comment="创建时间")
+    update_time = Column(TIMESTAMP, nullable=True, server_default="CURRENT_TIMESTAMP", onupdate="CURRENT_TIMESTAMP", comment="更新时间")

+ 22 - 0
pqai_agent/data_models/service_module.py

@@ -0,0 +1,22 @@
+from enum import Enum
+
+from sqlalchemy import Column, Integer, Text, BigInteger, String, SmallInteger, Boolean, TIMESTAMP
+from sqlalchemy.ext.declarative import declarative_base
+
+Base = declarative_base()
+
+class ModuleAgentType(int, Enum):
+    NATIVE = 0  # 原生Agent
+    COZE = 1    # Coze Agent
+
+
+class ServiceModule(Base):
+    __tablename__ = "service_module"
+    id = Column(BigInteger, primary_key=True, autoincrement=True, comment="主键id")
+    name = Column(String(64), nullable=False, comment="唯一名称")
+    display_name = Column(String(64), nullable=True, comment="显示名")
+    default_agent_type = Column(SmallInteger, nullable=True, comment="默认Agent类型,0-原生,1-Coze")
+    default_agent_id = Column(BigInteger, nullable=True, comment="默认Agent ID")
+    is_delete = Column(Boolean, nullable=False, default=False, comment="逻辑删除标识")
+    create_time = Column(TIMESTAMP, nullable=False, server_default="CURRENT_TIMESTAMP", comment="创建时间")
+    update_time = Column(TIMESTAMP, nullable=False, server_default="CURRENT_TIMESTAMP", onupdate="CURRENT_TIMESTAMP", comment="更新时间")

+ 14 - 5
pqai_agent/dialogue_manager.py

@@ -14,6 +14,7 @@ import cozepy
 from sqlalchemy.orm import sessionmaker, Session
 
 from pqai_agent import configs
+from pqai_agent.clients.relation_stage_client import RelationStageClient
 from pqai_agent.data_models.agent_push_record import AgentPushRecord
 from pqai_agent.logging_service import logger
 from pqai_agent.database import MySQLManager
@@ -74,7 +75,7 @@ class DialogueStateChange:
 class DialogueStateCache:
     def __init__(self):
         self.config = configs.get()
-        self.db = MySQLManager(self.config['storage']['agent_state']['mysql'])
+        self.db = MySQLManager(self.config['database']['ai_agent'])
         self.table = self.config['storage']['agent_state']['table']
 
     def get_state(self, staff_id: str, user_id: str) -> Tuple[DialogueState, DialogueState]:
@@ -102,7 +103,7 @@ class DialogueStateCache:
 
 class DialogueManager:
     def __init__(self, staff_id: str, user_id: str, user_manager: UserManager, state_cache: DialogueStateCache,
-                 AgentDBSession: sessionmaker[Session]):
+                 agent_db_session_maker: sessionmaker[Session]):
         config = configs.get()
 
         self.staff_id = staff_id
@@ -125,7 +126,10 @@ class DialogueManager:
         self.history_dialogue_service = HistoryDialogueService(
             config['storage']['history_dialogue']['api_base_url']
         )
-        self.AgentDBSession = AgentDBSession
+        # FIXME: 实际为无状态接口,不需要每个DialogueManager持有一个单独实例
+        self.relation_stage_client = RelationStageClient()
+        self.relation_stage = self.relation_stage_client.get_relation_stage(staff_id, user_id)
+        self.agent_db_session_maker = agent_db_session_maker
         self._recover_state()
         # 由于本地状态管理过于复杂,引入事务机制做状态回滚
         self._uncommited_state_change = []
@@ -155,6 +159,10 @@ class DialogueManager:
 
     def refresh_profile(self):
         self.staff_profile = self.user_manager.get_staff_profile(self.staff_id)
+        relation_stage = self.relation_stage_client.get_relation_stage(self.staff_id, self.user_id)
+        if relation_stage and relation_stage != self.relation_stage:
+            logger.info(f"staff[{self.staff_id}], user[{self.user_id}]: relation stage changed from {self.relation_stage} to {relation_stage}")
+            self.relation_stage = relation_stage
 
     def _recover_state(self):
         self.current_state, self.previous_state = self.state_cache.get_state(self.staff_id, self.user_id)
@@ -174,7 +182,7 @@ class DialogueManager:
         else:
             # 默认设置
             self.last_interaction_time_ms = int(time.time() * 1000) - minutes_to_get * 60 * 1000
-        with self.AgentDBSession() as session:
+        with self.agent_db_session_maker() as session:
             # 读取数据库中的最后一次交互时间
             query = session.query(AgentPushRecord).filter(
                 AgentPushRecord.staff_id == self.staff_id,
@@ -530,7 +538,6 @@ class DialogueManager:
             return True
         return False
 
-
     def is_in_human_intervention(self) -> bool:
         """检查是否处于人工介入状态"""
         return self.current_state == DialogueState.HUMAN_INTERVENTION
@@ -559,7 +566,9 @@ class DialogueManager:
             "last_interaction_interval": self._get_hours_since_last_interaction(2),
             "if_first_interaction": True if self.previous_state == DialogueState.INITIALIZED else False,
             "if_active_greeting": False if user_message else True,
+            "relation_stage": self.relation_stage,
             "formatted_staff_profile": prompt_utils.format_agent_profile(self.staff_profile),
+            "formatted_user_profile": prompt_utils.format_user_profile(self.user_profile),
             **self.user_profile,
             **legacy_staff_profile
         }

+ 2 - 2
pqai_agent/history_dialogue_service.py

@@ -89,6 +89,6 @@ if __name__ == '__main__':
     service = HistoryDialogueService(api_url)
     resp = service.get_dialogue_history(staff_id='1688857241615085', user_id='7881299616070168', recent_minutes=5*1440)
     print(resp)
-    user_db_config = configs.get()['storage']['user']['mysql']
-    db = HistoryDialogueDatabase(user_db_config)
+    agent_db_config = configs.get()['database']['ai_agent']
+    db = HistoryDialogueDatabase(agent_db_config)
     # print(db.get_dialogue_history_backward('1688854492669990', '7881301263964433', 1747397155000))

+ 55 - 0
pqai_agent/prompt_templates.py

@@ -211,6 +211,61 @@ USER_PROFILE_EXTRACT_PROMPT = """
 请使用update_user_profile函数返回需要更新的信息,注意不要返回不需要更新的信息!
 """
 
+USER_PROFILE_EXTRACT_PROMPT_V2 = """
+请在已有的用户画像的基础上,仔细分析以下用户和客服的对话内容,完善用户的画像信息。
+
+# 对话历史格式
+[用户][2025-05-29 22:06:14][文本] 内容...
+[客服][2025-05-29 22:06:20][文本] 内容...
+[用户][2025-05-29 22:06:33][文本] 内容...
+## 特别说明
+* 对话历史已通过[用户]/[客服]标签严格区分发言角色,除开头的角色标签外,其它均为对话的内容!
+* 消息开头可能出现"丽丽:"等冒号分隔结构,是对另一方的称呼,不是要将其视为对话发起人的身份标识!
+
+# 特征key定义及含义
+- name: 姓名
+- preferred_nickname: 用户希望对其的称呼
+- gender: 性别
+- age: 年龄
+- region: 地区。用户常驻的地区,不是用户临时所在地
+- health_conditions: 健康状况
+- interests: 兴趣爱好
+- interaction_frequency: 联系频率。每2天联系小于1次为low,每天联系1次为medium,未来均不再联系为stopped
+- flexible_params: 动态特征
+
+# 当前已提取信息(可能为空或有错误)
+{formatted_user_profile}
+
+# 对话历史
+{dialogue_history}
+
+# 任务
+在微信场景中,要与用户保持紧密沟通并提升互动质量,从历史沟通内容中系统性地提取极高置信度的用户信息
+
+# 要求
+* 尽可能准确地识别用户的年龄、兴趣爱好、健康状况
+* 关注用户生活、家庭等隐性信息
+* 信息提取一定要有很高的准确性!如果无法确定具体信息,一定不要猜测!一定注意是用户自己的情况,而不是用户谈到的其它人的情况!
+* 用户消息中出现的任何名称都视为对客服或第三方的称呼!除非用户明确使用类似"我叫""本名是"等自述句式,否则永远不要提取为姓名!
+* 一定不要混淆用户和客服分别说的话!客服说的话只用于提供上下文,帮助理解对话语境!所有信息必须以用户说的为准!
+* preferred_nickname提取需满足:用户明确使用"请叫我X"/"叫我X"/"称呼我X"等指令句式。排除用户对其他人的称呼。
+* 一定不要把用户对客服的称呼当作preferred_nickname!一定不要把用户对客服的称呼当作preferred_nickname!
+* 注意兴趣爱好的定义!兴趣爱好是为了乐趣或放松而进行的活动或消遣,必须是用户明确提到喜欢参与的活动,必须为动词或动名词。
+* 兴趣爱好只保留最关键的5项。请合并相似的兴趣,不要保留多项相似的兴趣!注意兴趣爱好的定义!一定不要把用户短期的话题和需求当作兴趣爱好!
+* 当前已提取的兴趣爱好并不一定准确,请判断当前兴趣爱好是否符合常理,如果不是一项活动或者根据对话历史判断它不是用户的兴趣爱好,请删除!
+* 每个特征按照低/中/高区分,只保留高置信度特征
+* 你需要自己提取对沟通有帮助的特征,放入flexible_params,key直接使用中文
+* 除了flexible_params,其它key请严格遵循<特征key定义>中的要求,不要使用未定义的key!
+
+以JSON对象格式返回**需要更新**的信息,不要返回无需更新的信息!!如果无需更新任何信息,请返回{{}},不要输出其它内容。示例输出:
+{{
+    "name": "张三",
+    "flexible_params": {{
+        "沟通特点": "使用四川方言"
+    }}
+}}
+"""
+
 RESPONSE_TYPE_DETECT_PROMPT = """
 # 角色设定
 * 你是一位熟悉中老年用户交流习惯的智能客服,能够精准理解用户需求,提供专业、实用且有温度的建议。

+ 31 - 7
pqai_agent/push_service.py

@@ -12,11 +12,14 @@ import rocketmq
 from rocketmq import ClientConfiguration, Credentials, SimpleConsumer, FilterExpression
 
 from pqai_agent import configs
+from pqai_agent.abtest.utils import get_abtest_info
 from pqai_agent.agents.message_push_agent import MessagePushAgent, DummyMessagePushAgent
 from pqai_agent.configs import apollo_config
 from pqai_agent.data_models.agent_push_record import AgentPushRecord
 from pqai_agent.logging_service import logger
 from pqai_agent.mq_message import MessageType
+from pqai_agent.toolkit import get_tools
+from pqai_agent.utils.agent_abtest_utils import get_agent_abtest_config
 
 
 class TaskType(Enum):
@@ -54,12 +57,17 @@ class PushScanThread:
         for staff_user in self.service.user_relation_manager.list_staff_users(staff_id=self.staff_id):
             staff_id = staff_user['staff_id']
             user_id = staff_user['user_id']
-            agent = self.service.get_agent_instance(staff_id, user_id)
-            should_initiate = agent.should_initiate_conversation()
+            # 通过AB实验配置控制用户组是否启用push
+            # abtest_params = get_abtest_info(user_id).params
+            # if abtest_params.get('agent_push_enabled', 'false').lower() != 'true':
+            #     logger.debug(f"User {user_id} not enabled agent push, skipping.")
+            #     continue
             user_tags = self.service.user_relation_manager.get_user_tags(user_id)
-
             if configs.get_env() != 'dev' and not white_list_tags.intersection(user_tags):
                 should_initiate = False
+            else:
+                agent = self.service.get_agent_instance(staff_id, user_id)
+                should_initiate = agent.should_initiate_conversation()
             if should_initiate:
                 logger.info(f"user[{user_id}], tags{user_tags}: generate a generation task for conversation initiation")
                 rmq_msg = generate_task_rmq_message(self.rmq_topic, staff_id, user_id, TaskType.GENERATE)
@@ -168,7 +176,7 @@ class PushTaskWorkerPool:
                     if response:
                         item["type"] = message_type
                         messages_to_send.append(item)
-            with self.agent_service.AgentDBSession() as session:
+            with self.agent_service.agent_db_session_maker() as session:
                 msg_list = [{"type": msg["type"].value, "content": msg["content"]} for msg in messages_to_send]
                 record = AgentPushRecord(staff_id=staff_id, user_id=user_id,
                                          content=json.dumps(msg_list, ensure_ascii=False),
@@ -192,12 +200,28 @@ class PushTaskWorkerPool:
             staff_id = task['staff_id']
             user_id = task['user_id']
             main_agent = self.agent_service.get_agent_instance(staff_id, user_id)
-            push_agent = MessagePushAgent()
+            agent_config = get_agent_abtest_config('push', user_id,
+                                                   self.agent_service.service_module_manager,
+                                                   self.agent_service.agent_config_manager)
+            if agent_config:
+                try:
+                    tool_names = json.loads(agent_config.tools)
+                except json.JSONDecodeError:
+                    logger.error(f"Invalid JSON in agent tools: {agent_config.tools}")
+                    tool_names = []
+                push_agent = MessagePushAgent(model=agent_config.execution_model,
+                                              system_prompt=agent_config.system_prompt,
+                                              tools=get_tools(tool_names))
+                query_prompt_template = agent_config.task_prompt
+            else:
+                push_agent = MessagePushAgent()
+                query_prompt_template = None
             message_to_user = push_agent.generate_message(
                 context=main_agent.get_prompt_context(None),
                 dialogue_history=self.agent_service.history_dialogue_db.get_dialogue_history_backward(
-                    staff_id, user_id, main_agent.last_interaction_time_ms, limit=100
-                )
+                    staff_id, user_id, main_agent.last_interaction_time_ms, limit=30
+                ),
+                query_prompt_template=query_prompt_template
             )
             if message_to_user:
                 rmq_message = generate_task_rmq_message(

+ 3 - 2
pqai_agent/response_type_detector.py

@@ -38,7 +38,8 @@ class ResponseTypeDetector:
         )
         self.model_name = chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_1_5
 
-    def detect_type(self, dialogue_history: List[Dict], next_message: Dict, enable_random=False):
+    def detect_type(self, dialogue_history: List[Dict], next_message: Dict, enable_random=False,
+                    random_rate=0.25):
         if configs.get().get('debug_flags', {}).get('disable_llm_api_call', False):
             return MessageType.TEXT
         composed_dialogue = self.compose_dialogue(dialogue_history)
@@ -62,7 +63,7 @@ class ResponseTypeDetector:
             suitable_for_voice = self.if_message_suitable_for_voice(next_message_content)
             logger.debug(f"voice suitable[{suitable_for_voice}], message: {next_message_content}")
             if suitable_for_voice:
-                if random.random() < 0.6:
+                if random.random() < random_rate:
                     logger.info(f"enable voice response randomly for message: {next_message_content}")
                     return MessageType.VOICE
         return MessageType.TEXT

+ 27 - 0
pqai_agent/service_module_manager.py

@@ -0,0 +1,27 @@
+from pqai_agent.data_models.service_module import ServiceModule, ModuleAgentType
+from pqai_agent.logging_service import logger
+
+class ServiceModuleManager:
+    def __init__(self, session_maker):
+        self.session_maker = session_maker
+        self.module_configs = {}
+        self.refresh_configs()
+
+    def refresh_configs(self):
+        try:
+            with self.session_maker() as session:
+                data = session.query(ServiceModule).filter_by(is_delete=False).all()
+                module_configs = {}
+                for module in data:
+                    module_configs[module.name] = {
+                        'display_name': module.display_name,
+                        'default_agent_type': ModuleAgentType(module.default_agent_type),
+                        'default_agent_id': module.default_agent_id
+                    }
+                self.module_configs = module_configs
+                logger.debug(f"Refreshed module configurations: {module_configs}")
+        except Exception as e:
+            logger.error(f"Error refreshing module configs: {e}")
+
+    def get_module_config(self, module_name: str):
+        return self.module_configs.get(module_name)

+ 42 - 0
pqai_agent/toolkit/__init__.py

@@ -0,0 +1,42 @@
+# 必须要在这里导入模块,以便对应的模块执行register_toolkit
+from typing import Sequence, List
+
+from pqai_agent.logging_service import logger
+from pqai_agent.toolkit.tool_registry import ToolRegistry
+from pqai_agent.toolkit.image_describer import ImageDescriber
+from pqai_agent.toolkit.message_notifier import MessageNotifier
+from pqai_agent.toolkit.pq_video_searcher import PQVideoSearcher
+from pqai_agent.toolkit.search_toolkit import SearchToolkit
+
+global_tool_map = ToolRegistry.tool_map
+
+def get_tool(tool_name: str) -> 'FunctionTool':
+    """
+    Retrieve a tool by its name from the global tool map.
+
+    Args:
+        tool_name (str): The name of the tool to retrieve.
+
+    Returns:
+        FunctionTool: The tool instance if found, otherwise None.
+    """
+    return global_tool_map.get(tool_name, None)
+
+def get_tools(tool_names: Sequence[str]) -> List['FunctionTool']:
+    """
+    Retrieve multiple tools by their names from the global tool map.
+
+    Args:
+        tool_names (Sequence[str]): A sequence of tool names to retrieve.
+
+    Returns:
+        Sequence[FunctionTool]: A sequence of tool instances corresponding to the provided names.
+    """
+    ret = []
+    for name in tool_names:
+        tool = get_tool(name)
+        if tool is not None:
+            ret.append(tool)
+        else:
+            logger.warning(f"Tool '{name}' not found in the global tool map.")
+    return ret

+ 61 - 0
pqai_agent/toolkit/coze_function_tools.py

@@ -0,0 +1,61 @@
+import types
+from typing import List, Dict
+import textwrap
+from pqai_agent.toolkit.function_tool import FunctionTool
+from pqai_agent.chat_service import Coze, TokenAuth, Message, CozeChat
+
+
+class FunctionToolFactory:
+    @staticmethod
+    def create_tool(bot_id: str, coze_client: CozeChat, func_name: str, func_desc: str) -> FunctionTool:
+        """
+        Create a FunctionTool for a specific Coze Bot.
+
+        Args:
+            bot_id (str): The ID of the Coze Bot.
+            coze_client (CozeChat): The Coze client instance to interact with the bot.
+            func_name (str): The name of the function to be used in the FunctionTool.
+            func_desc (str): A description of the function to be used in the FunctionTool.
+        Returns:
+            FunctionTool: A FunctionTool wrapping the Coze Bot interaction.
+        """
+
+        func_doc = f"""
+            {func_desc}
+
+            Args:
+                messages (List[Dict]): A list of messages to send to the bot.
+                custom_variables (Dict): Custom variables for the bot.
+            Returns:
+                str: The final response from the bot.
+            """
+        func_doc = textwrap.dedent(func_doc).strip()
+
+        def coze_func(messages: List[Dict], custom_variables: Dict = None) -> str:
+            # ATTENTION:
+            # custom_variables (Dict): Custom variables for the bot. THIS IS A TRICK.
+            # THIS PARAMETER SHOULD NOT BE VISIBLE TO THE AGENT AND FILLED BY THE SYSTEM.
+
+            # FIXME: Coze bot can return multimodal content.
+
+            response = coze_client.create(
+                bot_id=bot_id,
+                user_id='agent_tool_call',
+                messages=[Message.build_user_question_text(msg["text"]) for msg in messages],
+                custom_variables=custom_variables
+            )
+            if not response:
+                return 'Error in calling the function.'
+            return response
+
+        dynamic_function = types.FunctionType(
+            coze_func.__code__,
+            globals(),
+            name=func_name,
+            argdefs=coze_func.__defaults__,
+            closure=coze_func.__closure__
+        )
+        dynamic_function.__doc__ = func_doc
+
+        # Wrap the function in a FunctionTool
+        return FunctionTool(dynamic_function)

+ 2 - 0
pqai_agent/toolkit/image_describer.py

@@ -6,11 +6,13 @@ from pqai_agent.chat_service import VOLCENGINE_MODEL_DOUBAO_1_5_VISION_PRO
 from pqai_agent.logging_service import logger
 from pqai_agent.toolkit.base import BaseToolkit
 from pqai_agent.toolkit.function_tool import FunctionTool
+from pqai_agent.toolkit.tool_registry import register_toolkit
 
 # 不同实例间复用cache,但不是很好的实践
 _image_describer_caches = {}
 _cache_mutex = threading.Lock()
 
+@register_toolkit
 class ImageDescriber(BaseToolkit):
     def __init__(self, cache_dir: str = None):
         self.model = VOLCENGINE_MODEL_DOUBAO_1_5_VISION_PRO

+ 2 - 0
pqai_agent/toolkit/message_notifier.py

@@ -3,8 +3,10 @@ from typing import List, Dict
 from pqai_agent.logging_service import logger
 from pqai_agent.toolkit.base import BaseToolkit
 from pqai_agent.toolkit.function_tool import FunctionTool
+from pqai_agent.toolkit.tool_registry import register_toolkit
 
 
+@register_toolkit
 class MessageNotifier(BaseToolkit):
     def __init__(self):
         super().__init__()

+ 3 - 0
pqai_agent/toolkit/pq_video_searcher.py

@@ -3,7 +3,10 @@ import requests
 
 from pqai_agent.toolkit.base import BaseToolkit
 from pqai_agent.toolkit.function_tool import FunctionTool
+from pqai_agent.toolkit.tool_registry import register_toolkit
 
+
+@register_toolkit
 class PQVideoSearcher(BaseToolkit):
     API_URL = "https://vlogapi.piaoquantv.com/longvideoapi/search/userandvideo/list"
     def search_pq_video(self, keywords: List[str]) -> List[Dict]:

+ 2 - 0
pqai_agent/toolkit/search_toolkit.py

@@ -4,8 +4,10 @@ import requests
 
 from pqai_agent.toolkit.base import BaseToolkit
 from pqai_agent.toolkit.function_tool import FunctionTool
+from pqai_agent.toolkit.tool_registry import register_toolkit
 
 
+@register_toolkit
 class SearchToolkit(BaseToolkit):
     r"""A class representing a toolkit for web search.
     """

+ 27 - 0
pqai_agent/toolkit/tool_registry.py

@@ -0,0 +1,27 @@
+from typing import Type, Dict
+from pqai_agent.toolkit.function_tool import FunctionTool
+
+class ToolRegistry:
+    tool_map: Dict[str, FunctionTool] = {}
+
+    @classmethod
+    def register_tools(cls, toolkit_class: Type):
+        """
+        Register tools from a toolkit class into the global tool_map.
+
+        Args:
+            toolkit_class (Type): A class that implements a `get_tools` method.
+        """
+        instance = toolkit_class()
+        if not hasattr(instance, 'get_tools') or not callable(instance.get_tools):
+            raise ValueError(f"{toolkit_class.__name__} must implement a callable `get_tools` method.")
+
+        tools = instance.get_tools()
+        for tool in tools:
+            if not hasattr(tool, 'name'):
+                raise ValueError(f"Tool {tool} must have a `name` attribute.")
+            cls.tool_map[tool.name] = tool
+
+def register_toolkit(cls):
+    ToolRegistry.register_tools(cls)
+    return cls

+ 0 - 25
pqai_agent/user_manager.py

@@ -55,8 +55,6 @@ class UserManager(abc.ABC):
             },
             "interaction_style": "standard",  # standard, verbose, concise
             "interaction_frequency": "medium",  # low, medium, high
-            "last_topics": [],
-            "created_at": int(time.time() * 1000),
             "human_intervention_history": []
         }
         for key, value in kwargs.items():
@@ -294,7 +292,6 @@ class MySQLUserManager(UserManager):
             "data": staff_list
         }
 
-
 class LocalUserRelationManager(UserRelationManager):
     def __init__(self):
         pass
@@ -436,25 +433,3 @@ class MySQLUserRelationManager(UserRelationManager):
         except Exception as e:
             logger.error(f"stop_user_daily_push failed: {e}")
             return False
-
-
-if __name__ == '__main__':
-    config = configs.get()
-    user_db_config = config['storage']['user']
-    staff_db_config = config['storage']['staff']
-    user_manager = MySQLUserManager(user_db_config['mysql'], user_db_config['table'], staff_db_config['table'])
-    user_profile = user_manager.get_user_profile('7881301263964433')
-    print(user_profile)
-
-    wecom_db_config = config['storage']['user_relation']
-    user_relation_manager = MySQLUserRelationManager(
-        user_db_config['mysql'], wecom_db_config['mysql'],
-        config['storage']['staff']['table'],
-        user_db_config['table'],
-        wecom_db_config['table']['staff'],
-        wecom_db_config['table']['relation'],
-        wecom_db_config['table']['user']
-    )
-    # all_staff_users = user_relation_manager.list_staff_users()
-    user_tags = user_relation_manager.get_user_tags('7881302078008656')
-    print(user_tags)

+ 85 - 21
pqai_agent/user_profile_extractor.py

@@ -5,20 +5,36 @@
 import json
 from typing import Dict, Optional, List
 
-from pqai_agent import chat_service
-from pqai_agent import configs
-from pqai_agent.prompt_templates import USER_PROFILE_EXTRACT_PROMPT
+from pqai_agent import chat_service, configs
+from pqai_agent.prompt_templates import USER_PROFILE_EXTRACT_PROMPT, USER_PROFILE_EXTRACT_PROMPT_V2
 from openai import OpenAI
 from pqai_agent.logging_service import logger
+from pqai_agent.utils import prompt_utils
 
 
 class UserProfileExtractor:
-    def __init__(self):
-        self.llm_client = OpenAI(
-            api_key=chat_service.VOLCENGINE_API_TOKEN,
-            base_url=chat_service.VOLCENGINE_BASE_URL
-        )
-        self.model_name = chat_service.VOLCENGINE_MODEL_DEEPSEEK_V3
+    FIELDS = [
+        "name",
+        "preferred_nickname",
+        "gender",
+        "age",
+        "region",
+        "interests",
+        "health_conditions",
+        "interaction_frequency",
+        "flexible_params"
+    ]
+    def __init__(self, model_name=None, llm_client=None):
+        if not llm_client:
+            self.llm_client = OpenAI(
+                api_key=chat_service.VOLCENGINE_API_TOKEN,
+                base_url=chat_service.VOLCENGINE_BASE_URL
+            )
+        else:
+            self.llm_client = llm_client
+        if not model_name:
+            model_name = chat_service.VOLCENGINE_MODEL_DEEPSEEK_V3
+        self.model_name = model_name
 
     @staticmethod
     def get_extraction_function() -> Dict:
@@ -73,13 +89,14 @@ class UserProfileExtractor:
             }
         }
 
-    def generate_extraction_prompt(self, user_profile: Dict, dialogue_history: List[Dict]) -> str:
+    def generate_extraction_prompt(self, user_profile: Dict, dialogue_history: List[Dict], prompt_template = USER_PROFILE_EXTRACT_PROMPT) -> str:
         """
         生成用于信息提取的系统提示词
         """
         context = user_profile.copy()
         context['dialogue_history'] = self.compose_dialogue(dialogue_history)
-        return USER_PROFILE_EXTRACT_PROMPT.format(**context)
+        context['formatted_user_profile'] = prompt_utils.format_user_profile(user_profile)
+        return prompt_template.format(**context)
 
     @staticmethod
     def compose_dialogue(dialogue: List[Dict]) -> str:
@@ -130,15 +147,61 @@ class UserProfileExtractor:
             logger.error(f"用户画像提取出错: {e}")
             return None
 
+    def extract_profile_info_v2(self, user_profile: Dict, dialogue_history: List[Dict], prompt_template: Optional[str] = None) -> Optional[Dict]:
+        """
+        使用JSON输出提取用户画像信息
+        :param user_profile:
+        :param dialogue_history:
+        :param prompt_template: 可选的自定义提示模板
+        :return:
+        """
+        if configs.get().get('debug_flags', {}).get('disable_llm_api_call', False):
+            logger.debug("skip LLM API call.")
+            return None
+
+        try:
+            logger.debug("try to extract profile from message: {}".format(dialogue_history))
+            prompt_template = prompt_template or USER_PROFILE_EXTRACT_PROMPT_V2
+            prompt = self.generate_extraction_prompt(user_profile, dialogue_history, prompt_template)
+            response = self.llm_client.chat.completions.create(
+                model=self.model_name,
+                messages=[
+                    {"role": "system", "content": '你是一个专业的用户画像分析助手。'},
+                    {"role": "user", "content": prompt}
+                ],
+                temperature=0
+            )
+            json_data = response.choices[0].message.content \
+                .replace("```", "").replace("```json", "").strip()
+            try:
+                profile_info = json.loads(json_data)
+            except json.JSONDecodeError as e:
+                logger.error(f"Error in JSON decode: {e}, original input: {json_data}")
+                return None
+            return profile_info
+
+        except Exception as e:
+            logger.error(f"用户画像提取出错: {e}")
+            return None
+
     def merge_profile_info(self, existing_profile: Dict, new_info: Dict) -> Dict:
         """
         合并新提取的用户信息到现有资料
         """
         merged_profile = existing_profile.copy()
-        merged_profile.update(new_info)
+        for field in new_info:
+            if field in self.FIELDS:
+                merged_profile[field] = new_info[field]
+            else:
+                logger.warning(f"Unknown field in new profile: {field}")
         return merged_profile
 
 if __name__ == '__main__':
+    from pqai_agent import configs
+    from pqai_agent import logging_service
+    logging_service.setup_root_logger()
+    config = configs.get()
+    config['debug_flags']['disable_llm_api_call'] = False
     extractor = UserProfileExtractor()
     current_profile = {
         'name': '',
@@ -152,11 +215,11 @@ if __name__ == '__main__':
         'interaction_frequency': 'medium'
     }
     messages= [
-        {'role': 'user', 'content': "没有任何问题放心,不会骚扰你了,再见"}
+        {'role': 'user', 'content': "没有任何问题放心,以后不要再发了,再见"}
     ]
 
-    resp = extractor.extract_profile_info(current_profile, messages)
-    print(resp)
+    # resp = extractor.extract_profile_info_v2(current_profile, messages)
+    # logger.warning(resp)
     message = "好的,孩子,我是老李头,今年68啦,住在北京海淀区。平时喜欢在微信上跟老伙伴们聊聊养生、下下象棋,偶尔也跟年轻人学学新鲜事儿。\n" \
               "你叫我李叔就行,有啥事儿咱们慢慢聊啊\n" \
               "哎,今儿个天气不错啊,我刚才还去楼下小公园溜达了一圈儿。碰到几个老伙计在打太极,我也跟着比划了两下,这老胳膊老腿的,原来老不舒服,活动活动舒坦多了!\n" \
@@ -165,9 +228,10 @@ if __name__ == '__main__':
     messages = []
     for line in message.split("\n"):
         messages.append({'role': 'user', 'content': line})
-    resp = extractor.extract_profile_info(current_profile, messages)
-    print(resp)
-    print(extractor.merge_profile_info(current_profile, resp))
+    resp = extractor.extract_profile_info_v2(current_profile, messages)
+    logger.warning(resp)
+    merged_profile = extractor.merge_profile_info(current_profile, resp)
+    logger.warning(merged_profile)
     current_profile = {
         'name': '李老头',
         'preferred_nickname': '李叔',
@@ -179,6 +243,6 @@ if __name__ == '__main__':
         'interests': ['养生', '下象棋'],
         'interaction_frequency': 'medium'
     }
-    resp = extractor.extract_profile_info(current_profile, messages)
-    print(resp)
-    print(extractor.merge_profile_info(current_profile, resp))
+    resp = extractor.extract_profile_info_v2(merged_profile, messages)
+    logger.warning(resp)
+    logger.warning(extractor.merge_profile_info(current_profile, resp))

+ 18 - 0
pqai_agent/utils/agent_abtest_utils.py

@@ -0,0 +1,18 @@
+from typing import Optional
+
+from pqai_agent.abtest.utils import get_abtest_info
+from pqai_agent.data_models.agent_configuration import AgentConfiguration
+from pqai_agent.service_module_manager import ServiceModuleManager
+from pqai_agent.agent_config_manager import AgentConfigManager
+
+def get_agent_abtest_config(module_name: str, uid: str,
+                            service_module_manager: ServiceModuleManager,
+                            agent_config_manager: AgentConfigManager) -> Optional[AgentConfiguration]:
+    abtest_info = get_abtest_info(uid)
+    module_config = service_module_manager.get_module_config(f'{module_name}_module')
+    agent_id = module_config['default_agent_id']
+    param_key = f'module_{module_name}_agent_id'
+    if param_key in abtest_info.params:
+        agent_id = int(abtest_info.params[param_key])
+    agent_config = agent_config_manager.get_config(agent_id)
+    return agent_config

+ 10 - 2
pqai_agent/utils/db_utils.py

@@ -1,6 +1,6 @@
 from urllib.parse import quote_plus
-
 from sqlalchemy import create_engine
+from pqai_agent import configs
 
 def create_sql_engine(config):
     user = config['user']
@@ -9,4 +9,12 @@ def create_sql_engine(config):
     db_name = config['database']
     charset = config.get('charset', 'utf8mb4')
     engine = create_engine(f'mysql+pymysql://{user}:{passwd}@{host}/{db_name}?charset={charset}')
-    return engine
+    return engine
+
+def create_ai_agent_db_engine():
+    config = configs.get()['database']['ai_agent']
+    return create_sql_engine(config)
+
+def create_growth_db_engine():
+    config = configs.get()['database']['growth']
+    return create_sql_engine(config)

+ 11 - 6
pqai_agent/utils/prompt_utils.py

@@ -39,21 +39,26 @@ def format_user_profile(profile: Dict) -> str:
     """
     fields = [
         ('nickname', '微信昵称'),
+        ('preferred_nickname', '希望对其的称呼'),
         ('name', '姓名'),
         ('avatar', '头像'),
-        ('preferred_nickname', '偏好的称呼'),
+        ('gender', '性别'),
         ('age', '年龄'),
         ('region', '地区'),
         ('health_conditions', '健康状况'),
-        ('medications', '用药信息'),
-        ('interests', '兴趣爱好')
+        ('interests', '兴趣爱好'),
+        ('interaction_frequency', '联系频率'),
+        ('flexible_params', '动态特征'),
     ]
     strings_to_join = []
     for field in fields:
-        if not profile.get(field[0], None):
+        value = profile.get(field[0], None)
+        if not value:
             continue
-        if isinstance(profile[field[0]], list):
-            value = ','.join(profile[field[0]])
+        if isinstance(value, list):
+            value = ','.join(value)
+        elif isinstance(value, dict):
+            value = ';'.join(f"{k}: {v}" for k, v in value.items())
         else:
             value = profile[field[0]]
         cur_string = f"- {field[1]}{value}"

+ 5 - 3
pqai_agent_server/agent_server.py

@@ -49,6 +49,8 @@ if __name__ == "__main__":
 
     # 初始化用户管理服务
     # FIXME(zhoutian): 如果不使用MySQL,此数据库配置非必须
+    agent_db_config = config['database']['ai_agent']
+    growth_db_config = config['database']['growth']
     user_db_config = config['storage']['user']
     staff_db_config = config['storage']['staff']
     wecom_db_config = config['storage']['user_relation']
@@ -56,10 +58,10 @@ if __name__ == "__main__":
         user_manager = LocalUserManager()
         user_relation_manager = LocalUserRelationManager()
     else:
-        user_manager = MySQLUserManager(user_db_config['mysql'], user_db_config['table'], staff_db_config['table'])
+        user_manager = MySQLUserManager(agent_db_config, user_db_config['table'], staff_db_config['table'])
         user_relation_manager = MySQLUserRelationManager(
-            user_db_config['mysql'], wecom_db_config['mysql'],
-            config['storage']['staff']['table'],
+            agent_db_config, growth_db_config,
+            staff_db_config['table'],
             user_db_config['table'],
             wecom_db_config['table']['staff'],
             wecom_db_config['table']['relation'],

+ 220 - 4
pqai_agent_server/api_server.py

@@ -7,12 +7,17 @@ import werkzeug.exceptions
 from flask import Flask, request, jsonify
 from argparse import ArgumentParser
 
+from sqlalchemy.orm import sessionmaker
+
 from pqai_agent import configs
 
 from pqai_agent import logging_service, chat_service, prompt_templates
 from pqai_agent.agents.message_reply_agent import MessageReplyAgent
+from pqai_agent.data_models.agent_configuration import AgentConfiguration
+from pqai_agent.data_models.service_module import ServiceModule
 from pqai_agent.history_dialogue_service import HistoryDialogueService
 from pqai_agent.user_manager import MySQLUserManager, MySQLUserRelationManager
+from pqai_agent.utils.db_utils import create_ai_agent_db_engine
 from pqai_agent.utils.prompt_utils import format_agent_profile, format_user_profile
 from pqai_agent_server.const import AgentApiConst
 from pqai_agent_server.models import MySQLSessionManager
@@ -134,7 +139,7 @@ def get_base_prompt():
     prompt_map = {
         'greeting': prompt_templates.GENERAL_GREETING_PROMPT,
         'chitchat': prompt_templates.CHITCHAT_PROMPT_COZE,
-        'profile_extractor': prompt_templates.USER_PROFILE_EXTRACT_PROMPT,
+        'profile_extractor': prompt_templates.USER_PROFILE_EXTRACT_PROMPT_V2,
         'response_type_detector': prompt_templates.RESPONSE_TYPE_DETECT_PROMPT,
         'custom_debugging': '',
     }
@@ -307,6 +312,213 @@ def quit_human_interventions_status():
 
     return wrap_response(200, data=response)
 
+## Agent管理接口
+@app.route("/api/getNativeAgentList", methods=["GET"])
+def get_native_agent_list():
+    """
+    获取所有的Agent列表
+    :return:
+    """
+    page = request.args.get('page', 1)
+    page_size = request.args.get('page_size', 50)
+    create_user = request.args.get('create_user', None)
+    update_user = request.args.get('update_user', None)
+
+    offset = (int(page) - 1) * int(page_size)
+    with app.session_maker() as session:
+        query = session.query(AgentConfiguration) \
+            .filter(AgentConfiguration.is_delete == 0)
+        if create_user:
+            query = query.filter(AgentConfiguration.create_user == create_user)
+        if update_user:
+            query = query.filter(AgentConfiguration.update_user == update_user)
+        query = query.offset(offset).limit(int(page_size))
+        data = query.all()
+    ret_data = [
+        {
+            'id': agent.id,
+            'name': agent.name,
+            'display_name': agent.display_name,
+            'type': agent.type,
+            'execution_model': agent.execution_model,
+            'create_time': agent.create_time.strftime('%Y-%m-%d %H:%M:%S'),
+            'update_time': agent.update_time.strftime('%Y-%m-%d %H:%M:%S')
+        }
+        for agent in data
+    ]
+    return wrap_response(200, data=ret_data)
+
+@app.route("/api/getNativeAgentConfiguration", methods=["GET"])
+def get_native_agent_configuration():
+    """
+    获取指定Agent的配置
+    :return:
+    """
+    agent_id = request.args.get('agent_id')
+    if not agent_id:
+        return wrap_response(404, msg='agent_id is required')
+
+    with app.session_maker() as session:
+        agent = session.query(AgentConfiguration).filter(AgentConfiguration.id == agent_id).first()
+        if not agent:
+            return wrap_response(404, msg='Agent not found')
+
+        data = {
+            'id': agent.id,
+            'name': agent.name,
+            'display_name': agent.display_name,
+            'type': agent.type,
+            'execution_model': agent.execution_model,
+            'system_prompt': agent.system_prompt,
+            'task_prompt': agent.task_prompt,
+            'tools': agent.tools,
+            'sub_agents': agent.sub_agents,
+            'extra_params': agent.extra_params,
+            'create_time': agent.create_time.strftime('%Y-%m-%d %H:%M:%S'),
+            'update_time': agent.update_time.strftime('%Y-%m-%d %H:%M:%S')
+        }
+        return wrap_response(200, data=data)
+
+@app.route("/api/saveNativeAgentConfiguration", methods=["POST"])
+def save_native_agent_configuration():
+    """
+    保存Agent配置
+    :return:
+    """
+    req_data = request.json
+    agent_id = req_data.get('agent_id', None)
+    name = req_data.get('name')
+    display_name = req_data.get('display_name', None)
+    type_ = req_data.get('type', 0)
+    execution_model = req_data.get('execution_model', None)
+    system_prompt = req_data.get('system_prompt', None)
+    task_prompt = req_data.get('task_prompt', None)
+    tools = req_data.get('tools', [])
+    sub_agents = req_data.get('sub_agents', [])
+    extra_params = req_data.get('extra_params', {})
+
+    if not name:
+        return wrap_response(400, msg='name is required')
+
+    with app.session_maker() as session:
+        if agent_id:
+            agent_id = int(agent_id)
+            agent = session.query(AgentConfiguration).filter(AgentConfiguration.id == agent_id).first()
+            if not agent:
+                return wrap_response(404, msg='Agent not found')
+            agent.name = name
+            agent.display_name = display_name
+            agent.type = type_
+            agent.execution_model = execution_model
+            agent.system_prompt = system_prompt
+            agent.task_prompt = task_prompt
+            agent.tools = tools
+            agent.sub_agents = sub_agents
+            agent.extra_params = extra_params
+        else:
+            agent = AgentConfiguration(
+                name=name,
+                display_name=display_name,
+                type=type_,
+                execution_model=execution_model,
+                system_prompt=system_prompt,
+                task_prompt=task_prompt,
+                tools=tools,
+                sub_agents=sub_agents,
+                extra_params=extra_params
+            )
+            session.add(agent)
+
+        session.commit()
+        return wrap_response(200, msg='Agent configuration saved successfully', data={'id': agent.id})
+
+@app.route("/api/getModuleList", methods=["GET"])
+def get_module_list():
+    """
+    获取所有的模块列表
+    :return:
+    """
+    with app.session_maker() as session:
+        query = session.query(ServiceModule) \
+            .filter(ServiceModule.is_delete == 0)
+        data = query.all()
+    ret_data = [
+        {
+            'id': module.id,
+            'name': module.name,
+            'display_name': module.display_name,
+            'default_agent_type': module.default_agent_type,
+            'default_agent_id': module.default_agent_id,
+            'create_time': module.create_time.strftime('%Y-%m-%d %H:%M:%S'),
+            'update_time': module.update_time.strftime('%Y-%m-%d %H:%M:%S')
+        }
+        for module in data
+    ]
+    return wrap_response(200, data=ret_data)
+
+@app.route("/api/getModuleConfiguration", methods=["GET"])
+def get_module_configuration():
+    """
+    获取指定模块的配置
+    :return:
+    """
+    module_id = request.args.get('module_id')
+    if not module_id:
+        return wrap_response(404, msg='module_id is required')
+
+    with app.session_maker() as session:
+        module = session.query(ServiceModule).filter(ServiceModule.id == module_id).first()
+        if not module:
+            return wrap_response(404, msg='Module not found')
+
+        data = {
+            'id': module.id,
+            'name': module.name,
+            'display_name': module.display_name,
+            'default_agent_type': module.default_agent_type,
+            'default_agent_id': module.default_agent_id,
+            'create_time': module.create_time.strftime('%Y-%m-%d %H:%M:%S'),
+            'updated_time': module.updated_time.strftime('%Y-%m-%d %H:%M:%S')
+        }
+        return wrap_response(200, data=data)
+
+@app.route("/api/saveModuleConfiguration", methods=["POST"])
+def save_module_configuration():
+    """
+    保存模块配置
+    :return:
+    """
+    req_data = request.json
+    module_id = req_data.get('module_id', None)
+    name = req_data.get('name')
+    display_name = req_data.get('display_name', None)
+    default_agent_type = req_data.get('default_agent_type', 0)
+    default_agent_id = req_data.get('default_agent_id', None)
+
+    if not name:
+        return wrap_response(400, msg='name is required')
+
+    with app.session_maker() as session:
+        if module_id:
+            module_id = int(module_id)
+            module = session.query(ServiceModule).filter(ServiceModule.id == module_id).first()
+            if not module:
+                return wrap_response(404, msg='Module not found')
+            module.name = name
+            module.display_name = display_name
+            module.default_agent_type = default_agent_type
+            module.default_agent_id = default_agent_id
+        else:
+            module = ServiceModule(
+                name=name,
+                display_name=display_name,
+                default_agent_type=default_agent_type,
+                default_agent_id=default_agent_id
+            )
+            session.add(module)
+
+        session.commit()
+        return wrap_response(200, msg='Module configuration saved successfully', data={'id': module.id})
 
 @app.errorhandler(werkzeug.exceptions.BadRequest)
 def handle_bad_request(e):
@@ -327,28 +539,32 @@ if __name__ == '__main__':
     logging_service.setup_root_logger(level=logging_level, logfile_name='agent_api_server.log')
 
     # set db config
+    agent_db_config = config['database']['ai_agent']
+    growth_db_config = config['database']['growth']
     user_db_config = config['storage']['user']
     staff_db_config = config['storage']['staff']
     agent_state_db_config = config['storage']['agent_state']
     chat_history_db_config = config['storage']['chat_history']
 
     # init user manager
-    user_manager = MySQLUserManager(user_db_config['mysql'], user_db_config['table'], staff_db_config['table'])
+    user_manager = MySQLUserManager(agent_db_config, growth_db_config, staff_db_config['table'])
     app.user_manager = user_manager
 
     # init session manager
     session_manager = MySQLSessionManager(
-        db_config=user_db_config['mysql'],
+        db_config=agent_db_config,
         staff_table=staff_db_config['table'],
         user_table=user_db_config['table'],
         agent_state_table=agent_state_db_config['table'],
         chat_history_table=chat_history_db_config['table']
     )
     app.session_manager = session_manager
+    agent_db_engine = create_ai_agent_db_engine(config['database']['ai_agent'])
+    app.session_maker = sessionmaker(bind=agent_db_engine)
 
     wecom_db_config = config['storage']['user_relation']
     user_relation_manager = MySQLUserRelationManager(
-        user_db_config['mysql'], wecom_db_config['mysql'],
+        agent_db_config, growth_db_config,
         config['storage']['staff']['table'],
         user_db_config['table'],
         wecom_db_config['table']['staff'],

+ 13 - 27
pqai_agent_server/utils/prompt_util.py

@@ -41,8 +41,7 @@ def compose_openai_chat_messages_no_time(dialogue_history, multimodal=False):
             messages.append({"role": role, "content": f'{entry["content"]}'})
     return messages
 
-
-def run_openai_chat(messages, model_name, **kwargs):
+def create_llm_client(model_name):
     volcengine_models = [
         chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_32K,
         chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_1_5,
@@ -72,6 +71,11 @@ def run_openai_chat(messages, model_name, **kwargs):
         )
     else:
         raise Exception("model not supported")
+    return llm_client
+
+
+def run_openai_chat(messages, model_name, **kwargs):
+    llm_client = create_llm_client(model_name)
     response = llm_client.chat.completions.create(
         messages=messages, model=model_name, **kwargs
     )
@@ -79,36 +83,18 @@ def run_openai_chat(messages, model_name, **kwargs):
     return response
 
 
-def run_extractor_prompt(req_data):
+def run_extractor_prompt(req_data) -> Dict[str, str]:
     prompt = req_data["prompt"]
     user_profile = req_data["user_profile"]
-    staff_profile = req_data["staff_profile"]
     dialogue_history = req_data["dialogue_history"]
     model_name = req_data["model_name"]
-    prompt_context = {
-        "formatted_staff_profile": format_agent_profile(staff_profile),
-        **user_profile,
-        "dialogue_history": UserProfileExtractor.compose_dialogue(dialogue_history),
-    }
-    prompt = prompt.format(**prompt_context)
-    messages = [
-        {"role": "system", "content": "你是一个专业的用户画像分析助手。"},
-        {"role": "user", "content": prompt},
-    ]
-    tools = [UserProfileExtractor.get_extraction_function()]
-    response = run_openai_chat(messages, model_name, tools=tools, temperature=0)
-    tool_calls = response.choices[0].message.tool_calls
-    if tool_calls:
-        function_call = tool_calls[0]
-        if function_call.function.name == "update_user_profile":
-            profile_info = json.loads(function_call.function.arguments)
-            return {k: v for k, v in profile_info.items() if v}
-        else:
-            logger.error("llm does not return update_user_profile")
-            return {}
-    else:
+    llm_client = create_llm_client(model_name)
+    extractor = UserProfileExtractor(model_name=model_name, llm_client=llm_client)
+    profile_to_update = extractor.extract_profile_info_v2(user_profile, dialogue_history, prompt)
+    logger.info(profile_to_update)
+    if not profile_to_update:
         return {}
-
+    return profile_to_update
 
 def run_chat_prompt(req_data):
     prompt = req_data["prompt"]

+ 4 - 2
scripts/disable_user_daily_push.py

@@ -16,10 +16,12 @@ if __name__ == '__main__':
     config = configs.get()
     user_db_config = config['storage']['user']
     staff_db_config = config['storage']['staff']
-    user_manager = MySQLUserManager(user_db_config['mysql'], user_db_config['table'], staff_db_config['table'])
+    agent_db_config = config['database']['ai_agent']
+    growth_db_config = config['database']['growth']
+    user_manager = MySQLUserManager(agent_db_config, user_db_config['table'], staff_db_config['table'])
     wecom_db_config = config['storage']['user_relation']
     user_relation_manager = MySQLUserRelationManager(
-        user_db_config['mysql'], wecom_db_config['mysql'],
+        agent_db_config, growth_db_config,
         config['storage']['staff']['table'],
         user_db_config['table'],
         wecom_db_config['table']['staff'],

+ 2 - 1
scripts/profile_cleaner.py

@@ -12,7 +12,8 @@ if __name__ == '__main__':
     config = configs.get()
     user_db_config = config['storage']['user']
     staff_db_config = config['storage']['staff']
-    user_manager = MySQLUserManager(user_db_config['mysql'], user_db_config['table'], staff_db_config['table'])
+    agent_db_config = config['database']['ai_agent']
+    user_manager = MySQLUserManager(agent_db_config, user_db_config['table'], staff_db_config['table'])
 
     user_ids_to_clean = ['7881299986081786', '7881303544096524', '7881300732152777', '7881301752098239', '7881299457990953', '7881302872936170',]
     for user_id in user_ids_to_clean:

+ 4 - 3
scripts/resend_lost_message.py

@@ -13,11 +13,12 @@ from pqai_agent.user_manager import MySQLUserRelationManager
 config = configs.get()
 
 def main():
-    wecom_db_config = config['storage']['user_relation']
     user_db_config = config['storage']['user']
-
+    agent_db_config = config['database']['ai_agent']
+    growth_db_config = config['database']['growth']
+    wecom_db_config = config['storage']['user_relation']
     user_relation_manager = MySQLUserRelationManager(
-        user_db_config['mysql'], wecom_db_config['mysql'],
+        agent_db_config, growth_db_config,
         config['storage']['staff']['table'],
         user_db_config['table'],
         wecom_db_config['table']['staff'],