Browse Source

generate label for mysql

罗俊辉 1 year ago
parent
commit
72e1b9e7f3
1 changed files with 30 additions and 4 deletions
  1. 30 4
      process_data.py

+ 30 - 4
process_data.py

@@ -132,7 +132,7 @@ class SpiderProcess(object):
         select_sql = "SELECT video_id, video_title, label, channel, out_user_id, spider_mode, out_play_cnt, out_like_cnt, out_share_cnt FROM lightgbm_data WHERE type = 'spider';"
         data_list = self.client_spider.select(select_sql)
         df = []
-        for line in tqdm(data_list):
+        for line in tqdm(data_list[:10]):
             try:
                 temp = list(line)
                 video_id = line[0]
@@ -156,8 +156,9 @@ class SpiderProcess(object):
             except:
                 continue
         df = pd.DataFrame(df, columns=['label', 'channel', 'out_user_id', 'mode', 'out_play_cnt', 'out_like_cnt',
-                                       'out_share_cnt', 'label', 'lop', 'duration', 'tag1', 'tag2', 'tag3'])
-        df.to_json("data/train_data/spider_data_{}.json".format(datetime.datetime.today().strftime("y%m%d")), orient='records')
+                                       'out_share_cnt', 'lop', 'duration', 'tag1', 'tag2', 'tag3'])
+        df.to_json("data/train_data/spider_data_{}.json".format(datetime.datetime.today().strftime("y%m%d")),
+                   orient='records')
 
 
 class UserProcess(object):
@@ -168,6 +169,7 @@ class UserProcess(object):
     def __init__(self):
         self.client_spider = MySQLClientSpider()
         self.user_features = [
+            "label",
             "uid",
             "channel",
             "user_fans",
@@ -181,7 +183,10 @@ class UserProcess(object):
             "user_return_3",
             "user_view_3",
             "user_share_3",
-            "address"
+            "address",
+            "tag1",
+            "tag2",
+            "tag3"
         ]
 
     def userinfo_to_mysql(self, start_date, end_date):
@@ -218,6 +223,27 @@ class UserProcess(object):
         生成user训练数据
         :return:
         """
+        sql = "select title, label, uid, channel, user_fans, user_view_30, user_share_30, user_return_30, user_rov, user_str, user_return_videos_30, user_return_videos_3, user_return_3, user_view_3, user_share_3, address from lighgbm_data where type = 'userupload';"
+        dt_list = self.client_spider.select(sql)
+        df = []
+        for line in dt_list:
+            title = line[0]
+            temp = line
+            title_tags = list(jieba.analyse.textrank(title, topK=3))
+            if title_tags:
+                for i in range(3):
+                    try:
+                        temp.append(title_tags[i])
+                    except:
+                        temp.append(None)
+            else:
+                temp.append(None)
+                temp.append(None)
+                temp.append(None)
+
+            df.append(temp[1:])
+        df = pd.DataFrame(df, columns=self.user_features)
+        df.to_json("data/train_data/user_data.json", orient='records')
 
 
 if __name__ == "__main__":