Browse Source

generate label for mysql

罗俊辉 1 year ago
parent
commit
35ad3d886b
1 changed files with 34 additions and 104 deletions
  1. 34 104
      process_data.py

+ 34 - 104
process_data.py

@@ -28,84 +28,6 @@ class DataProcessor(object):
         self.client_spider = MySQLClientSpider()
         self.label_data = {}
 
-    def generate_train_label(self, item, y_ori_data, cate):
-        """
-        生成训练数据,用 np.array矩阵的方式返回,
-        :return: x_train, 训练数据, y_train, 训练 label
-        """
-        video_id = item["video_id"]
-        dt = item["dt"]
-        useful_features = [
-            "uid",
-            "type",
-            "channel",
-            "fans",
-            "view_count_user_30days",
-            "share_count_user_30days",
-            "return_count_user_30days",
-            "rov_user",
-            "str_user",
-            "out_user_id",
-            "mode",
-            "out_play_cnt",
-            "out_like_cnt",
-            "out_share_cnt",
-            "out_collection_cnt",
-        ]
-        spider_features = [
-            "channel",
-            "out_user_id",
-            "mode",
-            "out_play_cnt",
-            "out_like_cnt",
-            "out_share_cnt"
-        ]
-        user_features = [
-            "uid",
-            "channel",
-            "fans",
-            "view_count_user_30days",
-            "share_count_user_30days",
-            "return_count_user_30days",
-            "rov_user",
-            "str_user"
-        ]
-        match self.ll:
-            case "all":
-                item_features = [item[i] for i in useful_features]
-            case "user":
-                if item['type'] == "userupload":
-                    item_features = [item[i] for i in user_features]
-                else:
-                    return None, None
-            case "spider":
-                if item['type'] == "spider":
-                    item_features = [item[i] for i in spider_features]
-                    lop, duration = self.cal_lop(video_id)
-                    item_features.append(lop)
-                    item_features.append(duration)
-                else:
-                    return None, None
-        keywords_textrank = self.title_processor(video_id)
-        if keywords_textrank:
-            for i in range(3):
-                try:
-                    item_features.append(keywords_textrank[i])
-                except:
-                    item_features.append(None)
-        else:
-            item_features.append(None)
-            item_features.append(None)
-            item_features.append(None)
-
-        label_dt = generate_label_date(dt)
-        label_obj = y_ori_data.get(label_dt, {}).get(video_id)
-        if label_obj:
-            label = int(label_obj[cate]) if label_obj[cate] else 0
-        else:
-            label = 0
-        return label, item_features
-
     def producer(self):
         """
         生成数据
@@ -169,9 +91,6 @@ class DataProcessor(object):
                 process_info(item)
             except Exception as e:
                 print("操作失败", e)
-        #     time.sleep(0.5)
-        # with ThreadPoolExecutor(max_workers=8) as Pool:
-        #     Pool.map(process_info, init_list)
 
 
 class SpiderProcess(object):
@@ -210,7 +129,7 @@ class SpiderProcess(object):
         把 spider_duration 存储到数据库中
         :return:
         """
-        select_sql = "SELECT video_id, video_title, channel, out_user_id, spider_mode, out_play_cnt, out_like_cnt, out_share_cnt FROM lightgbm_data WHERE type = 'spider';"
+        select_sql = "SELECT video_id, video_title, channel, out_user_id, spider_mode, out_play_cnt, out_like_cnt, out_share_cnt, label FROM lightgbm_data WHERE type = 'spider';"
         data_list = self.client_spider.select(select_sql)
         df = []
         for line in tqdm(data_list):
@@ -233,21 +152,21 @@ class SpiderProcess(object):
                     temp.append(None)
                     temp.append(None)
 
-                df.append(temp[1:])
+                df.append(temp[2:])
             except:
                 continue
         df = pd.DataFrame(df, columns=['title', 'channel', 'out_user_id', 'mode', 'out_play_cnt', 'out_like_cnt',
                                        'out_share_cnt', 'lop', 'duration', 'tag1', 'tag2', 'tag3'])
-        df.to_excel("data/train_data/spider_data_{}.xlsx".format(datetime.datetime.today().strftime("y%m%d")))
+        df.to_json("data/train_data/spider_data_{}.json".format(datetime.datetime.today().strftime("y%m%d")))
 
 
 class UserProcess(object):
     """
-        User Data Process
-        """
+    User Data Process
+    """
 
     def __init__(self):
-        self.client = MysqlClient()
+        self.client_spider = MySQLClientSpider()
         self.user_features = [
             "uid",
             "channel",
@@ -265,29 +184,40 @@ class UserProcess(object):
             "address"
         ]
 
-    def title_processor(self, video_id):
-        """
-            通过 video_id 去获取title, 然后通过 title 再分词,把关键词作为 feature
-            :param video_id: the video id
-            :return: tag_list [tag, tag, tag, tag......]
-            """
-        sql = f"""SELECT title from wx_video where id = {video_id};"""
-        try:
-            title = self.client.select(sql)[0][0]
-            keywords_textrank = jieba.analyse.textrank(title, topK=3)
-            return list(keywords_textrank)
-        except Exception as e:
-            print(video_id, "\t", e)
-            return []
-
-    def user_data_process(self):
+    def userinfo_to_mysql(self, start_date, end_date):
         """
         把 user_return_3, user_view_3, user_share_3
         user_return_videos_3, user_return_videos_30
         address 存储到 mysql 数据库中
         :return:
         """
-        user_path = '/data'
+        user_path = 'data/train_data/daily-user-info-{}-{}.json'.format(start_date, end_date)
+        with open(user_path) as f:
+            data = json.loads(f.read())
+        sql = "select video_id, hour_dt_str from lighgbm_data where type = 'userupload' and address is NULL;"
+        dt_list = self.client_spider.select(sql)
+        for item in dt_list:
+            video_id, dt = item
+            user_info_obj = data.get(dt, {}).get(video_id)
+            if user_info_obj:
+                try:
+                    video_id = user_info_obj['video_id']
+                    address = user_info_obj['address']
+                    return_3 = user_info_obj['return_3days']
+                    view_3 = user_info_obj['view_3days']
+                    share_3 = user_info_obj['share_3days']
+                    return_videos_3 = user_info_obj['3day_return_500_videos']
+                    return_videos_30 = user_info_obj['30day_return_2000_videos']
+                    update_sql = f"""UPDATE lighgbm_data set address='{address}', user_return_3={return_3}, user_view_3={view_3}, user_share_3={share_3}, user_return_videos_3={return_videos_3}, user_return_videos_30={return_videos_30} where video_id = '{video_id}';"""
+                    self.client_spider.update(update_sql)
+                except:
+                    pass
+
+    def generate_user_data(self):
+        """
+        生成user训练数据
+        :return:
+        """
 
 
 if __name__ == "__main__":