Browse Source

优化数据处理代码

罗俊辉 1 year ago
parent
commit
c3a0ed2ae9
1 changed files with 5 additions and 5 deletions
  1. 5 5
      main_spider.py

+ 5 - 5
main_spider.py

@@ -158,14 +158,14 @@ class LightGBM(object):
         x, y = self.read_data(path, yc=6)
         print(type(x))
         print(type(y))
-        # true_label_df = DataFrame(list(y), columns=['ture_label'])
+        true_label_df = DataFrame([list(y)], columns=['ture_label'])
         bst = lgb.Booster(model_file=self.model)
         y_pred = bst.predict(x, num_iteration=bst.best_iteration)
-        pred_label_df = DataFrame(list(y_pred), columns=['pred_label'])
+        pred_score_df = DataFrame([list(y_pred)], columns=['pred_score'])
         temp = sorted(list(y_pred))
         yuzhi = temp[int(len(temp) * 0.7) - 1]
         y_pred_binary = [0 if i <= yuzhi else 1 for i in list(y_pred)]
-
+        pred_label_df = DataFrame([list(y_pred_binary)], columns=['pred_label'])
         score_list = []
         for index, item in enumerate(list(y_pred)):
             real_label = y[index]
@@ -183,8 +183,8 @@ class LightGBM(object):
         print(f"Accuracy: {accuracy}")
         fw.close()
         # 水平合并
-        df_concatenated = pd.concat([x, true_label_df, pred_label_df], axis=1)
-        df_concatenated.to_excel("data/predict_data/spider_predict_result_{}.excel".format(dt), index=False)
+        df_concatenated = pd.concat([x, true_label_df,pred_score_df, pred_label_df], axis=1)
+        df_concatenated.to_excel("data/predict_data/spider_predict_result_{}.xlsx".format(dt), index=False)
 
     def feature_importance(self):
         """