Browse Source

generate label for mysql

罗俊辉 1 year ago
parent
commit
8964973c37
1 changed files with 11 additions and 13 deletions
  1. 11 13
      process_data.py

+ 11 - 13
process_data.py

@@ -9,6 +9,7 @@ import asyncio
 import argparse
 from tqdm import tqdm
 import jieba.analyse
+from concurrent.futures.thread import ThreadPoolExecutor
 
 sys.path.append(os.getcwd())
 
@@ -102,7 +103,7 @@ class DataProcessor(object):
             label = 0
         return label, item_features
 
-    async def producer(self):
+    def producer(self):
         """
         生成数据
         :return:none
@@ -112,13 +113,13 @@ class DataProcessor(object):
         with open(label_path, encoding="utf-8") as f:
             label_data = json.loads(f.read())
 
-        async def read_title(client, video_id):
+        def read_title(client, video_id):
             """
             read_title_from mysql
             """
             sql = f"""SELECT title from wx_video where id = {video_id};"""
             try:
-                title = await client.select(sql)[0][0]
+                title = client.select(sql)[0][0]
                 return title
             except Exception as e:
                 print(video_id, "\t", e)
@@ -140,26 +141,23 @@ class DataProcessor(object):
                 label = 0
             return label, label_dt
 
-        async def process_info(item_, label_info):
+        def process_info(item_):
             """
             Insert data into MySql
             :param item_:
-            :param label_info:
             """
             video_id, hour_dt = item_
-            title = await read_title(client=self.client, video_id=video_id)
+            label_info = label_data
+            title = read_title(client=self.client, video_id=video_id)
             label, dt_daily = generate_label(video_id, hour_dt, label_info)
             insert_sql = f"""INSERT INTO lightgbm_data (video_title, label, daily_dt_str) values ('{title}', '{label}', '{dt_daily}';"""
-            await self.client_spider.update(insert_sql)
+            self.client_spider.update(insert_sql)
 
         select_sql = "SELECT video_id, hour_dt_str FROM lightgbm_data where label is NULL and hour_dt_str < '20240327';"
         init_data_tuple = self.client_spider.select(select_sql)
         init_list = list(init_data_tuple)
-        async_tasks = []
-        for item in init_list:
-            async_tasks.append(process_info(item, label_data))
-
-        await asyncio.gather(*async_tasks)
+        with ThreadPoolExecutor(max_workers=10) as Pool:
+            Pool.map(process_info, init_list)
 
 
 class SpiderProcess(object):
@@ -254,7 +252,7 @@ if __name__ == "__main__":
     # category = args.category
     # dtype = args.dtype
     D = DataProcessor()
-    asyncio.run(D.producer())
+    D.producer()
     # if mode == "train":
     #     print("Loading data and process for training.....")
     #     D = DataProcessor(flag="train", ll=category)