Browse Source

优化数据处理代码

罗俊辉 1 year ago
parent
commit
49a41330ad
1 changed files with 7 additions and 1 deletions
  1. 7 1
      main_spider.py

+ 7 - 1
main_spider.py

@@ -6,6 +6,7 @@ import sys
 import json
 import optuna
 import numpy as np
+from odps import DataFrame
 
 from sklearn.preprocessing import LabelEncoder
 
@@ -155,12 +156,14 @@ class LightGBM(object):
         fw = open("result/summary_{}.txt".format(dt), "a+", encoding="utf-8")
         path = 'data/predict_data/predict_{}.json'.format(dt)
         x, y = self.read_data(path, yc=6)
+        true_label_df = DataFrame(y, columns=['ture_label'])
         bst = lgb.Booster(model_file=self.model)
         y_pred = bst.predict(x, num_iteration=bst.best_iteration)
+        pred_label_df = DataFrame(list(y_pred), columns=['pred_label'])
         temp = sorted(list(y_pred))
         yuzhi = temp[int(len(temp) * 0.7) - 1]
         y_pred_binary = [0 if i <= yuzhi else 1 for i in list(y_pred)]
-        # 转换为二进制输出
+
         score_list = []
         for index, item in enumerate(list(y_pred)):
             real_label = y[index]
@@ -177,6 +180,9 @@ class LightGBM(object):
         accuracy = accuracy_score(y, y_pred_binary)
         print(f"Accuracy: {accuracy}")
         fw.close()
+        # 水平合并
+        df_concatenated = pd.concat([x, true_label_df, pred_label_df], axis=1)
+        df_concatenated.to_excel("data/predict_data/spider_predict_result_{}.excel".format(dt), index=False)
 
     def feature_importance(self):
         """