Selaa lähdekoodia

generate label for mysql

罗俊辉 1 vuosi sitten
vanhempi
commit
4de16b6b10
1 muutettua tiedostoa jossa 8 lisäystä ja 7 poistoa
  1. 8 7
      main_spider.py

+ 8 - 7
main_spider.py

@@ -5,19 +5,20 @@ import os
 import sys
 import json
 import optuna
+import numpy as np
 
-from sklearn.linear_model import LogisticRegression
+from sklearn.preprocessing import LabelEncoder
 
 sys.path.append(os.getcwd())
 
-import numpy as np
 import pandas as pd
 import lightgbm as lgb
+
 from scipy.stats import randint as sp_randint
 from scipy.stats import uniform as sp_uniform
 from sklearn.model_selection import RandomizedSearchCV, train_test_split
-from sklearn.metrics import roc_auc_score
-from bayes_opt import BayesianOptimization
+import lightgbm as lgb
+from sklearn.metrics import roc_auc_score, accuracy_score
 
 
 class LightGBM(object):
@@ -77,7 +78,7 @@ class LightGBM(object):
         X, y = self.read_data()
         X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
 
-        lgbM = lgb.LGBMClassifier(objective='binary')
+        lgbm = lgb.LGBMClassifier(objective='binary')
 
         # 设置搜索的参数范围
         param_dist = {
@@ -90,7 +91,7 @@ class LightGBM(object):
         }
 
         # 定义 RandomizedSearchCV
-        rsearch = RandomizedSearchCV(estimator=lgbM, param_distributions=param_dist, n_iter=100, cv=3,
+        rsearch = RandomizedSearchCV(estimator=lgbm, param_distributions=param_dist, n_iter=100, cv=3,
                                      scoring='roc_auc', random_state=42, verbose=2)
 
         # 开始搜索
@@ -210,7 +211,7 @@ if __name__ == "__main__":
     #     L.evaluate_model()
     #     L.feature_importance()
     L = LightGBM("train", "whole")
-    L.best_params()
+    L.tune()
     # study = optuna.create_study(direction='maximize')
     # study.optimize(L.bays_params, n_trials=100)
     # print('Number of finished trials:', len(study.trials))