Browse Source

优化数据处理代码

罗俊辉 1 year ago
parent
commit
439d539405
1 changed files with 5 additions and 4 deletions
  1. 5 4
      main_spider.py

+ 5 - 4
main_spider.py

@@ -54,7 +54,7 @@ class LightGBM(object):
         self.flag = flag
         self.dt = dt
 
-    def read_data(self, path):
+    def read_data(self, path, yc=None):
         """
         Read data from local
         :return:
@@ -62,8 +62,9 @@ class LightGBM(object):
         df = pd.read_json(path)
         df = df.dropna(subset=['label'])
         labels = df['label']
-        temp = sorted(labels)
-        yc = temp[int(len(temp) * 0.7)]
+        if not yc:
+            temp = sorted(labels)
+            yc = temp[int(len(temp) * 0.7)]
         print("阈值", yc)
         labels = [0 if i < yc else 1 for i in labels]
         features = df.drop("label", axis=1)
@@ -153,7 +154,7 @@ class LightGBM(object):
         """
         fw = open("result/summary_{}.txt".format(dt), "a+", encoding="utf-8")
         path = 'data/predict_data/predict_{}.json'.format(dt)
-        x, y = self.read_data(path)
+        x, y = self.read_data(path, yc=6)
         bst = lgb.Booster(model_file=self.model)
         y_pred = bst.predict(x, num_iteration=bst.best_iteration)
         temp = sorted(list(y_pred))