ad_predict.py 2.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465
  1. import json
  2. import random
  3. import time
  4. import gevent
  5. from db_helper import RedisHelper
  6. from config import set_config
  7. from log import Log
  8. from gevent import monkey, pool
  9. monkey.patch_all()
  10. config_ = set_config()
  11. log_ = Log()
  12. redis_helper = RedisHelper()
  13. def thompson_process(creative_id):
  14. # 获取creative_id对应的Thompson参数
  15. thompson_param_initial = redis_helper.get_data_from_redis(key_name=f"{config_.THOMPSON_PARAM_KEY_PREFIX}{creative_id}")
  16. if thompson_param_initial is None or thompson_param_initial == '':
  17. # 参数不存在,获取默认参数
  18. thompson_param = redis_helper.get_data_from_redis(key_name=f"{config_.THOMPSON_PARAM_KEY_PREFIX}-1")
  19. param_alpha, param_beta = json.loads(thompson_param.strip())
  20. param_alpha, param_beta = int(param_alpha), int(param_beta)
  21. random_flag = 'initial_random'
  22. else:
  23. # 参数存在
  24. param_alpha, param_beta = json.loads(thompson_param_initial.strip())
  25. param_alpha, param_beta = int(param_alpha), int(param_beta)
  26. if param_alpha + param_beta >= 100:
  27. # ad_idea_id 曝光数 >= 100,生成参数为(param_alpha+1, param_beta+1)的beta分布随机数
  28. thompson_param = thompson_param_initial
  29. random_flag = 'beta'
  30. else:
  31. # ad_idea_id 曝光数 < 100,获取默认参数
  32. thompson_param = redis_helper.get_data_from_redis(key_name=f"{config_.THOMPSON_PARAM_KEY_PREFIX}-1")
  33. param_alpha, param_beta = json.loads(thompson_param.strip())
  34. param_alpha, param_beta = int(param_alpha), int(param_beta)
  35. random_flag = 'initial_random'
  36. # 生成参数为(param_alpha+1, param_beta+1)的beta分布随机数
  37. score = random.betavariate(alpha=param_alpha + 1, beta=param_beta + 1)
  38. thompson_res = [creative_id, score, thompson_param_initial, thompson_param, random_flag]
  39. return thompson_res
  40. def get_creative_id_with_thompson(mid, creative_id_list, gevent_pool):
  41. """利用Thompson采样获取此次要展示的广告创意ID"""
  42. # 限制协程最大并发数:20
  43. tasks = [gevent_pool.spawn(thompson_process, creative_id) for creative_id in creative_id_list]
  44. gevent.joinall(tasks)
  45. thompson_res_list = [t.get() for t in tasks]
  46. # 按照score排序
  47. thompson_res_rank = sorted(thompson_res_list, key=lambda x: x[1], reverse=True)
  48. rank_res = {
  49. 'mid': mid,
  50. 'creative_id': thompson_res_rank[0][0],
  51. 'score': thompson_res_rank[0][1],
  52. 'thompson_param_initial': thompson_res_rank[0][2],
  53. 'thompson_param': thompson_res_rank[0][3],
  54. 'random_flag': thompson_res_rank[0][4],
  55. 'thompson_res_rank': thompson_res_rank
  56. }
  57. return rank_res