فهرست منبع

generate label for mysql

罗俊辉 1 سال پیش
والد
کامیت
2a34be833a
1فایلهای تغییر یافته به همراه11 افزوده شده و 8 حذف شده
  1. 11 8
      process_data.py

+ 11 - 8
process_data.py

@@ -24,6 +24,7 @@ class DataProcessor(object):
     def __init__(self, ):
         self.client = MysqlClient()
         self.client_spider = MySQLClientSpider()
+        self.label_data = {}
 
     def generate_train_label(self, item, y_ori_data, cate):
         """
@@ -111,7 +112,7 @@ class DataProcessor(object):
         # 把 label, video_title, daily_dt_str, 存储到 mysql 数据库中去
         label_path = "data/train_data/daily-label-20240101-20240325.json"
         with open(label_path, encoding="utf-8") as f:
-            label_data = json.loads(f.read())
+            self.label_data = json.loads(f.read())
 
         def read_title(client, video_id):
             """
@@ -148,7 +149,9 @@ class DataProcessor(object):
             :param item_:
             """
             video_id, hour_dt = item_
-            label_info = label_data
+            label_info = self.label_data
+            if not label_info:
+                print(label_info)
             title = read_title(client=self.client, video_id=video_id)
             label, dt_daily = generate_label(video_id, hour_dt, label_info)
             insert_sql = f"""UPDATE lightgbm_data 
@@ -157,14 +160,14 @@ class DataProcessor(object):
             ;"""
             self.client_spider.update(insert_sql)
 
-        select_sql = "SELECT video_id, hour_dt_str FROM lightgbm_data where label is NULL and hour_dt_str < '20240327';"
+        select_sql = "SELECT video_id, hour_dt_str FROM lightgbm_data where label = 0 and hour_dt_str < '20240327';"
         init_data_tuple = self.client_spider.select(select_sql)
         init_list = list(init_data_tuple)
-        for item in init_list:
-            # print(item)
-            process_info(item)
-        # with ThreadPoolExecutor(max_workers=10) as Pool:
-        #     Pool.map(process_info, init_list)
+        # for item in init_list:
+        #     # print(item)
+        #     process_info(item)
+        with ThreadPoolExecutor(max_workers=10) as Pool:
+            Pool.map(process_info, init_list)
 
 
 class SpiderProcess(object):