|
@@ -74,7 +74,7 @@ class InferenceFetchHandler(FetchHandler):
|
|
|
def handler(self, fetch_vars):
|
|
|
"""处理每批次的推理结果"""
|
|
|
result_dict = {}
|
|
|
- print("InferenceFetchHandler fetch_vars {}".format(fetch_vars))
|
|
|
+ logger.info("InferenceFetchHandler fetch_vars {}".format(fetch_vars))
|
|
|
for var_name, var_value in fetch_vars.items():
|
|
|
# 转换数据类型
|
|
|
if isinstance(var_value, np.ndarray):
|
|
@@ -99,6 +99,7 @@ class InferenceFetchHandler(FetchHandler):
|
|
|
self.current_batch = []
|
|
|
|
|
|
def finish(self):
|
|
|
+ logger.info("InferenceFetchHandler finish")
|
|
|
"""确保所有剩余结果都被保存"""
|
|
|
if self.current_batch:
|
|
|
self._write_batch()
|
|
@@ -233,17 +234,15 @@ class Main(object):
|
|
|
init_model_path = config.get("runner.infer_load_path")
|
|
|
model_mode = config.get("runner.model_mode", 0)
|
|
|
client = HangZhouOSSClient("art-recommend")
|
|
|
- client.get_object_to_file("lqc/64.tar.gz", "64.tar.gz")
|
|
|
- compress.uncompress_tar("64.tar.gz", init_model_path)
|
|
|
+ oss_object_name = self.config.get("runner.oss_object_name", "dyp/model.tar.gz")
|
|
|
+ client.get_object_to_file("oss_object_name", "model.tar.gz")
|
|
|
+ compress.uncompress_tar("model.tar.gz", init_model_path)
|
|
|
assert os.path.exists(init_model_path)
|
|
|
|
|
|
#if fleet.is_first_worker():
|
|
|
#fleet.load_inference_model(init_model_path, mode=int(model_mode))
|
|
|
#fleet.barrier_worker()
|
|
|
|
|
|
- save_model_path = self.config.get("runner.model_save_path")
|
|
|
- if save_model_path and (not os.path.exists(save_model_path)):
|
|
|
- os.makedirs(save_model_path)
|
|
|
|
|
|
reader_type = self.config.get("runner.reader_type", "QueueDataset")
|
|
|
epochs = int(self.config.get("runner.epochs"))
|