罗俊辉 1 рік тому
батько
коміт
4722101176
1 змінених файлів з 13 додано та 3 видалено
  1. 13 3
      process_data.py

+ 13 - 3
process_data.py

@@ -238,7 +238,17 @@ class UserProcess(object):
         生成user训练数据
         :return:
         """
-        sql = "select title, label, uid, channel, user_fans, user_view_30, user_share_30, user_return_30, user_rov, user_str, user_return_videos_30, user_return_videos_3, user_return_3, user_view_3, user_share_3, address from lighgbm_data where type = 'userupload';"
+        dt_time = datetime.datetime.strptime(dt_time, "%Y%m%d")
+        three_date_before = dt_time + datetime.timedelta(days=4)
+        temp_time = three_date_before.strftime("%Y%m%d")
+        if flag == "train":
+            sql = "select title, label, uid, channel, user_fans, user_view_30, user_share_30, user_return_30, user_rov, user_str, user_return_videos_30, user_return_videos_3, user_return_3, user_view_3, user_share_3, address from lighgbm_data where type = 'userupload' and daily_dt_str >= '20240305';"
+            des_path = "data/train_data/spider_train_{}".format(datetime.datetime.today().strftime("%Y%m%d"))
+        elif flag == "predict":
+            sql = f"""select title, label, uid, channel, user_fans, user_view_30, user_share_30, user_return_30, user_rov, user_str, user_return_videos_30, user_return_videos_3, user_return_3, user_view_3, user_share_3, address from lighgbm_data where type = 'userupload' and daily_dt_str = '{temp_time}';"""
+            des_path = "data/predict_data/predict_{}.json".format(dt_time.strftime("%Y%m%d"))
+        else:
+            return
         dt_list = self.client_spider.select(sql)
         df = []
         for line in dt_list:
@@ -258,7 +268,7 @@ class UserProcess(object):
 
             df.append(temp[1:])
         df = pd.DataFrame(df, columns=self.user_features)
-        df.to_json("data/train_data/user_data.json", orient='records')
+        df.to_json(des_path, orient='records')
 
 
 if __name__ == "__main__":
@@ -281,7 +291,7 @@ if __name__ == "__main__":
                 ed = str(input("输入结束日期,格式为 YYYYmmdd"))
                 U.userinfo_to_mysql(start_date=sd, end_date=ed)
             elif mode == "train":
-                U.generate_user_data("train")
+                U.generate_user_data(flag=mode, dt_time=dt)
             else:
                 print("Error")
         case "Data":