Ver código fonte

预测代码

罗俊辉 1 ano atrás
pai
commit
10caf2e58e
1 arquivos alterados com 9 adições e 2 exclusões
  1. 9 2
      main.py

+ 9 - 2
main.py

@@ -130,6 +130,7 @@ class LightGBM(object):
         with open("whole_data/y_data_total_return_prid.json") as f2:
             Y_test = json.loads(f2.read())
 
+        Y_test = [0 if i <= 26 else 1 for i in Y_test]
         X_test = pd.DataFrame(x_list, columns=self.my_c)
         for key in self.str_columns:
             X_test[key] = self.label_encoder.fit_transform(X_test[key])
@@ -138,10 +139,16 @@ class LightGBM(object):
         bst = lgb.Booster(model_file=self.model)
         y_pred = bst.predict(X_test, num_iteration=bst.best_iteration)
         # 转换为二进制输出
+        score_list = []
         for index, item in enumerate(list(y_pred)):
             real_label = Y_test[index]
             score = item
             print(real_label, "\t", score)
+            score_list.append(score)
+        print("预测样本总量: {}".format(len(score_list)))
+        data_series = pd.Series(score_list)
+        print("统计 score 信息")
+        print(data_series.describe())
         # y_pred_binary = np.where(y_pred > 0.5, 1, 0)
         # # 评估模型
         # accuracy = accuracy_score(Y_test, y_pred_binary)
@@ -150,5 +157,5 @@ class LightGBM(object):
 
 if __name__ == '__main__':
     L = LightGBM()
-    L.train_model()
-    # L.evaluate_model()
+    # L.train_model()
+    L.evaluate_model()