Browse Source

更新初始化代码

罗俊辉 1 year ago
parent
commit
239d7800fa
1 changed files with 18 additions and 18 deletions
  1. 18 18
      process_data.py

+ 18 - 18
process_data.py

@@ -11,7 +11,7 @@ sys.path.append(os.getcwd())
 from functions import generate_label_date
 
 
-def generate_train_label(item, y_ori_data):
+def generate_train_label(item, y_ori_data, cate):
     """
     生成训练数据,用 np.array矩阵的方式返回,
     :return: x_train, 训练数据, y_train, 训练 label
@@ -39,33 +39,33 @@ def generate_train_label(item, y_ori_data):
     label_dt = generate_label_date(dt)
     label_obj = y_ori_data.get(label_dt, {}).get(video_id)
     if label_obj:
-        label = int(label_obj['total_return']) if label_obj['total_return'] else 0
+        label = int(label_obj[cate]) if label_obj[cate] else 0
     else:
         label = 0
     return label, item_features
 
 
 if __name__ == '__main__':
-    x_path = 'data/train_january.json'
-    y_path = 'data/jan_feb_label.json'
+    x_path = 'data/hour_train.json'
+    y_path = 'data/daily-label-20240101-20240320.json'
 
     with open(x_path) as f:
         x_data = json.loads(f.read())
 
     with open(y_path) as f:
         y_data = json.loads(f.read())
-    x_list = []
-    y_list = []
-    for video_obj in tqdm(x_data):
-        print(video_obj)
-        our_label, features = generate_train_label(video_obj, y_data)
-        # if our_label:
+    cate_list = ['total_return', '3day_up_level', 'total_view', 'total_share']
+    for c in cate_list:
+        x_list = []
+        y_list = []
+        for video_obj in tqdm(x_data):
+            print(video_obj)
+            our_label, features = generate_train_label(video_obj, y_data, c)
+            x_list.append(features)
+            y_list.append(our_label)
+        # print(len(y_list))
+        with open("whole_data/x_data_{}.json".format(c), "w") as f1:
+            f1.write(json.dumps(x_list, ensure_ascii=False))
 
-        x_list.append(features)
-        y_list.append(our_label)
-    # print(len(y_list))
-    with open("whole_data/x_data.json", "w") as f1:
-        f1.write(json.dumps(x_list, ensure_ascii=False))
-
-    with open("whole_data/y_data.json", "w") as f2:
-        f2.write(json.dumps(y_list, ensure_ascii=False))
+        with open("whole_data/y_data_{}.json".format(c), "w") as f2:
+            f2.write(json.dumps(y_list, ensure_ascii=False))