罗俊辉 1 سال پیش
والد
کامیت
276b4e17aa
2فایلهای تغییر یافته به همراه3 افزوده شده و 2 حذف شده
  1. 2 2
      main.py
  2. 1 0
      test.py

+ 2 - 2
main.py

@@ -10,8 +10,8 @@ from sklearn.metrics import accuracy_score
 
 with open("whole_data/x_data.json") as f1:
     x_list = json.loads(f1.read())
-    X_train = x_list[:10000]
-    X_test = x_list[10000:]
+    X_train = np.array(x_list[:10000], dtype=object)
+    X_test = np.array(x_list[10000:], dtype=object)
 
 with open("whole_data/y_data.json") as f2:
     y_list = json.loads(f2.read())

+ 1 - 0
test.py

@@ -54,6 +54,7 @@ if __name__ == '__main__':
     for video_obj in tqdm(x_data):
         our_label, features = generate_train_label(video_obj, y_data)
         if our_label:
+
             x_list.append(features)
             y_list.append(our_label)
     print(len(y_list))