models.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259
  1. from typing import List, Dict, Optional, Set
  2. import json
  3. from dataclasses import dataclass, field
  4. import hashlib
  5. from pqai_agent.logging_service import logger
  6. class FNV:
  7. INIT64 = int("cbf29ce484222325", 16)
  8. PRIME64 = int("100000001b3", 16)
  9. MOD64 = 2**64
  10. @staticmethod
  11. def fnv1_64(data: bytes) -> int:
  12. hash_value = FNV.INIT64
  13. for byte in data:
  14. hash_value = (hash_value * FNV.PRIME64) % FNV.MOD64
  15. hash_value = hash_value ^ byte
  16. return hash_value
  17. class DiversionBucket:
  18. def match(self, experiment_context):
  19. raise NotImplementedError("Subclasses must implement this method")
  20. class UidDiversionBucket(DiversionBucket):
  21. def __init__(self, total_buckets: int, buckets: str):
  22. self.total_buckets = total_buckets
  23. if buckets:
  24. self.buckets = set(map(int, buckets.split(",")))
  25. else:
  26. self.buckets = set()
  27. def match(self, experiment_context):
  28. uid_hash = int(experiment_context.uid)
  29. bucket = uid_hash % self.total_buckets
  30. # print(f"Matching UID {experiment_context.uid} with hash {uid_hash} to bucket {bucket} in {self.buckets}")
  31. return bucket in self.buckets
  32. class FilterDiversionBucket(DiversionBucket):
  33. def __init__(self, filter_condition: str):
  34. self.filter_condition = filter_condition
  35. def match(self, experiment_context):
  36. raise NotImplementedError("not implemented")
  37. class Feature:
  38. def __init__(self, params=None):
  39. self.params = params
  40. def init(self):
  41. # Initialize feature-specific logic
  42. pass
  43. class ExperimentContext:
  44. def __init__(self, uid=None, filter_params=None):
  45. self.uid = uid
  46. self.filter_params = filter_params or {}
  47. def __str__(self):
  48. return f"ExperimentContext(uid={self.uid}, filter_params={self.filter_params})"
  49. class Domain:
  50. def __init__(self, domain_id, name, flow: int, buckets: str, bucket_type: str, debug_crowd_ids=None, is_default_domain=False, exp_layer_id=None,
  51. debug_users=""):
  52. self.id = int(domain_id)
  53. self.name = name
  54. self.debug_crowd_ids = debug_crowd_ids
  55. self.is_default_domain = is_default_domain
  56. self.exp_layer_id = int(exp_layer_id) if exp_layer_id is not None else None
  57. self.features = []
  58. self.layers = []
  59. self.debug_users = debug_users
  60. self.flow = flow
  61. self.buckets = buckets
  62. self.diversion_bucket = None
  63. self.bucket_type = bucket_type
  64. self.debug_user_set = set()
  65. def add_debug_users(self, users: List[str]):
  66. self.debug_user_set.update(users)
  67. def match_debug_users(self, experiment_context):
  68. return experiment_context.uid in self.debug_user_set
  69. def add_feature(self, feature: Feature):
  70. self.features.append(feature)
  71. def add_layer(self, layer):
  72. self.layers.append(layer)
  73. def init(self):
  74. self.debug_user_set.update(self.debug_users.split(","))
  75. self.diversion_bucket = UidDiversionBucket(100, self.buckets)
  76. def match(self, experiment_context):
  77. if self.flow == 0:
  78. return False
  79. elif self.flow == 100:
  80. return True
  81. if self.diversion_bucket:
  82. return self.diversion_bucket.match(experiment_context)
  83. return False
  84. @dataclass
  85. class Layer:
  86. id: int
  87. name: str
  88. experiments: List['Experiment'] = field(default_factory=list)
  89. domains: List[Domain] = field(default_factory=list)
  90. def add_experiment(self, experiment):
  91. self.experiments.append(experiment)
  92. def add_domain(self, domain):
  93. self.domains.append(domain)
  94. @dataclass
  95. class Experiment:
  96. id: int
  97. flow: int
  98. crowd_ids: List[str]
  99. debug_users: str
  100. buckets: str
  101. filter_condition: str
  102. bucket_type: str = "Random"
  103. debug_user_set: Set[str] = field(default_factory=set)
  104. diversion_bucket: Optional[DiversionBucket] = None
  105. experiment_versions: List['ExperimentVersion'] = field(default_factory=list)
  106. def add_debug_users(self, users: List[str]):
  107. self.debug_user_set.update(users)
  108. def match_debug_users(self, experiment_context):
  109. return experiment_context.uid in self.debug_user_set
  110. def add_experiment_version(self, version):
  111. self.experiment_versions.append(version)
  112. def match(self, experiment_context: ExperimentContext) -> bool:
  113. if self.bucket_type == "Random":
  114. if self.flow == 0:
  115. return False
  116. elif self.flow == 100:
  117. return True
  118. if self.diversion_bucket:
  119. return self.diversion_bucket.match(experiment_context)
  120. return False
  121. def init(self):
  122. # 初始化 debug_user_map
  123. if self.debug_users:
  124. self.debug_user_set.update(self.debug_users.split(","))
  125. # 初始化 diversion_bucket
  126. if self.bucket_type == "Random": # ExpBucketTypeRand
  127. self.diversion_bucket = UidDiversionBucket(100, self.buckets)
  128. elif self.bucket_type == "Condition" and self.filter_condition: # ExpBucketTypeCond
  129. self.diversion_bucket = FilterDiversionBucket(self.filter_condition)
  130. class ExperimentVersion:
  131. def __init__(self, exp_version_id, flow, buckets: str, exp_id: int, exp_version_name=None,
  132. debug_users: str = '', config=None, debug_crowd_ids=None):
  133. self.id = int(exp_version_id)
  134. self.exp_version_name = exp_version_name
  135. self.exp_id = int(exp_id)
  136. self.config = config
  137. self.debug_crowd_ids = debug_crowd_ids
  138. self.debug_users = debug_users
  139. self.params = {}
  140. self.flow = flow
  141. self.buckets = buckets
  142. self.debug_user_set = set()
  143. self.diversion_bucket = None
  144. def add_debug_users(self, users: List[str]):
  145. self.debug_user_set.update(users)
  146. def match_debug_users(self, experiment_context):
  147. return experiment_context.uid in self.debug_user_set
  148. def match(self, experiment_context: ExperimentContext):
  149. if self.flow == 0:
  150. return False
  151. elif self.flow == 100:
  152. return True
  153. if self.diversion_bucket:
  154. return self.diversion_bucket.match(experiment_context)
  155. return False
  156. def init(self):
  157. self.debug_user_set.update(self.debug_users.split(","))
  158. self.diversion_bucket = UidDiversionBucket(100, self.buckets)
  159. params = json.loads(self.config)
  160. for kv in params:
  161. self.params[kv['key']] = kv['value']
  162. class Project:
  163. def __init__(self, name=None, project_id=None):
  164. self.name = name
  165. self.id = int(project_id)
  166. self.domains = []
  167. self.layers = []
  168. self.default_domain : Optional[Domain] = None
  169. self.layer_map = {}
  170. self.domain_map = {}
  171. def add_domain(self, domain):
  172. self.domains.append(domain)
  173. self.domain_map[domain.id] = domain
  174. def add_layer(self, layer):
  175. self.layers.append(layer)
  176. self.layer_map[layer.id] = layer
  177. def set_default_domain(self, domain: Domain):
  178. self.default_domain = domain
  179. class ExperimentResult:
  180. def __init__(self, project=None, experiment_context=None):
  181. self.project = project
  182. if project:
  183. self.project_name = project.name
  184. else:
  185. self.project_name = None
  186. self.experiment_context = experiment_context
  187. self.params = {}
  188. self.experiment_versions: List[ExperimentVersion] = []
  189. self.exp_id = ""
  190. def add_params(self, params: Dict[str, str]):
  191. for key, value in params.items():
  192. if key in self.params:
  193. logger.warning(f"Duplicate key '{key}' in params, overwriting value: {self.params[key]} with {value}")
  194. self.params[key] = value
  195. def add_experiment_version(self, version):
  196. self.experiment_versions.append(version)
  197. def init(self):
  198. buf = []
  199. if self.project:
  200. buf.append(f"ER{self.project.id}")
  201. if self.experiment_versions:
  202. for experiment_version in self.experiment_versions:
  203. buf.append(f"_E{experiment_version.exp_id}")
  204. buf.append(f"#EV{experiment_version.id}")
  205. self.exp_id = "".join(buf)
  206. def __str__(self):
  207. return f"ExperimentResult(project={self.project_name}, params={self.params}, experiment_context={self.experiment_context}, experiment_versions={self.experiment_versions})"