Pārlūkot izejas kodu

generate label for mysql

罗俊辉 1 gadu atpakaļ
vecāks
revīzija
d9ab389ac6
1 mainītis faili ar 10 papildinājumiem un 9 dzēšanām
  1. 10 9
      main_userupload.py

+ 10 - 9
main_userupload.py

@@ -65,7 +65,7 @@ class LightGBM(object):
         self.flag = flag
         self.dt = dt
 
-    def read_data(self, path):
+    def read_data(self, path, yc=None):
         """
         Read data from local
         :return:
@@ -74,7 +74,8 @@ class LightGBM(object):
         df = df.dropna(subset=['label'])
         labels = df['label']
         temp = sorted(labels)
-        yc = temp[int(len(temp) * 0.8)]
+        if yc is None:
+            yc = temp[int(len(temp) * 0.8)]
         print("阈值", yc)
         labels = [0 if i < yc else 1 for i in labels]
         features = df.drop("label", axis=1)
@@ -141,12 +142,12 @@ class LightGBM(object):
         )
         test_data = lgb.Dataset(X_test, label=Y_test, reference=train_data)
         params = {
-            'bagging_fraction': 0.7938866919252519,
-            'bagging_freq': 7,
-            'feature_fraction': 0.9687508340232414,
-            'learning_rate': 0.09711720243493492,
-            'min_child_samples': 89,
-            'num_leaves': 35,
+            'bagging_fraction': 0.9323330736797192,
+            'bagging_freq': 1,
+            'feature_fraction': 0.8390650729441467,
+            'learning_rate': 0.07595782999760721,
+            'min_child_samples': 93,
+            'num_leaves': 36,
             'num_threads': 16
         }
 
@@ -164,7 +165,7 @@ class LightGBM(object):
         """
         fw = open("result/summary_{}.txt".format(dt), "a+", encoding="utf-8")
         path = 'data/predict_data/predict_{}.json'.format(dt)
-        x, y = self.read_data(path)
+        x, y = self.read_data(path, yc=28)
         bst = lgb.Booster(model_file=self.model)
         y_pred = bst.predict(x, num_iteration=bst.best_iteration)
         temp = sorted(list(y_pred))