瀏覽代碼

获取 rov 数据

罗俊辉 1 年之前
父節點
當前提交
905d122351
共有 1 個文件被更改,包括 8 次插入3 次删除
  1. 8 3
      main.py

+ 8 - 3
main.py

@@ -53,7 +53,12 @@ class LightGBM(object):
         df = df.dropna(subset=['rov_label'])  # 把 label 为空的删掉
         df = df.dropna(subset=['tag1', 'tag2', 'tag3'], how="all")  # 把 tag 为空的数据也删掉
         labels = df['rov_label']
-        features = df.drop(['label'], axis=1)
+        if not yc:
+            temp = sorted(labels)
+            yc = temp[int(len(temp) * 0.7)]
+        print("阈值", yc)
+        labels = [0 if i < yc else 1 for i in labels]
+        features = df.drop(['rov_label'], axis=1)
         for key in self.str_columns:
             features[key] = self.label_encoder.fit_transform(features[key])
         return features, labels, df
@@ -64,7 +69,7 @@ class LightGBM(object):
         """
         path = "data/train_data/all_train_20240409.json"
         X, y, ori_df = self.read_data(path)
-        print(len(list(y)))
+
         X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
 
         lgb_ = lgb.LGBMClassifier(objective='binary')
@@ -108,7 +113,7 @@ class LightGBM(object):
         :return:
         """
         path = "data/train_data/all_train_20240409.json"
-        x, y, ori_df = self.read_data(path)
+        x, y, ori_df = self.read_data(path, yc=0.02)
         train_size = int(len(x) * self.split_c)
         X_train, X_test = x[:train_size], x[train_size:]
         Y_train, Y_test = y[:train_size], y[train_size:]