ad_predict.py 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354
  1. import json
  2. import random
  3. import gevent
  4. from db_helper import RedisHelper
  5. from config import set_config
  6. from gevent import monkey, pool
  7. monkey.patch_all()
  8. config_ = set_config()
  9. redis_helper = RedisHelper()
  10. def thompson_process(ad_idea_id):
  11. # 获取ad_idea_id对应的Thompson参数
  12. thompson_param = redis_helper.get_data_from_redis(key_name=f"{config_.THOMPSON_PARAM_KEY_PREFIX}{ad_idea_id}")
  13. if thompson_param is None or thompson_param == '':
  14. # 参数不存在,随机生成[0, 1)之间的浮点数
  15. score = random.random()
  16. random_flag = 'random'
  17. else:
  18. # 参数存在
  19. param_alpha, param_beta = json.loads(thompson_param.strip())
  20. param_alpha, param_beta = int(param_alpha), int(param_beta)
  21. if param_alpha + param_beta >= 100:
  22. # ad_idea_id 曝光数 >= 100,生成参数为(param_alpha+1, param_beta+1)的beta分布随机数
  23. score = random.betavariate(alpha=param_alpha+1, beta=param_beta+1)
  24. random_flag = 'beta'
  25. else:
  26. # ad_idea_id 曝光数 < 100,随机生成[0, 1)之间的浮点数
  27. score = random.random()
  28. random_flag = 'random'
  29. thompson_res = [ad_idea_id, score, thompson_param, random_flag]
  30. return thompson_res
  31. def get_ad_idea_id_with_thompson(mid, ad_idea_id_list):
  32. """利用Thompson采样获取此次要展示的广告创意ID"""
  33. # 限制协程最大并发数:20
  34. gevent_pool = pool.Pool(20)
  35. tasks = [gevent_pool.spawn(thompson_process, ad_idea_id) for ad_idea_id in ad_idea_id_list]
  36. gevent.joinall(tasks)
  37. thompson_res_list = [t.get() for t in tasks]
  38. # 按照score排序
  39. thompson_res_rank = sorted(thompson_res_list, key=lambda x: x[1], reverse=True)
  40. rank_res = {
  41. 'mid': mid,
  42. 'ad_idea_id': thompson_res_rank[0][0],
  43. 'score': thompson_res_rank[0][1],
  44. 'thompson_param': thompson_res_rank[0][2],
  45. 'random_flag': thompson_res_rank[0][3],
  46. 'thompson_res_rank': thompson_res_rank
  47. }
  48. return rank_res