Browse Source

优化数据处理代码

罗俊辉 1 year ago
parent
commit
0ed4a9ff56
1 changed files with 3 additions and 3 deletions
  1. 3 3
      main_spider.py

+ 3 - 3
main_spider.py

@@ -158,14 +158,14 @@ class LightGBM(object):
         x, y = self.read_data(path, yc=6)
         print(type(x))
         print(type(y))
-        true_label_df = DataFrame([list(y)], columns=['ture_label'])
+        true_label_df = pd.DataFrame(list(y), columns=['ture_label'])
         bst = lgb.Booster(model_file=self.model)
         y_pred = bst.predict(x, num_iteration=bst.best_iteration)
-        pred_score_df = DataFrame([list(y_pred)], columns=['pred_score'])
+        pred_score_df = pd.DataFrame(list(y_pred), columns=['pred_score'])
         temp = sorted(list(y_pred))
         yuzhi = temp[int(len(temp) * 0.7) - 1]
         y_pred_binary = [0 if i <= yuzhi else 1 for i in list(y_pred)]
-        pred_label_df = DataFrame([list(y_pred_binary)], columns=['pred_label'])
+        pred_label_df = pd.DataFrame(list(y_pred_binary), columns=['pred_label'])
         score_list = []
         for index, item in enumerate(list(y_pred)):
             real_label = y[index]