|  | @@ -9,11 +9,12 @@ from sklearn.model_selection import train_test_split
 | 
	
		
			
				|  |  |  from sklearn.metrics import mean_absolute_error, r2_score, mean_absolute_percentage_error
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  from config import set_config
 | 
	
		
			
				|  |  | -from utils import read_from_pickle, write_to_pickle, data_normalization, request_post, filter_video_status
 | 
	
		
			
				|  |  | +from utils import read_from_pickle, write_to_pickle, data_normalization, \
 | 
	
		
			
				|  |  | +    request_post, filter_video_status, update_video_w_h_rate
 | 
	
		
			
				|  |  |  from log import Log
 | 
	
		
			
				|  |  |  from db_helper import RedisHelper, MysqlHelper
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -config_ = set_config()
 | 
	
		
			
				|  |  | +config_, env = set_config()
 | 
	
		
			
				|  |  |  log_ = Log()
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  
 | 
	
	
		
			
				|  | @@ -124,7 +125,7 @@ def pack_result_to_csv(filename, sort_columns=None, filepath=config_.DATA_DIR_PA
 | 
	
		
			
				|  |  |      :param sort_columns: 指定排序列名列名,type-list, 默认为None
 | 
	
		
			
				|  |  |      :param filepath: csv文件存放路径,默认为config_.DATA_DIR_PATH
 | 
	
		
			
				|  |  |      :param ascending: 是否按指定列的数组升序排列,默认为True,即升序排列
 | 
	
		
			
				|  |  | -    :param data: 数据
 | 
	
		
			
				|  |  | +    :param data: 数据, type-dict
 | 
	
		
			
				|  |  |      :return: None
 | 
	
		
			
				|  |  |      """
 | 
	
		
			
				|  |  |      if not os.path.exists(filepath):
 | 
	
	
		
			
				|  | @@ -136,6 +137,26 @@ def pack_result_to_csv(filename, sort_columns=None, filepath=config_.DATA_DIR_PA
 | 
	
		
			
				|  |  |      df.to_csv(file, index=False)
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +def pack_list_result_to_csv(filename, data, columns=None, sort_columns=None, filepath=config_.DATA_DIR_PATH, ascending=True):
 | 
	
		
			
				|  |  | +    """
 | 
	
		
			
				|  |  | +    打包数据并存入csv, 数据为字典列表
 | 
	
		
			
				|  |  | +    :param filename: csv文件名
 | 
	
		
			
				|  |  | +    :param data: 数据,type-list [{}, {},...]
 | 
	
		
			
				|  |  | +    :param columns: 列名顺序
 | 
	
		
			
				|  |  | +    :param sort_columns: 指定排序列名列名,type-list, 默认为None
 | 
	
		
			
				|  |  | +    :param filepath: csv文件存放路径,默认为config_.DATA_DIR_PATH
 | 
	
		
			
				|  |  | +    :param ascending: 是否按指定列的数组升序排列,默认为True,即升序排列
 | 
	
		
			
				|  |  | +    :return: None
 | 
	
		
			
				|  |  | +    """
 | 
	
		
			
				|  |  | +    if not os.path.exists(filepath):
 | 
	
		
			
				|  |  | +        os.makedirs(filepath)
 | 
	
		
			
				|  |  | +    file = os.path.join(filepath, filename)
 | 
	
		
			
				|  |  | +    df = pd.DataFrame(data=data)
 | 
	
		
			
				|  |  | +    if sort_columns:
 | 
	
		
			
				|  |  | +        df = df.sort_values(by=sort_columns, ascending=ascending)
 | 
	
		
			
				|  |  | +    df.to_csv(file, index=False, columns=columns)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |  def predict():
 | 
	
		
			
				|  |  |      """预测"""
 | 
	
		
			
				|  |  |      # 读取预测数据并进行清洗
 | 
	
	
		
			
				|  | @@ -146,48 +167,74 @@ def predict():
 | 
	
		
			
				|  |  |      # 预测
 | 
	
		
			
				|  |  |      y_ = model.predict(x)
 | 
	
		
			
				|  |  |      log_.info('predict finished!')
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |      # 将结果进行归一化到[0, 100]
 | 
	
		
			
				|  |  |      normal_y_ = data_normalization(list(y_))
 | 
	
		
			
				|  |  |      log_.info('normalization finished!')
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    # 按照normal_y_降序排序
 | 
	
		
			
				|  |  | +    predict_data = []
 | 
	
		
			
				|  |  | +    for i, video_id in enumerate(video_ids):
 | 
	
		
			
				|  |  | +        data = {'video_id': video_id, 'normal_y_': normal_y_[i], 'y_': y_[i], 'y': y[i]}
 | 
	
		
			
				|  |  | +        predict_data.append(data)
 | 
	
		
			
				|  |  | +    predict_data_sorted = sorted(predict_data, key=lambda temp: temp['normal_y_'], reverse=True)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    # 按照排序,从100以固定差值做等差递减,以该值作为rovScore
 | 
	
		
			
				|  |  | +    predict_result = []
 | 
	
		
			
				|  |  | +    redis_data = {}
 | 
	
		
			
				|  |  | +    json_data = []
 | 
	
		
			
				|  |  | +    video_id_list = []
 | 
	
		
			
				|  |  | +    for j, item in enumerate(predict_data_sorted):
 | 
	
		
			
				|  |  | +        video_id = int(item['video_id'])
 | 
	
		
			
				|  |  | +        rov_score = 100 - j * config_.ROV_SCORE_D
 | 
	
		
			
				|  |  | +        item['rov_score'] = rov_score
 | 
	
		
			
				|  |  | +        predict_result.append(item)
 | 
	
		
			
				|  |  | +        redis_data[video_id] = rov_score
 | 
	
		
			
				|  |  | +        json_data.append({'videoId': video_id, 'rovScore': rov_score})
 | 
	
		
			
				|  |  | +        video_id_list.append(video_id)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |      # 打包预测结果存入csv
 | 
	
		
			
				|  |  | -    predict_data = {'normal_y_': normal_y_, 'y_': y_, 'y': y, 'video_ids': video_ids}
 | 
	
		
			
				|  |  |      predict_result_filename = 'predict.csv'
 | 
	
		
			
				|  |  | -    pack_result_to_csv(filename=predict_result_filename, sort_columns=['normal_y_'], ascending=False, **predict_data)
 | 
	
		
			
				|  |  | +    pack_list_result_to_csv(filename=predict_result_filename,
 | 
	
		
			
				|  |  | +                            data=predict_result,
 | 
	
		
			
				|  |  | +                            columns=['video_id', 'rov_score', 'normal_y_', 'y_', 'y'],
 | 
	
		
			
				|  |  | +                            sort_columns=['rov_score'],
 | 
	
		
			
				|  |  | +                            ascending=False)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |      # 上传redis
 | 
	
		
			
				|  |  | -    redis_data = {}
 | 
	
		
			
				|  |  | -    json_data = []
 | 
	
		
			
				|  |  | -    for i in range(len(video_ids)):
 | 
	
		
			
				|  |  | -        redis_data[video_ids[i]] = normal_y_[i]
 | 
	
		
			
				|  |  | -        json_data.append({'videoId': video_ids[i], 'rovScore': normal_y_[i]})
 | 
	
		
			
				|  |  |      key_name = config_.RECALL_KEY_NAME_PREFIX + time.strftime('%Y%m%d')
 | 
	
		
			
				|  |  |      redis_helper = RedisHelper()
 | 
	
		
			
				|  |  |      redis_helper.add_data_with_zset(key_name=key_name, data=redis_data)
 | 
	
		
			
				|  |  |      log_.info('data to redis finished!')
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    # 清空修改ROV的视频数据
 | 
	
		
			
				|  |  | +    redis_helper.del_keys(key_name=config_.UPDATE_ROV_KEY_NAME)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |      # 通知后端更新数据
 | 
	
		
			
				|  |  | +    log_.info('json_data count = {}'.format(len(json_data)))
 | 
	
		
			
				|  |  |      result = request_post(request_url=config_.NOTIFY_BACKEND_UPDATE_ROV_SCORE_URL, request_data={'videos': json_data})
 | 
	
		
			
				|  |  |      if result['code'] == 0:
 | 
	
		
			
				|  |  |          log_.info('notify backend success!')
 | 
	
		
			
				|  |  |      else:
 | 
	
		
			
				|  |  |          log_.error('notify backend fail!')
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +    # 更新视频的宽高比数据
 | 
	
		
			
				|  |  | +    if video_id_list:
 | 
	
		
			
				|  |  | +        update_video_w_h_rate(video_ids=video_id_list,
 | 
	
		
			
				|  |  | +                              key_name=config_.W_H_RATE_UP_1_VIDEO_LIST_KEY_NAME['rov_recall'])
 | 
	
		
			
				|  |  | +        log_.info('update video w_h_rate to redis finished!')
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  def predict_test():
 | 
	
		
			
				|  |  |      """测试环境数据生成"""
 | 
	
		
			
				|  |  |      # 获取测试环境中最近发布的40000条视频
 | 
	
		
			
				|  |  | -    mysql_info = {
 | 
	
		
			
				|  |  | -        'host': 'rm-bp1k5853td1r25g3n690.mysql.rds.aliyuncs.com',
 | 
	
		
			
				|  |  | -        'port': 3306,
 | 
	
		
			
				|  |  | -        'user': 'wx2016_longvideo',
 | 
	
		
			
				|  |  | -        'password': 'wx2016_longvideoP@assword1234',
 | 
	
		
			
				|  |  | -        'db': 'longvideo'
 | 
	
		
			
				|  |  | -    }
 | 
	
		
			
				|  |  |      sql = "SELECT id FROM wx_video ORDER BY id DESC LIMIT 40000;"
 | 
	
		
			
				|  |  | -    mysql_helper = MysqlHelper(mysql_info=mysql_info)
 | 
	
		
			
				|  |  | +    mysql_helper = MysqlHelper()
 | 
	
		
			
				|  |  |      data = mysql_helper.get_data(sql=sql)
 | 
	
		
			
				|  |  |      video_ids = [video[0] for video in data]
 | 
	
		
			
				|  |  |      # 视频状态过滤
 | 
	
		
			
				|  |  |      filtered_videos = filter_video_status(video_ids)
 | 
	
		
			
				|  |  | -    log_.info('filtered_videos nums={}'.format(len(filtered_videos)))
 | 
	
		
			
				|  |  | +    log_.info('filtered_videos count = {}'.format(len(filtered_videos)))
 | 
	
		
			
				|  |  |      # 随机生成 0-100 数作为分数
 | 
	
		
			
				|  |  |      redis_data = {}
 | 
	
		
			
				|  |  |      json_data = []
 | 
	
	
		
			
				|  | @@ -195,17 +242,25 @@ def predict_test():
 | 
	
		
			
				|  |  |          score = random.uniform(0, 100)
 | 
	
		
			
				|  |  |          redis_data[video_id] = score
 | 
	
		
			
				|  |  |          json_data.append({'videoId': video_id, 'rovScore': score})
 | 
	
		
			
				|  |  | +    log_.info('json_data count = {}'.format(len(json_data)))
 | 
	
		
			
				|  |  |      # 上传Redis
 | 
	
		
			
				|  |  |      redis_helper = RedisHelper()
 | 
	
		
			
				|  |  |      key_name = config_.RECALL_KEY_NAME_PREFIX + time.strftime('%Y%m%d')
 | 
	
		
			
				|  |  |      redis_helper.add_data_with_zset(key_name=key_name, data=redis_data)
 | 
	
		
			
				|  |  |      log_.info('test data to redis finished!')
 | 
	
		
			
				|  |  | +    # 清空修改ROV的视频数据
 | 
	
		
			
				|  |  | +    redis_helper.del_keys(key_name=config_.UPDATE_ROV_KEY_NAME)
 | 
	
		
			
				|  |  |      # 通知后端更新数据
 | 
	
		
			
				|  |  |      result = request_post(request_url=config_.NOTIFY_BACKEND_UPDATE_ROV_SCORE_URL, request_data={'videos': json_data})
 | 
	
		
			
				|  |  |      if result['code'] == 0:
 | 
	
		
			
				|  |  |          log_.info('notify backend success!')
 | 
	
		
			
				|  |  |      else:
 | 
	
		
			
				|  |  |          log_.error('notify backend fail!')
 | 
	
		
			
				|  |  | +    # 更新视频的宽高比数据
 | 
	
		
			
				|  |  | +    if filtered_videos:
 | 
	
		
			
				|  |  | +        update_video_w_h_rate(video_ids=filtered_videos,
 | 
	
		
			
				|  |  | +                              key_name=config_.W_H_RATE_UP_1_VIDEO_LIST_KEY_NAME['rov_recall'])
 | 
	
		
			
				|  |  | +        log_.info('update video w_h_rate to redis finished!')
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  if __name__ == '__main__':
 | 
	
	
		
			
				|  | @@ -220,6 +275,11 @@ if __name__ == '__main__':
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      log_.info('rov model predict start...')
 | 
	
		
			
				|  |  |      predict_start = time.time()
 | 
	
		
			
				|  |  | -    predict()
 | 
	
		
			
				|  |  | +    if env in ['dev', 'test']:
 | 
	
		
			
				|  |  | +        predict_test()
 | 
	
		
			
				|  |  | +    elif env in ['pre', 'pro']:
 | 
	
		
			
				|  |  | +        predict()
 | 
	
		
			
				|  |  | +    else:
 | 
	
		
			
				|  |  | +        log_.error('env error')
 | 
	
		
			
				|  |  |      predict_end = time.time()
 | 
	
		
			
				|  |  |      log_.info('rov model predict end, execute time = {}ms'.format((predict_end - predict_start)*1000))
 |