Quellcode durchsuchen

generate label for mysql

罗俊辉 vor 1 Jahr
Ursprung
Commit
59ee8ba399
1 geänderte Dateien mit 158 neuen und 62 gelöschten Zeilen
  1. 158 62
      process_data.py

+ 158 - 62
process_data.py

@@ -5,24 +5,24 @@ process the data to satisfy the lightgbm
 import sys
 import os
 import json
+import asyncio
+import argparse
 from tqdm import tqdm
 import jieba.analyse
 
 sys.path.append(os.getcwd())
 
-from functions import generate_label_date, MysqlClient, MySQLClientSpider
+from functions import generate_label_date, generate_daily_strings, MysqlClient, MySQLClientSpider
 
 
 class DataProcessor(object):
     """
-    Process the data to satisfy the lightGBM
+    Insert some information to lightgbm_data
     """
 
-    def __init__(self, flag, c="useful"):
+    def __init__(self, ):
         self.client = MysqlClient()
         self.client_spider = MySQLClientSpider()
-        self.flag = flag
-        self.c = c
 
     def generate_train_label(self, item, y_ori_data, cate):
         """
@@ -66,8 +66,8 @@ class DataProcessor(object):
             "rov_user",
             "str_user"
         ]
-        match self.c:
-            case "useful":
+        match self.ll:
+            case "all":
                 item_features = [item[i] for i in useful_features]
             case "user":
                 if item['type'] == "userupload":
@@ -102,77 +102,173 @@ class DataProcessor(object):
             label = 0
         return label, item_features
 
-    def title_processor(self, video_id):
+    async def producer(self):
         """
-        通过 video_id 去获取title, 然后通过 title 再分词,把关键词作为 feature
-        :param video_id: the video id
-        :return: tag_list [tag, tag, tag, tag......]
+        生成数据
+        :return:none
         """
-        sql = f"""SELECT title from wx_video where id = {video_id};"""
-        try:
-            title = self.client.select(sql)[0][0]
-            keywords_textrank = jieba.analyse.textrank(title, topK=3)
-            return list(keywords_textrank)
-        except Exception as e:
-            print(video_id, "\t", e)
-            return []
+        # 把 label, video_title, daily_dt_str, 存储到 mysql 数据库中去
+        label_path = "data/train_data/daily-label-20240101-20240325.json"
+        with open(label_path, encoding="utf-8") as f:
+            label_data = json.loads(f.read())
+
+        async def read_title(client, video_id):
+            """
+            read_title_from mysql
+            """
+            sql = f"""SELECT title from wx_video where id = {video_id};"""
+            try:
+                title = await client.select(sql)[0][0]
+                return title
+            except Exception as e:
+                print(video_id, "\t", e)
+                return ""
+
+        def generate_label(video_id, hourly_dt_str, label_info):
+            """
+            generate label daily_dt_str for mysql
+            :param label_info:
+            :param video_id:
+            :param hourly_dt_str:
+            :return: label, daily_dt_str
+            """
+            label_dt = generate_label_date(hourly_dt_str)
+            label_obj = label_info.get(label_dt, {}).get(video_id)
+            if label_obj:
+                label = int(label_obj["total_return"]) if label_obj["total_return"] else 0
+            else:
+                label = 0
+            return label, label_dt
 
-    def cal_lop(self, video_id):
+        async def process_info(item_, label_info):
+            """
+            Insert data into MySql
+            :param item_:
+            :param label_info:
+            """
+            video_id, hour_dt = item_
+            title = await read_title(client=self.client, video_id=video_id)
+            label, dt_daily = generate_label(video_id, hour_dt, label_info)
+            insert_sql = f"""INSERT INTO lightgbm_data (video_title, label, daily_dt_str) values ('{title}', '{label}', '{dt_daily}';"""
+            await self.client_spider.update(insert_sql)
+
+
+        select_sql = "SELECT video_id, hour_dt_str FROM lightgbm_data where label is NULL;"
+        init_data_tuple = self.client_spider.select(select_sql)
+        init_list = list(init_data_tuple)
+        async_tasks = []
+        for item in init_list:
+            async_tasks.append(process_info(item, label_data))
+
+        await asyncio.gather(*async_tasks)
+
+
+class SpiderProcess(object):
+    """
+    Spider Data Process and Process data for lightgbm training
+    """
+
+    def __init__(self):
+        self.client_spider = MySQLClientSpider()
+
+    def spider_lop(self, video_id):
         """
-        通过视频 id 去爬虫表读取播放和点赞,并且求出like / play的值,要注意平滑,要注意分母为 0 的情况
+        Spider lop = like / play
         :param video_id:
-        :return:  lop
+        :return:
         """
         sql = f"""SELECT like_cnt, play_cnt, duration from crawler_video where video_id = '{video_id}';"""
         try:
             like_cnt, play_cnt, duration = self.client_spider.select(sql)[0]
-            lop = (like_cnt + 70) / (play_cnt + 1800)
+            lop = (like_cnt + 700) / (play_cnt + 18000)
             return lop, duration
         except Exception as e:
             print(video_id, "\t", e)
             return 0, 0
 
-    def producer(self, dt):
+    def spider_data_produce(self):
         """
-        生成数据
-        :return:none
+        把 spider_duration 存储到数据库中
+        :return:
         """
-        if self.flag == "train":
-            x_path = "data/train_data/train_2024010100_2024031523.json"
-            y_path = "data/train_data/daily-label-20240101-20240325.json"
-        elif self.flag == "predict":
-            x_path = "data/pred_data/pred_202403{}00_202403{}23.json".format(dt, dt)
-            y_path = "data/train_data/daily-label-20240101-20240325.json"
-        else:
-            return
-        with open(x_path) as f:
-            x_data = json.loads(f.read())
-        with open(y_path) as f:
-            y_data = json.loads(f.read())
-        cate_list = ["total_return"]
-        for c in cate_list:
-            x_list = []
-            y_list = []
-            for video_obj in tqdm(x_data):
-                our_label, features = self.generate_train_label(video_obj, y_data, c)
-                if features:
-                    x_list.append(features)
-                    y_list.append(our_label)
-            with open("data/produce_data/x_data_{}_{}_{}_{}.json".format(c, self.flag, dt, self.c), "w") as f1:
-                f1.write(json.dumps(x_list, ensure_ascii=False))
-
-            with open("data/produce_data/y_data_{}_{}_{}_{}.json".format(c, self.flag, dt, self.c), "w") as f2:
-                f2.write(json.dumps(y_list, ensure_ascii=False))
+        return
+
+
+class UserProcess(object):
+    """
+        User Data Process
+        """
+
+    def __init__(self):
+        self.client = MysqlClient()
+        self.user_features = [
+            "uid",
+            "channel",
+            "user_fans",
+            "user_view_30",
+            "user_share_30",
+            "user_return_30",
+            "user_rov",
+            "user_str",
+            "user_return_videos_30",
+            "user_return_videos_3",
+            "user_return_3",
+            "user_view_3",
+            "user_share_3",
+            "address"
+        ]
+
+    def title_processor(self, video_id):
+        """
+            通过 video_id 去获取title, 然后通过 title 再分词,把关键词作为 feature
+            :param video_id: the video id
+            :return: tag_list [tag, tag, tag, tag......]
+            """
+        sql = f"""SELECT title from wx_video where id = {video_id};"""
+        try:
+            title = self.client.select(sql)[0][0]
+            keywords_textrank = jieba.analyse.textrank(title, topK=3)
+            return list(keywords_textrank)
+        except Exception as e:
+            print(video_id, "\t", e)
+            return []
+
+    def user_data_process(self):
+        """
+        把 user_return_3, user_view_3, user_share_3
+        user_return_videos_3, user_return_videos_30
+        address 存储到 mysql 数据库中
+        :return:
+        """
+        user_path = '/data'
 
 
 if __name__ == "__main__":
-    flag = int(input("please input method train or predict:\n "))
-    if flag == 1:
-        t = "train"
-        D = DataProcessor(flag=t, c="spider")
-        D.producer(dt="whole")
-    else:
-        t = "predict"
-        D = DataProcessor(flag=t, c="spider")
-        for d in range(16, 22):
-            D.producer(d)
+    # D = DataProcessor()
+    # D.producer()
+    # parser = argparse.ArgumentParser()  # 新建参数解释器对象
+    # parser.add_argument("--mode")
+    # parser.add_argument("--category")
+    # parser.add_argument("--dtype", default="whole")
+    # args = parser.parse_args()
+    # mode = args.mode
+    # category = args.category
+    # dtype = args.dtype
+    D = DataProcessor()
+    asyncio.run(D.producer())
+    # if mode == "train":
+    #     print("Loading data and process for training.....")
+    #     D = DataProcessor(flag="train", ll=category)
+    #     D.producer("whole")
+    # elif mode == "predict":
+    #     print("Loading data and process for prediction for each day......")
+    #     D = DataProcessor(flag="predict", ll=category)
+    #     if dtype == "single":
+    #         date_str = str(input("Please enter the date of the prediction"))
+    #         D.producer(date_str)
+    #     elif dtype == "days":
+    #         start_date_str = str(input("Please enter the start date of the prediction"))
+    #         end_date_str = str(input("Please enter the end date of the prediction"))
+    #         dt_list = generate_daily_strings(start_date=start_date_str, end_date=end_date_str)
+    #         for d in dt_list:
+    #             D.producer()