罗俊辉 1 år sedan
förälder
incheckning
7933244d45
1 ändrade filer med 9 tillägg och 5 borttagningar
  1. 9 5
      main.py

+ 9 - 5
main.py

@@ -45,12 +45,13 @@ float_cols = [
     ]
 with open("whole_data/x_data.json") as f1:
     x_list = json.loads(f1.read())
-    X_train = pd.DataFrame(x_list[:86434], columns=my_c)
+    index_t = int(len(x_list) * 0.7)
+    X_train = pd.DataFrame(x_list[:index_t], columns=my_c)
     for key in str_cols:
         X_train[key] = label_encoder.fit_transform(X_train[key])
     for key in float_cols:
         X_train[key] = pd.to_numeric(X_train[key], errors='coerce')
-    X_test = pd.DataFrame(x_list[86434:], columns=my_c)
+    X_test = pd.DataFrame(x_list[index_t:], columns=my_c)
     for key in str_cols:
         X_test[key] = label_encoder.fit_transform(X_test[key])
     for key in float_cols:
@@ -59,9 +60,12 @@ with open("whole_data/x_data.json") as f1:
 
 with open("whole_data/y_data.json") as f2:
     y_list = json.loads(f2.read())
-    y__list = [0 if i <= 56 else 1 for i in y_list]
-    y_train = np.array(y__list[:86434])
-    y_test = np.array(y__list[86434:])
+    index_t = int(len(y_list) * 0.7)
+    temp = sorted(y_list)
+    yuzhi = temp[int(len(temp) * 0.8)-1]
+    y__list = [0 if i <= yuzhi else 1 for i in y_list]
+    y_train = np.array(y__list[:index_t])
+    y_test = np.array(y__list[index_t:])
 
 # 创建LightGBM数据集
 train_data = lgb.Dataset(X_train, label=y_train, categorical_feature=['uid', 'type', 'channel', 'mode'])