Browse Source

预测代码

罗俊辉 1 year ago
parent
commit
1f92e33c9e
1 changed files with 15 additions and 2 deletions
  1. 15 2
      main.py

+ 15 - 2
main.py

@@ -121,7 +121,19 @@ class LightGBM(object):
         评估模型性能
         :return:
         """
-        X_test, Y_test = [], []
+        # 测试数据
+        with open("whole_data/x_data_total_return_prid.json") as f1:
+            x_list = json.loads(f1.read())
+
+        # 测试 label
+        with open("whole_data/y_data_total_return_prid.json") as f2:
+            Y_test = json.loads(f2.read())
+
+        X_test = pd.DataFrame(x_list, columns=self.my_c)
+        for key in self.str_columns:
+            X_test[key] = self.label_encoder.fit_transform(X_test[key])
+        for key in self.float_columns:
+            X_test[key] = pd.to_numeric(X_test[key], errors='coerce')
         bst = lgb.Booster(model_file=self.model)
         y_pred = bst.predict(X_test, num_iteration=bst.best_iteration)
         # 转换为二进制输出
@@ -133,4 +145,5 @@ class LightGBM(object):
 
 if __name__ == '__main__':
     L = LightGBM()
-    L.train_model()
+    # L.train_model()
+    L.evaluate_model()