p_data_process.py 1.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445
  1. """
  2. 生成预测数据, 3月14日 和 3月17日的小时级数据
  3. 3月18日至 3月21日的daily 数据
  4. """
  5. import os
  6. import json
  7. from tqdm import tqdm
  8. from functions import generate_hourly_strings
  9. def generate_hour_data(s, e, flag):
  10. """
  11. 从 temp_data读取小时级别数据作为训练数据和预测数据
  12. :param s: 开始日期
  13. :param e: 结束日期
  14. :param flag: train / pred
  15. :return:
  16. """
  17. target_string_list = generate_hourly_strings(start_date=s, end_date=e)
  18. path = "data/temp_data"
  19. L = []
  20. for file in tqdm(target_string_list):
  21. json_path = os.path.join(path, "hour_" + file + ".json")
  22. with open(json_path, encoding="utf-8") as f:
  23. data = json.loads(f.read())
  24. for obj in data:
  25. L.append(obj)
  26. with open("data/{}_data/{}_{}_{}.json".format(flag, flag, s, e), "w", encoding="utf-8") as f:
  27. f.write(json.dumps(L, ensure_ascii=False))
  28. if __name__ == "__main__":
  29. iii = int(input("请输入标识符,输入 1 生成训练数据, 输入 2 生成预测数据: \n"))
  30. if iii == 1:
  31. f = "train"
  32. elif iii == 2:
  33. f = "pred"
  34. else:
  35. print("输入错误")
  36. start = str(input("请输入开始字符串, 格式为 yyyymmddhh:\n"))
  37. end = str(input("请输入结束字符串, 格式为 yyyymmddhh: \n"))
  38. generate_hour_data(s=start, e=end, flag=f)