瀏覽代碼

update rovScore to equal difference

liqian 3 年之前
父節點
當前提交
b1f71aeda6
共有 2 個文件被更改,包括 49 次插入8 次删除
  1. 3 0
      config.py
  2. 46 8
      rov_train.py

+ 3 - 0
config.py

@@ -57,6 +57,9 @@ class BaseConfig(object):
     # 生效中的置顶视频列表 redis key
     TOP_VIDEO_LIST_KEY_NAME = 'com.weiqu.video.top.item.score.area'
 
+    # rovScore公差
+    ROV_SCORE_D = 0.001
+
 
 class DevelopmentConfig(BaseConfig):
     """开发环境配置"""

+ 46 - 8
rov_train.py

@@ -125,7 +125,26 @@ def pack_result_to_csv(filename, sort_columns=None, filepath=config_.DATA_DIR_PA
     :param sort_columns: 指定排序列名列名,type-list, 默认为None
     :param filepath: csv文件存放路径,默认为config_.DATA_DIR_PATH
     :param ascending: 是否按指定列的数组升序排列,默认为True,即升序排列
-    :param data: 数据
+    :param data: 数据, type-dict
+    :return: None
+    """
+    if not os.path.exists(filepath):
+        os.makedirs(filepath)
+    file = os.path.join(filepath, filename)
+    df = pd.DataFrame(data=data)
+    if sort_columns:
+        df = df.sort_values(by=sort_columns, ascending=ascending)
+    df.to_csv(file, index=False)
+
+
+def pack_list_result_to_csv(filename, data, sort_columns=None, filepath=config_.DATA_DIR_PATH, ascending=True):
+    """
+    打包数据并存入csv, 数据为字典列表
+    :param filename: csv文件名
+    :param data: 数据,type-list [{}, {},...]
+    :param sort_columns: 指定排序列名列名,type-list, 默认为None
+    :param filepath: csv文件存放路径,默认为config_.DATA_DIR_PATH
+    :param ascending: 是否按指定列的数组升序排列,默认为True,即升序排列
     :return: None
     """
     if not os.path.exists(filepath):
@@ -147,25 +166,44 @@ def predict():
     # 预测
     y_ = model.predict(x)
     log_.info('predict finished!')
+
     # 将结果进行归一化到[0, 100]
     normal_y_ = data_normalization(list(y_))
     log_.info('normalization finished!')
+
+    # 按照normal_y_降序排序
+    predict_data = []
+    for i, video_id in enumerate(video_ids):
+        data = {'video_id': video_id, 'normal_y_': normal_y_[i], 'y_': y_[i], 'y': y[i]}
+        predict_data.append(data)
+    predict_data_sorted = sorted(predict_data, key=lambda temp: temp['normal_y_'], reverse=True)
+
+    # 按照排序,从100以固定差值做等差递减,以该值作为rovScore
+    predict_result = []
+    redis_data = {}
+    json_data = []
+    for j, item in enumerate(predict_data_sorted):
+        video_id = int(item['video_id'])
+        rov_score = 100 - j * config_.ROV_SCORE_D
+        item['rov_score'] = rov_score
+        predict_result.append(item)
+        redis_data[video_id] = rov_score
+        json_data.append({'videoId': video_id, 'rovScore': rov_score})
+
     # 打包预测结果存入csv
-    predict_data = {'normal_y_': normal_y_, 'y_': y_, 'y': y, 'video_ids': video_ids}
     predict_result_filename = 'predict.csv'
-    pack_result_to_csv(filename=predict_result_filename, sort_columns=['normal_y_'], ascending=False, **predict_data)
+    pack_list_result_to_csv(filename=predict_result_filename, sort_columns=['rov_score'],
+                            ascending=False, data=predict_result)
+
     # 上传redis
-    redis_data = {}
-    json_data = []
-    for i in range(len(video_ids)):
-        redis_data[video_ids[i]] = normal_y_[i]
-        json_data.append({'videoId': video_ids[i], 'rovScore': normal_y_[i]})
     key_name = config_.RECALL_KEY_NAME_PREFIX + time.strftime('%Y%m%d')
     redis_helper = RedisHelper()
     redis_helper.add_data_with_zset(key_name=key_name, data=redis_data)
     log_.info('data to redis finished!')
+
     # 清空修改ROV的视频数据
     redis_helper.del_keys(key_name=config_.UPDATE_ROV_KEY_NAME)
+
     # 通知后端更新数据
     # result = request_post(request_url=config_.NOTIFY_BACKEND_UPDATE_ROV_SCORE_URL, request_data={'videos': json_data})
     # if result['code'] == 0: