Browse Source

generate label for mysql

罗俊辉 1 year ago
parent
commit
77e0846f98
1 changed files with 32 additions and 36 deletions
  1. 32 36
      main_spider.py

+ 32 - 36
main_spider.py

@@ -13,11 +13,9 @@ sys.path.append(os.getcwd())
 import numpy as np
 import pandas as pd
 import lightgbm as lgb
-from sklearn.preprocessing import LabelEncoder
-from sklearn.metrics import accuracy_score
-
-from sklearn.model_selection import train_test_split, StratifiedKFold
-from sklearn.datasets import load_breast_cancer
+from scipy.stats import randint as sp_randint
+from scipy.stats import uniform as sp_uniform
+from sklearn.model_selection import RandomizedSearchCV, train_test_split
 from sklearn.metrics import roc_auc_score
 from bayes_opt import BayesianOptimization
 
@@ -75,40 +73,38 @@ class LightGBM(object):
             features[key] = self.label_encoder.fit_transform(features[key])
         return features, labels
 
-    def objective(self, trial):
-        x, y = self.read_data()
-        train_size = int(len(x) * self.split_c)
-        X_train, X_test = x[:train_size], x[train_size:]
-        Y_train, Y_test = y[:train_size], y[train_size:]
+    def best_params(self):
+        X, y = self.read_data()
+        X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
+
+        lgbM = lgb.LGBMClassifier(objective='binary')
 
-        dtrain = lgb.Dataset(X_train, label=Y_train)
-        dvalid = lgb.Dataset(X_test, label=Y_test, reference=dtrain)
-
-        param = {
-            'objective': 'binary',  # 根据问题修改,例如'regression'或'multiclass'
-            'metric': 'binary_logloss',  # 根据问题修改,例如'l2'或'multi_logloss'
-            'verbosity': -1,
-            'boosting_type': 'gbdt',
-            'num_leaves': trial.suggest_int('num_leaves', 20, 40),
-            'learning_rate': trial.suggest_float('learning_rate', 1e-4, 1e-1, log=True),
-            'feature_fraction': trial.suggest_float('feature_fraction', 0.6, 0.9),
-            'bagging_fraction': trial.suggest_float('bagging_fraction', 0.6, 0.9),
-            'bagging_freq': trial.suggest_int('bagging_freq', 1, 10),
-            'min_child_samples': trial.suggest_int('min_child_samples', 5, 100),
-            'num_thread': 16
+        # 设置搜索的参数范围
+        param_dist = {
+            'num_leaves': sp_randint(20, 40),
+            'learning_rate': sp_uniform(0.001, 0.1),
+            'feature_fraction': sp_uniform(0.5, 0.9),
+            'bagging_fraction': sp_uniform(0.5, 0.9),
+            'bagging_freq': sp_randint(1, 10),
+            'min_child_samples': sp_randint(5, 100),
         }
 
-        gbm = lgb.train(param, dtrain, valid_sets=[dvalid])
-        preds = gbm.predict(X_test)
-        pred_labels = np.rint(preds)
-        accuracy = accuracy_score(Y_test, pred_labels)
-        return accuracy  # 或其他优化指标
+        # 定义 RandomizedSearchCV
+        rsearch = RandomizedSearchCV(estimator=lgbM, param_distributions=param_dist, n_iter=100, cv=3,
+                                     scoring='roc_auc', random_state=42, verbose=2)
+
+        # 开始搜索
+        rsearch.fit(X_train, y_train)
+
+        # 打印最佳参数和对应的AUC得分
+        print("Best parameters found: ", rsearch.best_params_)
+        print("Best AUC found: ", rsearch.best_score_)
 
-    def tune(self, n_trials=100):
-        study = optuna.create_study(direction='maximize')
-        study.optimize(self.objective, n_trials=n_trials)
-        print('Number of finished trials:', len(study.trials))
-        print('Best trial:', study.best_trial.params)
+        # 使用最佳参数在测试集上的表现
+        best_model = rsearch.best_estimator_
+        y_pred = best_model.predict_proba(X_test)[:, 1]
+        auc = roc_auc_score(y_test, y_pred)
+        print("AUC on test set: ", auc)
 
     def train_model(self):
         """
@@ -214,7 +210,7 @@ if __name__ == "__main__":
     #     L.evaluate_model()
     #     L.feature_importance()
     L = LightGBM("train", "whole")
-    L.tune()
+    L.best_params()
     # study = optuna.create_study(direction='maximize')
     # study.optimize(L.bays_params, n_trials=100)
     # print('Number of finished trials:', len(study.trials))