浏览代码

Update abtest: fix dataclass and data types

StrayWarrior 4 天之前
父节点
当前提交
1dcddddaad
共有 2 个文件被更改,包括 41 次插入27 次删除
  1. 12 10
      pqai_agent/abtest/client.py
  2. 29 17
      pqai_agent/abtest/models.py

+ 12 - 10
pqai_agent/abtest/client.py

@@ -74,23 +74,23 @@ class ExperimentClient:
 
                 # 获取域的特性(暂无实际用处)
                 list_feature_req = ListFeaturesRequest()
-                list_feature_req.domain_id = domain.id
+                list_feature_req.domain_id = str(domain.id)
                 features_response = self.client.list_features(list_feature_req)
                 for feature_data in features_response.body.features:
                     domain.add_feature(feature_data)
 
                 # 获取域的层
                 list_layer_req = ListLayersRequest()
-                list_layer_req.domain_id = domain.id
+                list_layer_req.domain_id = str(domain.id)
                 layers_response = self.client.list_layers(list_layer_req)
                 for layer_data in layers_response.body.layers:
                     logger.debug(f'[Layer] {layer_data}')
-                    layer = Layer(layer_id=layer_data.layer_id, name=layer_data.name)
+                    layer = Layer(id=int(layer_data.layer_id), name=layer_data.name)
                     project.add_layer(layer)
 
                     # 获取层的实验
                     list_experiment_req = ListExperimentsRequest()
-                    list_experiment_req.layer_id = layer.id
+                    list_experiment_req.layer_id = str(layer.id)
                     # FIXME: magic code
                     list_experiment_req.status = 'Running'
                     experiments_response = self.client.list_experiments(list_experiment_req)
@@ -109,9 +109,9 @@ class ExperimentClient:
 
                         # 获取实验的版本
                         list_exp_ver_req = ListExperimentVersionsRequest()
-                        list_exp_ver_req.experiment_id = experiment.id
+                        list_exp_ver_req.experiment_id = int(experiment.id)
                         versions_response = self.client.list_experiment_versions(list_exp_ver_req)
-                        print(versions_response)
+                        logger.debug(versions_response)
                         for version_data in versions_response.body.experiment_versions:
                             logger.debug(f'[ExperimentVersion] {version_data}')
                             version = ExperimentVersion(exp_version_id=version_data.experiment_version_id,
@@ -129,6 +129,7 @@ class ExperimentClient:
             for domain in project.domains:
                 if domain.is_default_domain:
                     continue
+                # domain.exp_layer_id是domain所属的layer id
                 layer: Layer = project.layer_map.get(domain.exp_layer_id, None)
                 if not layer:
                     continue
@@ -233,10 +234,11 @@ class ExperimentClient:
                 experiment_result.add_experiment_version(version)
                 return
 
-    def _hash_value(self, hash_key):
+    def _hash_value(self, hash_key) -> int:
         import hashlib
-        md5_hash = hashlib.md5(hash_key.encode()).hexdigest()
-        return md5_hash
+        from pqai_agent.abtest.models import FNV
+        md5_hash = hashlib.md5(hash_key.encode()).hexdigest().encode()
+        return FNV.fnv1_64(md5_hash)
 
     def __del__(self):
         if self.running and self.worker_thread:
@@ -276,6 +278,6 @@ if __name__ == '__main__':
                     for version in experiment.experiment_versions:
                         print(f"        Version: {version.id}, Config: {version.config}")
 
-    exp_context = ExperimentContext(uid='x')
+    exp_context = ExperimentContext(uid='123')
     result = experiment_client.match_experiment('PQAgent', exp_context)
     print(result)

+ 29 - 17
pqai_agent/abtest/models.py

@@ -1,8 +1,21 @@
 from typing import List, Dict, Optional, Set
 import json
-from attr import dataclass
+from dataclasses import dataclass, field
 import hashlib
 
+class FNV:
+    INIT64 = int("cbf29ce484222325", 16)
+    PRIME64 = int("100000001b3", 16)
+    MOD64 = 2**64
+
+    @staticmethod
+    def fnv1_64(data: bytes) -> int:
+        hash_value = FNV.INIT64
+        for byte in data:
+            hash_value = (hash_value * FNV.PRIME64) % FNV.MOD64
+            hash_value = hash_value ^ byte
+        return hash_value
+
 class DiversionBucket:
     def match(self, experiment_context):
         raise NotImplementedError("Subclasses must implement this method")
@@ -10,11 +23,15 @@ class DiversionBucket:
 class UidDiversionBucket(DiversionBucket):
     def __init__(self, total_buckets: int, buckets: str):
         self.total_buckets = total_buckets
-        self.buckets = set(map(int, buckets.split(",")))
+        if buckets:
+            self.buckets = set(map(int, buckets.split(",")))
+        else:
+            self.buckets = set()
 
     def match(self, experiment_context):
-        uid_hash = int(hashlib.md5(experiment_context.uid.encode()).hexdigest(), 16)
+        uid_hash = int(experiment_context.uid)
         bucket = uid_hash % self.total_buckets
+        print(f"Matching UID {experiment_context.uid} with hash {uid_hash} to bucket {bucket} in {self.buckets}")
         return bucket in self.buckets
 
 
@@ -45,11 +62,11 @@ class ExperimentContext:
 class Domain:
     def __init__(self, domain_id, name, flow: int, buckets: str, bucket_type: str, debug_crowd_ids=None, is_default_domain=False, exp_layer_id=None,
                  debug_users=""):
-        self.id = domain_id
+        self.id = int(domain_id)
         self.name = name
         self.debug_crowd_ids = debug_crowd_ids
         self.is_default_domain = is_default_domain
-        self.exp_layer_id = exp_layer_id
+        self.exp_layer_id = int(exp_layer_id) if exp_layer_id is not None else None
         self.features = []
         self.layers = []
         self.debug_users = debug_users
@@ -85,17 +102,12 @@ class Domain:
         return False
 
 
+@dataclass
 class Layer:
     id: int
     name: str
-    experiments: List['Experiment']
-    domains: List[Domain]
-
-    def __init__(self, layer_id, name):
-        self.id = layer_id
-        self.name = name
-        self.experiments = []
-        self.domains = []
+    experiments: List['Experiment'] = field(default_factory=list)
+    domains: List[Domain] = field(default_factory=list)
 
     def add_experiment(self, experiment):
         self.experiments.append(experiment)
@@ -113,9 +125,9 @@ class Experiment:
     buckets: str
     filter_condition: str
     bucket_type: str = "Random"
-    debug_user_set: Set[str] = set()
+    debug_user_set: Set[str] = field(default_factory=set)
     diversion_bucket: Optional[DiversionBucket] = None
-    experiment_versions: List['ExperimentVersion'] = []
+    experiment_versions: List['ExperimentVersion'] = field(default_factory=list)
 
     def add_debug_users(self, users: List[str]):
         self.debug_user_set.update(users)
@@ -150,7 +162,7 @@ class Experiment:
 class ExperimentVersion:
     def __init__(self, exp_version_id, flow, buckets: str, exp_version_name=None, debug_users: str = '',
                  config=None, debug_crowd_ids=None):
-        self.id = exp_version_id
+        self.id = int(exp_version_id)
         self.exp_version_name = exp_version_name
         self.config = config
         self.debug_crowd_ids = debug_crowd_ids
@@ -187,7 +199,7 @@ class ExperimentVersion:
 class Project:
     def __init__(self, project_name=None, project_id=None):
         self.project_name = project_name
-        self.id = project_id
+        self.id = int(project_id)
         self.domains = []
         self.layers = []
         self.default_domain : Optional[Domain] = None