ad_predict.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. import json
  2. import random
  3. import time
  4. import math
  5. import gevent
  6. from db_helper import RedisHelper
  7. from config import set_config
  8. from log import Log
  9. from gevent import monkey, pool
  10. monkey.patch_all()
  11. config_ = set_config()
  12. log_ = Log()
  13. redis_helper = RedisHelper()
  14. def thompson_process(creative_id):
  15. # 获取creative_id对应的Thompson参数
  16. thompson_param_initial = redis_helper.get_data_from_redis(key_name=f"{config_.THOMPSON_PARAM_KEY_PREFIX}{creative_id}")
  17. if thompson_param_initial is None or thompson_param_initial == '':
  18. # 参数不存在,获取默认参数
  19. thompson_param = redis_helper.get_data_from_redis(key_name=f"{config_.THOMPSON_PARAM_KEY_PREFIX}-1")
  20. param_alpha, param_beta = json.loads(thompson_param.strip())
  21. param_alpha, param_beta = int(param_alpha), int(param_beta)
  22. random_flag = 'initial_random'
  23. else:
  24. # 参数存在
  25. param_alpha, param_beta = json.loads(thompson_param_initial.strip())
  26. param_alpha, param_beta = int(param_alpha), int(param_beta)
  27. if param_alpha + param_beta >= 100:
  28. # ad_idea_id 曝光数 >= 100,生成参数为(param_alpha+1, param_beta+1)的beta分布随机数
  29. thompson_param = thompson_param_initial
  30. random_flag = 'beta'
  31. else:
  32. # ad_idea_id 曝光数 < 100,获取默认参数
  33. thompson_param = redis_helper.get_data_from_redis(key_name=f"{config_.THOMPSON_PARAM_KEY_PREFIX}-1")
  34. param_alpha, param_beta = json.loads(thompson_param.strip())
  35. param_alpha, param_beta = int(param_alpha), int(param_beta)
  36. random_flag = 'under_view_initial_random'
  37. # 生成参数为(param_alpha+1, param_beta+1)的beta分布随机数
  38. alpha = math.log(param_alpha + 1) + 1
  39. beta = math.log(param_beta + 1) + 1
  40. score = random.betavariate(alpha=alpha, beta=beta)
  41. betavariate_param = [alpha, beta]
  42. thompson_res = [creative_id, score, thompson_param_initial, thompson_param, betavariate_param, random_flag]
  43. return thompson_res
  44. def get_creative_id_with_thompson(mid, creative_id_list, gevent_pool, sort_strategy):
  45. """利用Thompson采样获取此次要展示的广告创意ID"""
  46. # 限制协程最大并发数:20
  47. tasks = [gevent_pool.spawn(thompson_process, creative_id) for creative_id in creative_id_list]
  48. gevent.joinall(tasks)
  49. thompson_res_list = [t.get() for t in tasks]
  50. # 按照score排序
  51. thompson_res_rank = sorted(thompson_res_list, key=lambda x: x[1], reverse=True)
  52. rank_res = {
  53. 'mid': mid,
  54. 'creative_id': thompson_res_rank[0][0],
  55. 'score': thompson_res_rank[0][1],
  56. 'thompson_param_initial': thompson_res_rank[0][2],
  57. 'thompson_param': thompson_res_rank[0][3],
  58. 'betavariate_param': thompson_res_rank[0][4],
  59. 'random_flag': thompson_res_rank[0][5],
  60. 'sort_strategy': sort_strategy,
  61. 'thompson_res_rank': thompson_res_rank
  62. }
  63. return rank_res
  64. def get_creative_id_with_thompson_weight(mid, creative_id_list, gevent_pool, sort_strategy):
  65. """利用Thompson采样+cvr加权 获取此次要展示的广告创意ID"""
  66. tasks = [gevent_pool.spawn(thompson_process, creative_id) for creative_id in creative_id_list]
  67. gevent.joinall(tasks)
  68. thompson_res_list = [t.get() for t in tasks]
  69. # 获取creative_id对应cvr, 给定对应权重
  70. # st_1 = time.time()
  71. cvr_mapping = {}
  72. creative_weight = {}
  73. key_list = [f"{config_.CREATIVE_CVR_KEY_PREFIX}{creative_id}" for creative_id in creative_id_list]
  74. cvr_list = []
  75. name_list = []
  76. for i in range(len(key_list)):
  77. if i % 20 == 0 and i != 0:
  78. print(len(name_list))
  79. cvr_res = redis_helper.get_batch_key(name_list=name_list)
  80. cvr_list.extend(cvr_res)
  81. name_list = [key_list[i]]
  82. else:
  83. name_list.append(key_list[i])
  84. if len(name_list) > 0:
  85. cvr_res = redis_helper.get_batch_key(name_list=name_list)
  86. cvr_list.extend(cvr_res)
  87. for i, creative_id in enumerate(creative_id_list):
  88. creative_weight[creative_id] = config_.CREATIVE_WEIGHT_INITIAL
  89. cvr = cvr_list[i]
  90. if cvr is None:
  91. continue
  92. try:
  93. cvr_mapping[creative_id] = float(cvr)
  94. except:
  95. continue
  96. # log_.info(f"st1: {(time.time() - st_1) * 1000}ms")
  97. # st_2 = time.time()
  98. cvr_sorted = sorted(cvr_mapping.items(), key=lambda x: x[1], reverse=False)
  99. for i, item in enumerate(cvr_sorted):
  100. creative_id = item[0]
  101. creative_weight[creative_id] += (i * config_.WEIGHT_GRADIENT)
  102. # log_.info(f"st2: {(time.time() - st_2) * 1000}ms")
  103. # 对有cvr的creative进行加权
  104. # st_3 = time.time()
  105. thompson_weight_res_list = []
  106. weight_sum = sum([weight for _, weight in creative_weight.items()])
  107. for thompson_res in thompson_res_list:
  108. creative_id, score = thompson_res[0], thompson_res[1]
  109. weight = creative_weight[creative_id]
  110. if weight > config_.CREATIVE_WEIGHT_INITIAL:
  111. weight_score = score * (1 + weight / weight_sum)
  112. else:
  113. weight_score = score
  114. thompson_weight_res = thompson_res + [creative_weight[creative_id], weight_score]
  115. thompson_weight_res_list.append(thompson_weight_res)
  116. # log_.info(f"st3: {(time.time() - st_3) * 1000}ms")
  117. # 重新排序
  118. thompson_res_rank = sorted(thompson_weight_res_list, key=lambda x: x[7], reverse=True)
  119. rank_res = {
  120. 'mid': mid,
  121. 'creative_id': thompson_res_rank[0][0],
  122. 'score': thompson_res_rank[0][1],
  123. 'thompson_param_initial': thompson_res_rank[0][2],
  124. 'thompson_param': thompson_res_rank[0][3],
  125. 'betavariate_param': thompson_res_rank[0][4],
  126. 'random_flag': thompson_res_rank[0][5],
  127. 'creative_weight': thompson_res_rank[0][6],
  128. 'weight_score': thompson_res_rank[0][7],
  129. 'sort_strategy': sort_strategy,
  130. 'thompson_res_rank': thompson_res_rank
  131. }
  132. return rank_res