ソースを参照

获取 rov 数据

罗俊辉 1 年間 前
コミット
39b5d05238
2 ファイル変更27 行追加9 行削除
  1. 16 5
      data_process/process_data_for_lightgbm.py
  2. 11 4
      main.py

+ 16 - 5
data_process/process_data_for_lightgbm.py

@@ -172,7 +172,18 @@ class AllProcess(object):
 
     def __init__(self):
         self.client_spider = MySQLClientSpider()
-        self.all_features = ["label", "tag1", "tag2", "tag3", "tag4"]
+        self.all_features = [
+            # "video_title",
+            "rov_label",
+            "channel",
+            "type",
+            # "out_play_cnt",
+            # "out_like_cnt",
+            # "out_share_cnt"
+            "tag1",
+            "tag2",
+            "tag3"
+        ]
 
     def read_all_data(self, flag, dt_time):
         """
@@ -185,11 +196,11 @@ class AllProcess(object):
         three_date_before = dt_time + datetime.timedelta(days=4)
         temp_time = three_date_before.strftime("%Y%m%d")
         if flag == "train":
-            sql = f"""select video_title, label from lightgbm_data where daily_dt_str <= '{temp_time}';"""
+            sql = f"""select video_title, rov_label, channel, type from lightgbm_data where daily_dt_str <= '{temp_time}' and rov_label > 0;"""
             des_path = "/root/luojunhui/alg/data/train_data/all_train_{}.json".format(
                 datetime.datetime.today().strftime("%Y%m%d"))
         elif flag == "predict":
-            sql = f"""select video_title, label from lightgbm_data where daily_dt_str = '{temp_time}';"""
+            sql = f"""select video_title, rov_label, channel, type from lightgbm_data where daily_dt_str = '{temp_time}';"""
             des_path = "/root/luojunhui/alg/data/predict_data/all_predict_{}.json".format(dt_time.strftime("%Y%m%d"))
         else:
             return
@@ -198,10 +209,10 @@ class AllProcess(object):
         for line in tqdm(dt_list):
             title = line[0]
             try:
-                title_tags = list(jieba.analyse.textrank(title, topK=4))
+                title_tags = list(jieba.analyse.textrank(title, topK=3))
                 temp = list(line)
                 if title_tags:
-                    for i in range(4):
+                    for i in range(3):
                         try:
                             temp.append(title_tags[i])
                         except:

+ 11 - 4
main.py

@@ -25,15 +25,22 @@ class LightGBM(object):
     def __init__(self, flag, dt):
         self.label_encoder = LabelEncoder()
         self.my_c = [
+            "channel",
+            "type",
             "tag1",
             "tag2",
-            "tag3",
-            "tag4"
+            "tag3"
+        ]
+        self.str_columns = [
+            "channel",
+            "type",
+            "tag1",
+            "tag2",
+            "tag3"
         ]
-        self.str_columns = ["tag1", "tag2"]
         self.split_c = 0.75
         self.yc = 0.8
-        self.model = "models/lightgbm_0408_all_tags.bin"
+        self.model = "models/lightgbm_0409_all_tags.bin"
         self.flag = flag
         self.dt = dt