浏览代码

add thompson_weight

liqian 1 年之前
父节点
当前提交
5170b2b7b9
共有 3 个文件被更改,包括 62 次插入3 次删除
  1. 48 1
      ad_predict.py
  2. 9 2
      app.py
  3. 5 0
      config.py

+ 48 - 1
ad_predict.py

@@ -48,7 +48,7 @@ def thompson_process(creative_id):
     return thompson_res
 
 
-def get_creative_id_with_thompson(mid, creative_id_list, gevent_pool):
+def get_creative_id_with_thompson(mid, creative_id_list, gevent_pool, sort_strategy):
     """利用Thompson采样获取此次要展示的广告创意ID"""
     # 限制协程最大并发数:20
     tasks = [gevent_pool.spawn(thompson_process, creative_id) for creative_id in creative_id_list]
@@ -64,6 +64,53 @@ def get_creative_id_with_thompson(mid, creative_id_list, gevent_pool):
         'thompson_param': thompson_res_rank[0][3],
         'betavariate_param': thompson_res_rank[0][4],
         'random_flag': thompson_res_rank[0][5],
+        'sort_strategy': sort_strategy,
+        'thompson_res_rank': thompson_res_rank
+    }
+    return rank_res
+
+
+def get_creative_id_with_thompson_weight(mid, creative_id_list, gevent_pool, sort_strategy):
+    """利用Thompson采样+cvr加权 获取此次要展示的广告创意ID"""
+    tasks = [gevent_pool.spawn(thompson_process, creative_id) for creative_id in creative_id_list]
+    gevent.joinall(tasks)
+    thompson_res_list = [t.get() for t in tasks]
+    # 获取creative_id对应cvr, 给定对应权重
+    cvr_mapping = {}
+    creative_weight = {}
+    for creative_id in creative_id_list:
+        creative_weight[creative_id] = config_.CREATIVE_WEIGHT_INITIAL
+        cvr = redis_helper.get_data_from_redis(
+            key_name=f"{config_.CREATIVE_CVR_KEY_PREFIX}{creative_id}")
+        if cvr is None:
+            continue
+        cvr_mapping[creative_id] = cvr
+    cvr_sorted = sorted(cvr_mapping.items(), key=lambda x: x[1], reverse=False)
+
+    for i, item in enumerate(cvr_sorted):
+        creative_id = item[0]
+        creative_weight[creative_id] += (i * config_.WEIGHT_GRADIENT)
+
+    # thompson_score * weight
+    thompson_weight_res_list = []
+    weight_sum = sum([weight for _, weight in creative_weight.items()])
+    for thompson_res in thompson_res_list:
+        creative_id, score = thompson_res[0], thompson_res[1]
+        weight_score = score * (1 + creative_weight[creative_id] / weight_sum)
+        thompson_weight_res_list.append(thompson_res.add([creative_weight[creative_id], weight_score]))
+    # 重新排序
+    thompson_res_rank = sorted(thompson_weight_res_list, key=lambda x: x[6], reverse=True)
+    rank_res = {
+        'mid': mid,
+        'creative_id': thompson_res_rank[0][0],
+        'score': thompson_res_rank[0][1],
+        'thompson_param_initial': thompson_res_rank[0][2],
+        'thompson_param': thompson_res_rank[0][3],
+        'betavariate_param': thompson_res_rank[0][4],
+        'random_flag': thompson_res_rank[0][5],
+        'creative_weight': thompson_res_rank[0][6],
+        'weight_score': thompson_res_rank[0][7],
+        'sort_strategy': sort_strategy,
         'thompson_res_rank': thompson_res_rank
     }
     return rank_res

+ 9 - 2
app.py

@@ -11,11 +11,12 @@ monkey.patch_all()
 from flask import Flask, request
 from log import Log
 from config import set_config
-from ad_predict import get_creative_id_with_thompson
+from ad_predict import get_creative_id_with_thompson, get_creative_id_with_thompson_weight
 
 app = Flask(__name__)
 log_ = Log()
 config_ = set_config()
+# 限制协程最大并发数:100
 gevent_pool = pool.Pool(100)
 
 # log_.info(f"server start...")
@@ -33,7 +34,13 @@ def get_creative_id():
         request_data = json.loads(request.get_data())
         mid = request_data.get('mid')
         creative_id_list = request_data.get('creativeIdList')
-        thompson_result = get_creative_id_with_thompson(mid=mid, creative_id_list=creative_id_list, gevent_pool=gevent_pool)
+        sort_strategy = request_data.get('sortStrategy', 'thompson')
+        if sort_strategy == 'thompson_weight':
+            thompson_result = get_creative_id_with_thompson_weight(mid=mid, creative_id_list=creative_id_list,
+                                                                   gevent_pool=gevent_pool, sort_strategy=sort_strategy)
+        else:
+            thompson_result = get_creative_id_with_thompson(mid=mid, creative_id_list=creative_id_list,
+                                                            gevent_pool=gevent_pool, sort_strategy=sort_strategy)
         result = {'code': 200, 'message': 'success', 'data': {'mid': mid, 'creativeId': thompson_result['creative_id']}}
         log_message = {
             'requestUri': '/ad/predict/getCreativeId',

+ 5 - 0
config.py

@@ -4,6 +4,11 @@ import os
 class BaseConfig(object):
     # creativeId对应Thompson参数结果存放 redis key 前缀,完整格式:thompson:param:{creative_id}
     THOMPSON_PARAM_KEY_PREFIX = 'thompson:param:'
+    # creativeId对应cvr结果存放 redis key 前缀,完整格式:creative:cvr:{creativeId}
+    CREATIVE_CVR_KEY_PREFIX = 'creative:cvr:'
+    # 默认权重
+    CREATIVE_WEIGHT_INITIAL = 100
+    WEIGHT_GRADIENT = 10
 
 
 class TestConfig(BaseConfig):