Pārlūkot izejas kodu

generate label for mysql

罗俊辉 1 gadu atpakaļ
vecāks
revīzija
1c2a480ec7
1 mainītis faili ar 31 papildinājumiem un 18 dzēšanām
  1. 31 18
      process_data.py

+ 31 - 18
process_data.py

@@ -124,12 +124,22 @@ class SpiderProcess(object):
             print(video_id, "\t", e)
             return 0, 0
 
-    def spider_data_produce(self):
+    def spider_data_produce(self, flag, dt_time=None):
         """
         把 spider_duration 存储到数据库中
         :return:
         """
-        select_sql = "SELECT video_id, video_title, label, channel, out_user_id, spider_mode, out_play_cnt, out_like_cnt, out_share_cnt FROM lightgbm_data WHERE type = 'spider' order by daily_dt_str;"
+        if flag == "train":
+            select_sql = "SELECT video_id, video_title, label, channel, out_user_id, spider_mode, out_play_cnt, out_like_cnt, out_share_cnt FROM lightgbm_data WHERE type = 'spider' order by daily_dt_str;"
+            des_path = "data/train_data/spider_train_{}".format(datetime.datetime.today().strftime("%Y%m%d"))
+        elif flag == "predict":
+            dt_time = datetime.datetime.strptime(dt_time, "%Y%m%d")
+            three_date_before = dt_time - datetime.timedelta(days=4)
+            temp_time = three_date_before.strftime("%Y%m%d")
+            select_sql = f"""SELECT video_id, video_title, label, channel, out_user_id, spider_mode, out_play_cnt, out_like_cnt, out_share_cnt FROM lightgbm_data WHERE type = 'spider' and daily_dt_str = '{temp_time}';"""
+            des_path = "data/predict_data/predict_{}.json".format(temp_time)
+        else:
+            return
         data_list = self.client_spider.select(select_sql)
         df = []
         for line in tqdm(data_list):
@@ -157,8 +167,8 @@ class SpiderProcess(object):
                 continue
         df = pd.DataFrame(df, columns=['label', 'channel', 'out_user_id', 'mode', 'out_play_cnt', 'out_like_cnt',
                                        'out_share_cnt', 'lop', 'duration', 'tag1', 'tag2', 'tag3'])
-        df.to_json("data/train_data/spider_data_{}.json".format(datetime.datetime.today().strftime("%y%m%d")),
-                   orient='records')
+
+        df.to_json(des_path, orient='records')
 
 
 class UserProcess(object):
@@ -247,20 +257,23 @@ class UserProcess(object):
 
 
 if __name__ == "__main__":
-    # D = DataProcessor()
-    # D.producer()
-    # parser = argparse.ArgumentParser()  # 新建参数解释器对象
-    # parser.add_argument("--mode")
-    # parser.add_argument("--category")
-    # parser.add_argument("--dtype", default="whole")
-    # args = parser.parse_args()
-    # mode = args.mode
-    # category = args.category
-    # dtype = args.dtype
-    # D = DataProcessor()
-    # D.producer()
-    S = SpiderProcess()
-    S.spider_data_produce()
+    parser = argparse.ArgumentParser()  # 新建参数解释器对象
+    parser.add_argument("--mode")
+    parser.add_argument("--de")
+    parser.add_argument("--dt")
+    args = parser.parse_args()
+    mode = args.mode
+    D = args.de
+    dt = args.dt
+    match D:
+        case "spider":
+            S = SpiderProcess()
+            S.spider_data_produce(flag=mode, dt_time=dt)
+        case "user":
+            U = UserProcess()
+        case "Data":
+            D = DataProcessor()
+
     # if mode == "train":
     #     print("Loading data and process for training.....")
     #     D = DataProcessor(flag="train", ll=category)