liqian пре 3 година
родитељ
комит
54b1b4d34f
2 измењених фајлова са 19 додато и 12 уклоњено
  1. 6 5
      config.py
  2. 13 7
      rov_train.py

+ 6 - 5
config.py

@@ -27,7 +27,7 @@ class BaseConfig(object):
     TRAIN_DELTA_DAYS = 30
     # 训练数据表名
     TRAIN_PROJECT = 'usercdm'
-    TRAIN_TABLE = 'rov_feature_add_v1'
+    TRAIN_TABLE = 'rov_feature_add_v1_x'
     # 训练数据文件存放路径
     TRAIN_DATA_FILENAME = 'train_data.pickle'
 
@@ -37,7 +37,7 @@ class BaseConfig(object):
     PREDICT_DELTA_DAYS = 1
     # 预测数据表名
     PREDICT_PROJECT = 'usercdm'
-    PREDICT_TABLE = 'rov_predict_table_add_v1'
+    PREDICT_TABLE = 'rov_predict_table_add_v1_x'
     # 预测数据文件存放路径
     PREDICT_DATA_FILENAME = 'predict_data.pickle'
 
@@ -119,7 +119,8 @@ class DevelopmentConfig(BaseConfig):
     }
 
     # Hologres视频状态存储表名
-    VIDEO_STATUS = 'longvideo_test.dwd_mdm_item_video_stat'
+    # VIDEO_STATUS = 'longvideo_test.dwd_mdm_item_video_stat'
+    VIDEO_STATUS = 'longvideo.dwd_mdm_item_video_stat'
 
     # 从流量池获取视频接口地址
     GET_VIDEOS_FROM_POOL_URL = 'http://testapi-internal.piaoquantv.com/flowpool/video/getAllVideo'
@@ -299,8 +300,8 @@ class ProductionConfig(BaseConfig):
 
 def set_config():
     # 获取环境变量 ROV_OFFLINE_ENV
-    env = os.environ.get('ROV_OFFLINE_ENV')
-    # env = 'dev'
+    # env = os.environ.get('ROV_OFFLINE_ENV')
+    env = 'dev'
     if env is None:
         log_.error('ENV ERROR: is None!')
         return

+ 13 - 7
rov_train.py

@@ -62,6 +62,7 @@ def process_predict_data(filename):
     """
     # 获取数据
     data = read_from_pickle(filename)
+    print(len(data))
 
     # 获取视频id列
     video_ids = data['videoid']
@@ -69,6 +70,7 @@ def process_predict_data(filename):
     video_id_list = [int(video_id) for video_id in video_ids]
     filtered_videos = [str(item) for item in filter_video_status(video_ids=video_id_list)]
     data = data.loc[data['videoid'].isin(filtered_videos)]
+    print(len(data))
 
     video_id_final = data['videoid']
 
@@ -236,7 +238,7 @@ def predict():
                             columns=['video_id', 'rov_score', 'normal_y_', 'y_'],
                             sort_columns=['rov_score'],
                             ascending=False)
-
+    """
     # 上传redis
     key_name = config_.RECALL_KEY_NAME_PREFIX + time.strftime('%Y%m%d')
     redis_helper = RedisHelper()
@@ -253,6 +255,7 @@ def predict():
         log_.info('notify backend success!')
     else:
         log_.error('notify backend fail!')
+    """
 
     # ##### 下线
     # # 更新视频的宽高比数据
@@ -302,6 +305,7 @@ def predict_test():
 
 
 if __name__ == '__main__':
+    """
     log_.info('rov model train start...')
     train_start = time.time()
     train_filename = config_.TRAIN_DATA_FILENAME
@@ -310,14 +314,16 @@ if __name__ == '__main__':
     train(X, Y, features=fea)
     train_end = time.time()
     log_.info('rov model train end, execute time = {}ms'.format((train_end - train_start)*1000))
+    """
 
     log_.info('rov model predict start...')
     predict_start = time.time()
-    if env in ['dev', 'test']:
-        predict_test()
-    elif env in ['pre', 'pro']:
-        predict()
-    else:
-        log_.error('env error')
+    predict()
+    # if env in ['dev', 'test']:
+    #     predict_test()
+    # elif env in ['pre', 'pro']:
+    #     predict()
+    # else:
+    #     log_.error('env error')
     predict_end = time.time()
     log_.info('rov model predict end, execute time = {}ms'.format((predict_end - predict_start)*1000))