@@ -185,7 +185,7 @@ class Main(object):
for epoch in range(epochs):
fleet.load_inference_model(
- os.path.join(init_model_path, str(epoch)),
+ init_model_path,
mode=int(model_mode))
epoch_start_time = time.time()