Browse Source

分析代码

罗俊辉 1 year ago
parent
commit
70b1cffa60
1 changed files with 35 additions and 14 deletions
  1. 35 14
      p_data_process.py

+ 35 - 14
p_data_process.py

@@ -5,20 +5,41 @@
 
 import os
 import json
+from tqdm import tqdm
 from functions import generate_hourly_strings
 
-target_string_list = generate_hourly_strings(
-    start_date="2024031900", end_date="2024031923"
-)
-path = "temp_data/data"
-L = []
-for file in target_string_list:
-    json_path = os.path.join(path, file + ".json")
-    print(json_path)
-    with open(json_path, encoding="utf-8") as f:
-        data = json.loads(f.read())
-    for obj in data:
-        L.append(obj)
 
-with open("prid_data/train_0319.json", "w", encoding="utf-8") as f:
-    f.write(json.dumps(L, ensure_ascii=False))
+def generate_hour_data(s, e, flag):
+    """
+    从 temp_data读取小时级别数据作为训练数据和预测数据
+    :param s: 开始日期
+    :param e: 结束日期
+    :param flag: train / pred
+    :return:
+    """
+    target_string_list = generate_hourly_strings(start_date=s, end_date=e)
+    path = "data/temp_data"
+    L = []
+    for file in tqdm(target_string_list):
+        json_path = os.path.join(path, "hour_" + file + ".json")
+        with open(json_path, encoding="utf-8") as f:
+            data = json.loads(f.read())
+        for obj in data:
+            L.append(obj)
+
+    with open("{}_data/{}_{}_{}.json".format(flag, flag, s, e), "w", encoding="utf-8") as f:
+        f.write(json.dumps(L, ensure_ascii=False))
+
+
+if __name__ == "__main__":
+    iii = int(input("请输入标识符,输入 1 生成训练数据, 输入 2 生成预测数据: \n"))
+    if iii == 1:
+        f = "train"
+    elif iii == 2:
+        f = "pred"
+    else:
+        print("输入错误")
+    start = str(input("请输入开始字符串, 格式为 yyyymmddhh:\n"))
+    end = str(input("请输入结束字符串, 格式为 yymmddhh: \n"))
+    generate_hour_data(s=start, e=end, flag=f)
+