Browse Source

预测代码

罗俊辉 1 year ago
parent
commit
16a0e90f35
1 changed files with 8 additions and 4 deletions
  1. 8 4
      main.py

+ 8 - 4
main.py

@@ -137,10 +137,14 @@ class LightGBM(object):
         bst = lgb.Booster(model_file=self.model)
         y_pred = bst.predict(X_test, num_iteration=bst.best_iteration)
         # 转换为二进制输出
-        y_pred_binary = np.where(y_pred > 0.7, 1, 0)
-        # 评估模型
-        accuracy = accuracy_score(Y_test, y_pred_binary)
-        print(f'Accuracy: {accuracy}')
+        for index, item in enumerate(list(y_pred)):
+            real_label = Y_test[index]
+            score = item
+            print(real_label, "\t", score)
+        # y_pred_binary = np.where(y_pred > 0.5, 1, 0)
+        # # 评估模型
+        # accuracy = accuracy_score(Y_test, y_pred_binary)
+        # print(f'Accuracy: {accuracy}')
 
 
 if __name__ == '__main__':