Przeglądaj źródła

generate label for mysql

罗俊辉 1 rok temu
rodzic
commit
f4057b0739
1 zmienionych plików z 20 dodań i 20 usunięć
  1. 20 20
      main_spider.py

+ 20 - 20
main_spider.py

@@ -123,12 +123,12 @@ class LightGBM(object):
         )
         test_data = lgb.Dataset(X_test, label=Y_test, reference=train_data)
         params = {
-            'num_leaves': 31,
-            'learning_rate': 0.00020616904432655601,
-            'feature_fraction': 0.6508847259863764,
-            'bagging_fraction': 0.7536774652478249,
-            'bagging_freq': 6,
-            'min_child_samples': 99,
+            'bagging_fraction': 0.7938866919252519,
+            'bagging_freq': 7,
+            'feature_fraction': 0.9687508340232414,
+            'learning_rate': 0.09711720243493492,
+            'min_child_samples': 89,
+            'num_leaves': 35,
             'num_threads': 16
         }
 
@@ -198,20 +198,20 @@ class LightGBM(object):
 
 
 if __name__ == "__main__":
-    # i = int(input("输入 1 训练, 输入 2 预测:\n"))
-    # if i == 1:
-    #     f = "train"
-    #     dt = "whole"
-    #     L = LightGBM(flag=f, dt=dt)
-    #     L.train_model()
-    # elif i == 2:
-    #     f = "predict"
-    #     dt = int(input("输入日期, 16-21:\n"))
-    #     L = LightGBM(flag=f, dt=dt)
-    #     L.evaluate_model()
-    #     L.feature_importance()
-    L = LightGBM("train", "whole")
-    L.best_params()
+    i = int(input("输入 1 训练, 输入 2 预测:\n"))
+    if i == 1:
+        f = "train"
+        dt = "whole"
+        L = LightGBM(flag=f, dt=dt)
+        L.train_model()
+    elif i == 2:
+        f = "predict"
+        dt = int(input("输入日期, 16-21:\n"))
+        L = LightGBM(flag=f, dt=dt)
+        L.evaluate_model()
+        L.feature_importance()
+    # L = LightGBM("train", "whole")
+    # L.best_params()
     # study = optuna.create_study(direction='maximize')
     # study.optimize(L.bays_params, n_trials=100)
     # print('Number of finished trials:', len(study.trials))