罗俊辉 hai 1 ano
pai
achega
33739acb48
Modificáronse 2 ficheiros con 23 adicións e 8 borrados
  1. 15 5
      main.py
  2. 8 3
      process_data.py

+ 15 - 5
main.py

@@ -98,7 +98,7 @@ class LightGBM(object):
         Generate data for feature engineering
         :return:
         """
-        with open("data/produce_data/x_data_total_return_{}.json".format(self.flag,)) as f1:
+        with open("data/produce_data/x_data_total_return_{}_{}.json".format(self.flag, self.dt)) as f1:
             x_list = json.loads(f1.read())
         index_t = int(len(x_list) * self.split_c)
         X_train = pd.DataFrame(x_list[:index_t], columns=self.my_c)
@@ -118,7 +118,7 @@ class LightGBM(object):
         Generate data for label
         :return:
         """
-        with open("produce_data/y_data_total_return_train.json") as f2:
+        with open("produce_data/y_data_total_return_{}_{}.json".format(self.flag, self.dt)) as f2:
             y_list = json.loads(f2.read())
         index_t = int(len(y_list) * self.split_c)
         temp = sorted(y_list)
@@ -217,11 +217,21 @@ class LightGBM(object):
 
 
 if __name__ == "__main__":
-    L = LightGBM()
+    i = int(input("输入 1 训练, 输入 2 预测:\n"))
+    if i == 1:
+        f = "train"
+        dt = "whole"
+        L = LightGBM(flag=f, dt=dt)
+        L.train_model()
+    elif i == 2:
+        f = "predict"
+        dt = int(input("输入日期, 16-21:\n"))
+        L = LightGBM(flag=f, dt=dt)
+        L.evaluate_model()
     # study = optuna.create_study(direction='maximize')
     # study.optimize(L.bays_params, n_trials=100)
     # print('Number of finished trials:', len(study.trials))
     # print('Best trial:', study.best_trial.params)
     # L.train_model()
-    L.evaluate_model()
-    L.feature_importance()
+    # L.evaluate_model()
+    # L.feature_importance()

+ 8 - 3
process_data.py

@@ -114,9 +114,14 @@ class DataProcessor(object):
 
 
 if __name__ == "__main__":
-    flag = str(input("please input method train or predict"))
-    D = DataProcessor(flag=flag)
-    if flag == "predict":
+    flag = int(input("please input method train or predict:\n "))
+    if flag == 1:
+        t = "train"
+        D = DataProcessor(flag=t)
+    else:
+        t = "predict"
+
+    if flag == 2:
         for d in range(16, 22):
             D.producer(d)
     else: