envs.py 8.6 KB


  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from contextlib import closing
  15. import yaml
  16. import copy
  17. import os
  18. import socket
  19. import sys
  20. import six
  21. import traceback
  22. import warnings
  23. global_envs = {}
  24. global_envs_flatten = {}
  25. def flatten_environs(envs, separator="."):
  26. flatten_dict = {}
  27. assert isinstance(envs, dict)
  28. def fatten_env_namespace(namespace_nests, local_envs):
  29. if not isinstance(local_envs, dict):
  30. global_k = separator.join(namespace_nests)
  31. flatten_dict[global_k] = str(local_envs)
  32. else:
  33. for k, v in local_envs.items():
  34. if isinstance(v, dict):
  35. nests = copy.deepcopy(namespace_nests)
  36. nests.append(k)
  37. fatten_env_namespace(nests, v)
  38. else:
  39. global_k = separator.join(namespace_nests + [k])
  40. flatten_dict[global_k] = str(v)
  41. for k, v in envs.items():
  42. fatten_env_namespace([k], v)
  43. return flatten_dict
  44. def set_runtime_environs(environs):
  45. for k, v in environs.items():
  46. os.environ[k] = str(v)
  47. def get_runtime_environ(key):
  48. return os.getenv(key, None)
  49. def get_trainer():
  50. train_mode = get_runtime_environ("train.trainer.trainer")
  51. return train_mode
  52. def get_fleet_mode():
  53. fleet_mode = get_runtime_environ("fleet_mode")
  54. return fleet_mode
  55. def set_global_envs(envs):
  56. assert isinstance(envs, dict)
  57. def fatten_env_namespace(namespace_nests, local_envs):
  58. for k, v in local_envs.items():
  59. if isinstance(v, dict):
  60. nests = copy.deepcopy(namespace_nests)
  61. nests.append(k)
  62. fatten_env_namespace(nests, v)
  63. elif (k == "dataset" or k == "phase" or
  64. k == "runner") and isinstance(v, list):
  65. for i in v:
  66. if i.get("name") is None:
  67. raise ValueError("name must be in dataset list ", v)
  68. nests = copy.deepcopy(namespace_nests)
  69. nests.append(k)
  70. nests.append(i["name"])
  71. fatten_env_namespace(nests, i)
  72. else:
  73. global_k = ".".join(namespace_nests + [k])
  74. global_envs[global_k] = v
  75. fatten_env_namespace([], envs)
  76. for name, value in global_envs.items():
  77. if isinstance(value, str):
  78. value = os_path_adapter(workspace_adapter(value))
  79. global_envs[name] = value
  80. for runner in envs["runner"]:
  81. if "save_step_interval" in runner or "save_step_path" in runner:
  82. phase_name = runner["phases"]
  83. phase = [
  84. phase for phase in envs["phase"]
  85. if phase["name"] == phase_name[0]
  86. ]
  87. dataset_name = phase[0].get("dataset_name")
  88. dataset = [
  89. dataset for dataset in envs["dataset"]
  90. if dataset["name"] == dataset_name
  91. ]
  92. if dataset[0].get("type") == "QueueDataset":
  93. runner["save_step_interval"] = None
  94. runner["save_step_path"] = None
  95. warnings.warn(
  96. "QueueDataset can not support save by step, please not config save_step_interval and save_step_path in your yaml"
  97. )
  98. if get_platform() != "LINUX":
  99. for dataset in envs["dataset"]:
  100. name = ".".join(["dataset", dataset["name"], "type"])
  101. global_envs[name] = "DataLoader"
  102. if get_platform() == "LINUX" and six.PY3:
  103. print("QueueDataset can not support PY3, change to DataLoader")
  104. for dataset in envs["dataset"]:
  105. name = ".".join(["dataset", dataset["name"], "type"])
  106. global_envs[name] = "DataLoader"
  107. def get_global_env(env_name, default_value=None, namespace=None):
  108. """
  109. get os environment value
  110. """
  111. _env_name = env_name if namespace is None else ".".join(
  112. [namespace, env_name])
  113. return global_envs.get(_env_name, default_value)
  114. def get_global_envs():
  115. return global_envs
  116. def paddlerec_adapter(path):
  117. if path.startswith("paddlerec."):
  118. package = get_runtime_environ("PACKAGE_BASE")
  119. l_p = path.split("paddlerec.")[1].replace(".", "/")
  120. return os.path.join(package, l_p)
  121. else:
  122. return path
  123. def os_path_adapter(value):
  124. if get_platform() == "WINDOWS":
  125. value = value.replace("/", "\\")
  126. else:
  127. value = value.replace("\\", "/")
  128. return value
  129. def workspace_adapter(value):
  130. workspace = global_envs.get("workspace")
  131. return workspace_adapter_by_specific(value, workspace)
  132. def workspace_adapter_by_specific(value, workspace):
  133. workspace = paddlerec_adapter(workspace)
  134. value = value.replace("{workspace}", workspace)
  135. return value
  136. def reader_adapter():
  137. if get_platform() != "WINDOWS":
  138. return
  139. datasets = global_envs.get("dataset")
  140. for dataset in datasets:
  141. dataset["type"] = "DataLoader"
  142. def pretty_print_envs(envs, header=None):
  143. spacing = 5
  144. max_k = 45
  145. max_v = 50
  146. for k, v in envs.items():
  147. max_k = max(max_k, len(k))
  148. h_format = "{{:^{}s}}{}{{:<{}s}}\n".format(max_k, " " * spacing, max_v)
  149. l_format = "{{:<{}s}}{{}}{{:<{}s}}\n".format(max_k, max_v)
  150. length = max_k + max_v + spacing
  151. border = "".join(["="] * length)
  152. line = "".join(["-"] * length)
  153. draws = ""
  154. draws += border + "\n"
  155. if header:
  156. draws += h_format.format(header[0], header[1])
  157. else:
  158. draws += h_format.format("paddlerec Global Envs", "Value")
  159. draws += line + "\n"
  160. for k, v in envs.items():
  161. if isinstance(v, str) and len(v) >= max_v:
  162. str_v = "... " + v[-46:]
  163. else:
  164. str_v = v
  165. draws += l_format.format(k, " " * spacing, str(str_v))
  166. draws += border
  167. _str = "\n{}\n".format(draws)
  168. return _str
  169. def lazy_instance_by_package(package, class_name):
  170. try:
  171. model_package = __import__(package,
  172. globals(), locals(), package.split("."))
  173. instance = getattr(model_package, class_name)
  174. return instance
  175. except Exception as err:
  176. traceback.print_exc()
  177. print('Catch Exception:%s' % str(err))
  178. return None
  179. def lazy_instance_by_fliename(abs, class_name):
  180. try:
  181. dirname = os.path.dirname(abs)
  182. sys.path.append(dirname)
  183. package = os.path.splitext(os.path.basename(abs))[0]
  184. model_package = __import__(package,
  185. globals(), locals(), package.split("."))
  186. instance = getattr(model_package, class_name)
  187. return instance
  188. except Exception as err:
  189. traceback.print_exc()
  190. print('Catch Exception:%s' % str(err))
  191. return None
  192. def get_platform():
  193. import platform
  194. plats = platform.platform()
  195. if 'Linux' in plats:
  196. return "LINUX"
  197. if 'Darwin' in plats:
  198. return "DARWIN"
  199. if 'Windows' in plats:
  200. return "WINDOWS"
  201. def find_free_port():
  202. def __free_port():
  203. with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
  204. s.bind(('', 0))
  205. return s.getsockname()[1]
  206. new_port = __free_port()
  207. return new_port
  208. def load_yaml(config):
  209. vs = [int(i) for i in yaml.__version__.split(".")]
  210. if vs[0] < 5:
  211. use_full_loader = False
  212. elif vs[0] > 5:
  213. use_full_loader = True
  214. else:
  215. if vs[1] >= 1:
  216. use_full_loader = True
  217. else:
  218. use_full_loader = False
  219. if os.path.isfile(config):
  220. if six.PY2:
  221. with open(config, 'r') as rb:
  222. if use_full_loader:
  223. _config = yaml.load(rb.read(), Loader=yaml.FullLoader)
  224. else:
  225. _config = yaml.load(rb.read())
  226. return _config
  227. else:
  228. with open(config, 'r', encoding="utf-8") as rb:
  229. if use_full_loader:
  230. _config = yaml.load(rb.read(), Loader=yaml.FullLoader)
  231. else:
  232. _config = yaml.load(rb.read())
  233. return _config
  234. else:
  235. raise ValueError("config {} can not be supported".format(config))