Browse Source

generate label for mysql

罗俊辉 1 year ago
parent
commit
7ff6390140
1 changed files with 22 additions and 20 deletions
  1. 22 20
      main_spider.py

+ 22 - 20
main_spider.py

@@ -78,7 +78,7 @@ class LightGBM(object):
             'metric': 'binary_logloss',
             'verbosity': -1,
             'boosting_type': 'gbdt',
-            'num_leaves': trial.suggest_int('num_leaves', 10, 40),
+            'num_leaves': trial.suggest_int('num_leaves', 20, 40),
             'learning_rate': trial.suggest_loguniform('learning_rate', 1e-8, 1.0),
             'feature_fraction': trial.suggest_uniform('feature_fraction', 0.4, 1.0),
             'bagging_fraction': trial.suggest_uniform('bagging_fraction', 0.4, 1.0),
@@ -107,8 +107,10 @@ class LightGBM(object):
         Load dataset
         :return:
         """
-        X_train, X_test = self.generate_x_data()
-        Y_train, Y_test = self.generate_y_data()
+        x, y = self.read_data()
+        train_size = int(len(x) * self.split_c)
+        X_train, X_test = x[:train_size], x[train_size:]
+        Y_train, Y_test = y[:train_size], y[train_size:]
         train_data = lgb.Dataset(
             X_train,
             label=Y_train,
@@ -192,20 +194,20 @@ class LightGBM(object):
 
 
 if __name__ == "__main__":
-    # i = int(input("输入 1 训练, 输入 2 预测:\n"))
-    # if i == 1:
-    #     f = "train"
-    #     dt = "whole"
-    #     L = LightGBM(flag=f, dt=dt)
-    #     L.train_model()
-    # elif i == 2:
-    #     f = "predict"
-    #     dt = int(input("输入日期, 16-21:\n"))
-    #     L = LightGBM(flag=f, dt=dt)
-    #     L.evaluate_model()
-    #     L.feature_importance()
-    L = LightGBM("train", "whole")
-    study = optuna.create_study(direction='maximize')
-    study.optimize(L.bays_params, n_trials=100)
-    print('Number of finished trials:', len(study.trials))
-    print('Best trial:', study.best_trial.params)
+    i = int(input("输入 1 训练, 输入 2 预测:\n"))
+    if i == 1:
+        f = "train"
+        dt = "whole"
+        L = LightGBM(flag=f, dt=dt)
+        L.train_model()
+    elif i == 2:
+        f = "predict"
+        dt = int(input("输入日期, 16-21:\n"))
+        L = LightGBM(flag=f, dt=dt)
+        L.evaluate_model()
+        L.feature_importance()
+    # L = LightGBM("train", "whole")
+    # study = optuna.create_study(direction='maximize')
+    # study.optimize(L.bays_params, n_trials=100)
+    # print('Number of finished trials:', len(study.trials))
+    # print('Best trial:', study.best_trial.params)