|
@@ -0,0 +1,281 @@
|
|
|
+# Python: experiment_client.py
|
|
|
+import threading
|
|
|
+from typing import List, Dict
|
|
|
+from alibabacloud_paiabtest20240119.client import Client
|
|
|
+from pqai_agent.abtest.models import Project, Domain, Layer, Experiment, ExperimentVersion, \
|
|
|
+ ExperimentContext, ExperimentResult
|
|
|
+from alibabacloud_paiabtest20240119.models import ListProjectsRequest, ListProjectsResponseBodyProjects, \
|
|
|
+ ListDomainsRequest, ListFeaturesRequest, ListLayersRequest, ListExperimentsRequest, ListExperimentVersionsRequest
|
|
|
+from pqai_agent.logging_service import logger
|
|
|
+
|
|
|
+class ExperimentClient:
|
|
|
+ def __init__(self, client: Client):
|
|
|
+ self.client = client
|
|
|
+ self.project_map = {}
|
|
|
+ self.running = False
|
|
|
+ self.worker_thread = None
|
|
|
+
|
|
|
+ def start(self):
|
|
|
+ self.running = True
|
|
|
+ self.worker_thread = threading.Thread(target=self._worker_loop)
|
|
|
+ self.worker_thread.start()
|
|
|
+
|
|
|
+ def shutdown(self, blocking=False):
|
|
|
+ self.running = False
|
|
|
+ if self.worker_thread:
|
|
|
+ if blocking:
|
|
|
+ self.worker_thread.join()
|
|
|
+ else:
|
|
|
+ self.worker_thread = None
|
|
|
+
|
|
|
+ def _worker_loop(self):
|
|
|
+ while self.running:
|
|
|
+ # Sleep or wait for a condition to avoid busy waiting
|
|
|
+ threading.Event().wait(60)
|
|
|
+ try:
|
|
|
+ self.load_experiment_data()
|
|
|
+ logger.debug("Experiment data loaded successfully.")
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"Error loading experiment data: {e}")
|
|
|
+ logger.info("ExperimentClient worker thread exit.")
|
|
|
+
|
|
|
+ def load_experiment_data(self):
|
|
|
+ project_map = {}
|
|
|
+
|
|
|
+ # 获取所有项目
|
|
|
+ list_project_req = ListProjectsRequest()
|
|
|
+ list_project_req.all = True
|
|
|
+ projects_response = self.client.list_projects(list_project_req)
|
|
|
+ projects: List[ListProjectsResponseBodyProjects] = projects_response.body.projects
|
|
|
+
|
|
|
+ for project_data in projects:
|
|
|
+ project = Project(project_name=project_data.name, project_id=project_data.project_id)
|
|
|
+ logger.debug(f"[Project] {project_data}")
|
|
|
+
|
|
|
+ # 获取项目的域
|
|
|
+ list_domain_req = ListDomainsRequest()
|
|
|
+ list_domain_req.project_id = project.id
|
|
|
+ domains_response = self.client.list_domains(list_domain_req)
|
|
|
+
|
|
|
+ for domain_data in domains_response.body.domains:
|
|
|
+ domain = Domain(domain_id=domain_data.domain_id,
|
|
|
+ name=domain_data.name,
|
|
|
+ flow=domain_data.flow,
|
|
|
+ buckets=domain_data.buckets,
|
|
|
+ bucket_type=domain_data.bucket_type,
|
|
|
+ is_default_domain=domain_data.is_default_domain,
|
|
|
+ exp_layer_id=domain_data.layer_id,
|
|
|
+ debug_users=domain_data.debug_users)
|
|
|
+ logger.debug(f"[Domain] {domain_data}")
|
|
|
+ if domain.is_default_domain:
|
|
|
+ project.set_default_domain(domain)
|
|
|
+ domain.init()
|
|
|
+ project.add_domain(domain)
|
|
|
+
|
|
|
+ # 获取域的特性(暂无实际用处)
|
|
|
+ list_feature_req = ListFeaturesRequest()
|
|
|
+ list_feature_req.domain_id = domain.id
|
|
|
+ features_response = self.client.list_features(list_feature_req)
|
|
|
+ for feature_data in features_response.body.features:
|
|
|
+ domain.add_feature(feature_data)
|
|
|
+
|
|
|
+ # 获取域的层
|
|
|
+ list_layer_req = ListLayersRequest()
|
|
|
+ list_layer_req.domain_id = domain.id
|
|
|
+ layers_response = self.client.list_layers(list_layer_req)
|
|
|
+ for layer_data in layers_response.body.layers:
|
|
|
+ logger.debug(f'[Layer] {layer_data}')
|
|
|
+ layer = Layer(layer_id=layer_data.layer_id, name=layer_data.name)
|
|
|
+ project.add_layer(layer)
|
|
|
+
|
|
|
+ # 获取层的实验
|
|
|
+ list_experiment_req = ListExperimentsRequest()
|
|
|
+ list_experiment_req.layer_id = layer.id
|
|
|
+ # FIXME: magic code
|
|
|
+ list_experiment_req.status = 'Running'
|
|
|
+ experiments_response = self.client.list_experiments(list_experiment_req)
|
|
|
+
|
|
|
+ for experiment_data in experiments_response.body.experiments:
|
|
|
+ logger.debug(f'[Experiment] {experiment_data}')
|
|
|
+ # FIXME: Java SDK中有特殊处理
|
|
|
+ crowd_ids = experiment_data.crowd_ids if experiment_data.crowd_ids else ""
|
|
|
+ experiment = Experiment(id=int(experiment_data.experiment_id), bucket_type=experiment_data.bucket_type,
|
|
|
+ flow=experiment_data.flow, buckets=experiment_data.buckets,
|
|
|
+ crowd_ids=crowd_ids.split(","),
|
|
|
+ debug_users=experiment_data.debug_users,
|
|
|
+ filter_condition=experiment_data.condition
|
|
|
+ )
|
|
|
+ experiment.init()
|
|
|
+
|
|
|
+ # 获取实验的版本
|
|
|
+ list_exp_ver_req = ListExperimentVersionsRequest()
|
|
|
+ list_exp_ver_req.experiment_id = experiment.id
|
|
|
+ versions_response = self.client.list_experiment_versions(list_exp_ver_req)
|
|
|
+ print(versions_response)
|
|
|
+ for version_data in versions_response.body.experiment_versions:
|
|
|
+ logger.debug(f'[ExperimentVersion] {version_data}')
|
|
|
+ version = ExperimentVersion(exp_version_id=version_data.experiment_version_id,
|
|
|
+ flow=int(version_data.flow),
|
|
|
+ buckets=version_data.buckets,
|
|
|
+ debug_users=version_data.debug_users,
|
|
|
+ exp_version_name=version_data.name,
|
|
|
+ config=version_data.config)
|
|
|
+ version.init()
|
|
|
+ experiment.add_experiment_version(version)
|
|
|
+ layer.add_experiment(experiment)
|
|
|
+ domain.add_layer(layer)
|
|
|
+
|
|
|
+ # 建立layer-domain的反向映射,从而形成嵌套结构
|
|
|
+ for domain in project.domains:
|
|
|
+ if domain.is_default_domain:
|
|
|
+ continue
|
|
|
+ layer: Layer = project.layer_map.get(domain.exp_layer_id, None)
|
|
|
+ if not layer:
|
|
|
+ continue
|
|
|
+ layer.add_domain(domain)
|
|
|
+
|
|
|
+ project_map[project.project_name] = project
|
|
|
+
|
|
|
+ self.project_map = project_map
|
|
|
+
|
|
|
+ def match_experiment(self, project_name, experiment_context) -> ExperimentResult:
|
|
|
+ if project_name not in self.project_map:
|
|
|
+ return ExperimentResult(project_name=project_name, experiment_context=experiment_context)
|
|
|
+
|
|
|
+ project = self.project_map[project_name]
|
|
|
+ experiment_result = ExperimentResult(project=project, experiment_context=experiment_context)
|
|
|
+
|
|
|
+ self._match_domain(project.default_domain, experiment_result)
|
|
|
+ experiment_result.init()
|
|
|
+ return experiment_result
|
|
|
+
|
|
|
+ def _match_domain(self, domain: Domain, experiment_result: ExperimentResult):
|
|
|
+ if not domain:
|
|
|
+ return
|
|
|
+
|
|
|
+ for feature in domain.features:
|
|
|
+ if feature.match(experiment_result.experiment_context):
|
|
|
+ experiment_result.add_params(feature.params)
|
|
|
+
|
|
|
+ for layer in domain.layers:
|
|
|
+ self._match_layer(layer, experiment_result)
|
|
|
+
|
|
|
+ def _match_layer(self, layer, experiment_result):
|
|
|
+ if not layer:
|
|
|
+ return
|
|
|
+
|
|
|
+ for experiment in layer.experiments:
|
|
|
+ if experiment.match_debug_users(experiment_result.experiment_context):
|
|
|
+ logger.debug(f"Matched debug user for experiment: {experiment.id}")
|
|
|
+ self._match_experiment(experiment, experiment_result)
|
|
|
+ return
|
|
|
+
|
|
|
+ for domain in layer.domains:
|
|
|
+ if domain.match_debug_users(experiment_result.experiment_context):
|
|
|
+ logger.debug(f"Matched debug user for domain: {domain.id}")
|
|
|
+ self._match_domain(domain, experiment_result)
|
|
|
+
|
|
|
+ hash_key = f"{experiment_result.experiment_context.uid}_LAYER{layer.id}"
|
|
|
+ hash_value = self._hash_value(hash_key)
|
|
|
+
|
|
|
+ exp_context = ExperimentContext(uid=hash_value,
|
|
|
+ filter_params=experiment_result.experiment_context.filter_params)
|
|
|
+
|
|
|
+ matched_experiments = [exp for exp in layer.experiments if exp.match(exp_context)]
|
|
|
+
|
|
|
+ if len(matched_experiments) == 1:
|
|
|
+ self._match_experiment(matched_experiments[0], experiment_result)
|
|
|
+ elif len(matched_experiments) > 1:
|
|
|
+ for experiment in matched_experiments:
|
|
|
+ if experiment.bucket_type == "Condition":
|
|
|
+ self._match_experiment(experiment, experiment_result)
|
|
|
+ return
|
|
|
+ logger.warning(f"Warning: Multiple experiments matched under layer {layer.id}.")
|
|
|
+ self._match_experiment(matched_experiments[0], experiment_result)
|
|
|
+
|
|
|
+ matched_domains = []
|
|
|
+ for domain in layer.domains:
|
|
|
+ if domain.match(exp_context):
|
|
|
+ logger.debug(f"Matched domain {domain.id} for uid {experiment_result.experiment_context.uid}.")
|
|
|
+ matched_domains.append(domain)
|
|
|
+ if len(matched_domains) == 1:
|
|
|
+ self._match_domain(matched_domains[0], experiment_result)
|
|
|
+ return
|
|
|
+ elif len(matched_domains) > 1:
|
|
|
+ for domain in matched_domains:
|
|
|
+ if domain.bucket_type == "Condition":
|
|
|
+ self._match_domain(domain, experiment_result)
|
|
|
+ return
|
|
|
+ logger.warning(f"Warning: Multiple domains matched under layer {layer.id}, using the first one.")
|
|
|
+ self._match_domain(matched_domains[0], experiment_result)
|
|
|
+ return
|
|
|
+
|
|
|
+ def _match_experiment(self, experiment: Experiment, experiment_result: ExperimentResult):
|
|
|
+ if not experiment:
|
|
|
+ return
|
|
|
+
|
|
|
+ for version in experiment.experiment_versions:
|
|
|
+ if version.match_debug_users(experiment_result.experiment_context):
|
|
|
+ logger.debug(f"Matched debug user for experiment version: {version.id}")
|
|
|
+ experiment_result.add_params(version.params)
|
|
|
+ experiment_result.add_experiment_version(version)
|
|
|
+ return
|
|
|
+
|
|
|
+ hash_key = f"{experiment_result.experiment_context.uid}_EXPERIMENT{experiment.id}"
|
|
|
+ hash_value = self._hash_value(hash_key)
|
|
|
+
|
|
|
+ exp_context = ExperimentContext(uid=hash_value,
|
|
|
+ filter_params=experiment_result.experiment_context.filter_params)
|
|
|
+
|
|
|
+ for version in experiment.experiment_versions:
|
|
|
+ if version.match(exp_context):
|
|
|
+ experiment_result.add_params(version.params)
|
|
|
+ experiment_result.add_experiment_version(version)
|
|
|
+ return
|
|
|
+
|
|
|
+ def _hash_value(self, hash_key):
|
|
|
+ import hashlib
|
|
|
+ md5_hash = hashlib.md5(hash_key.encode()).hexdigest()
|
|
|
+ return md5_hash
|
|
|
+
|
|
|
+ def __del__(self):
|
|
|
+ if self.running and self.worker_thread:
|
|
|
+ self.shutdown()
|
|
|
+
|
|
|
+g_client = None
|
|
|
+
|
|
|
+def get_client():
|
|
|
+ global g_client
|
|
|
+ if not g_client:
|
|
|
+ ak_id = 'LTAI5tFGqgC8f3mh1fRCrAEy'
|
|
|
+ ak_secret = 'XhOjK9XmTYRhVAtf6yii4s4kZwWzvV'
|
|
|
+ region = 'cn-hangzhou'
|
|
|
+ from alibabacloud_tea_openapi.models import Config
|
|
|
+ endpoint = f"paiabtest.{region}.aliyuncs.com"
|
|
|
+ conf = Config(access_key_id=ak_id, access_key_secret=ak_secret, region_id=region,
|
|
|
+ endpoint=endpoint, type="access_key")
|
|
|
+ api_client = Client(conf)
|
|
|
+ g_client = ExperimentClient(api_client)
|
|
|
+ g_client.load_experiment_data()
|
|
|
+ g_client.start()
|
|
|
+ return g_client
|
|
|
+
|
|
|
+if __name__ == '__main__':
|
|
|
+ from pqai_agent.logging_service import setup_root_logger
|
|
|
+ setup_root_logger(level='DEBUG')
|
|
|
+ experiment_client = get_client()
|
|
|
+
|
|
|
+ for project_name, project in experiment_client.project_map.items():
|
|
|
+ print(f"Project: {project_name}, ID: {project.id}")
|
|
|
+ for domain in project.domains:
|
|
|
+ print(f" Domain: {domain.id}, Default: {domain.is_default_domain}")
|
|
|
+ for layer in domain.layers:
|
|
|
+ print(f" Layer: {layer.id}")
|
|
|
+ for experiment in layer.experiments:
|
|
|
+ print(f" Experiment: {experiment.id}")
|
|
|
+ for version in experiment.experiment_versions:
|
|
|
+ print(f" Version: {version.id}, Config: {version.config}")
|
|
|
+
|
|
|
+ exp_context = ExperimentContext(uid='x')
|
|
|
+ result = experiment_client.match_experiment('PQAgent', exp_context)
|
|
|
+ print(result)
|