|
@@ -25,6 +25,7 @@ import sys
|
|
|
import paddle.distributed.fleet as fleet
|
|
|
import paddle.distributed.fleet.base.role_maker as role_maker
|
|
|
import paddle
|
|
|
+import paddle.base.core as core
|
|
|
|
|
|
import warnings
|
|
|
import logging
|
|
@@ -335,7 +336,20 @@ if __name__ == "__main__":
|
|
|
|
|
|
paddle.enable_static()
|
|
|
|
|
|
+ read_env_flags = [
|
|
|
+ key[len(flag_prefix) :]
|
|
|
+ for key in core.globals().keys()
|
|
|
+ if key.startswith(flag_prefix)
|
|
|
+ ]
|
|
|
+
|
|
|
+ def remove_flag_if_exists(name):
|
|
|
+ if name in read_env_flags:
|
|
|
+ read_env_flags.remove(name)
|
|
|
+
|
|
|
+ core.init_gflags(["--tryfromenv=" + ",".join(read_env_flags)] + ",max_body_size")
|
|
|
print('global_flags: ' + str(list(paddle.base.framework._global_flags().keys())))
|
|
|
+
|
|
|
+
|
|
|
config = parse_args()
|
|
|
os.environ["CPU_NUM"] = str(config.get("runner.thread_num"))
|
|
|
benchmark_main = Main(config)
|