Browse Source

获取 rov 数据

罗俊辉 1 year ago
parent
commit
614fe8d45b
1 changed files with 7 additions and 7 deletions
  1. 7 7
      main.py

+ 7 - 7
main.py

@@ -50,10 +50,10 @@ class LightGBM(object):
         :return:
         """
         df = pd.read_json(path)
-        df = df.dropna(subset=['label'])  # 把 label 为空的删掉
-        df = df.dropna(subset=['tag1', 'tag2'], how="all")  # 把 tag 为空的数据也删掉
-        labels = df['label']
-        features = df.drop(['label', 'tag3', 'tag4'], axis=1)
+        df = df.dropna(subset=['rov_label'])  # 把 label 为空的删掉
+        df = df.dropna(subset=['tag1', 'tag2', 'tag3'], how="all")  # 把 tag 为空的数据也删掉
+        labels = df['rov_label']
+        features = df.drop(['label'], axis=1)
         for key in self.str_columns:
             features[key] = self.label_encoder.fit_transform(features[key])
         return features, labels, df
@@ -62,7 +62,7 @@ class LightGBM(object):
         """
         find best params for lightgbm
         """
-        path = "data/train_data/all_train_20240408.json"
+        path = "data/train_data/all_train_20240409.json"
         X, y, ori_df = self.read_data(path)
         print(len(list(y)))
         X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
@@ -107,7 +107,7 @@ class LightGBM(object):
         Load dataset
         :return:
         """
-        path = "data/train_data/all_train_20240408.json"
+        path = "data/train_data/all_train_20240409.json"
         x, y, ori_df = self.read_data(path)
         train_size = int(len(x) * self.split_c)
         X_train, X_test = x[:train_size], x[train_size:]
@@ -115,7 +115,7 @@ class LightGBM(object):
         train_data = lgb.Dataset(
             X_train,
             label=Y_train,
-            categorical_feature=["tag1", "tag2"],
+            categorical_feature=["channel", "type", "tag1", "tag2", "tag3"],
         )
         test_data = lgb.Dataset(X_test, label=Y_test, reference=train_data)
         params = {