丁云鹏 4 місяців тому
батько
коміт
9fe4c2f208

+ 5 - 6
recommend-model-produce/src/main/python/tools/static_ps_infer_v2.py

@@ -74,7 +74,7 @@ class InferenceFetchHandler(FetchHandler):
     def handler(self, fetch_vars):
         """处理每批次的推理结果"""
         result_dict = {}
-        print("InferenceFetchHandler fetch_vars {}".format(fetch_vars))
+        logger.info("InferenceFetchHandler fetch_vars {}".format(fetch_vars))
         for var_name, var_value in fetch_vars.items():
             # 转换数据类型
             if isinstance(var_value, np.ndarray):
@@ -99,6 +99,7 @@ class InferenceFetchHandler(FetchHandler):
         self.current_batch = []
     
     def finish(self):
+        logger.info("InferenceFetchHandler finish")
         """确保所有剩余结果都被保存"""
         if self.current_batch:
             self._write_batch()
@@ -233,17 +234,15 @@ class Main(object):
         init_model_path = config.get("runner.infer_load_path")
         model_mode = config.get("runner.model_mode", 0)
         client = HangZhouOSSClient("art-recommend")
-        client.get_object_to_file("lqc/64.tar.gz", "64.tar.gz")
-        compress.uncompress_tar("64.tar.gz", init_model_path)
+        oss_object_name = self.config.get("runner.oss_object_name", "dyp/model.tar.gz")
+        client.get_object_to_file("oss_object_name", "model.tar.gz")
+        compress.uncompress_tar("model.tar.gz", init_model_path)
         assert os.path.exists(init_model_path)
 
         #if fleet.is_first_worker():
         #fleet.load_inference_model(init_model_path, mode=int(model_mode))
         #fleet.barrier_worker()
 
-        save_model_path = self.config.get("runner.model_save_path")
-        if save_model_path and (not os.path.exists(save_model_path)):
-            os.makedirs(save_model_path)
 
         reader_type = self.config.get("runner.reader_type", "QueueDataset")
         epochs = int(self.config.get("runner.epochs"))

+ 1 - 1
recommend-model-produce/src/main/python/tools/static_ps_trainer_v2.py

@@ -244,7 +244,7 @@ class Main(object):
 
         if fleet.is_first_worker():
             model_dir = "{}/{}".format(save_model_path, epochs - 1)
-            oss_object_name = self.config.get("runner.oss_object_name", "dyp/dnn_plugin_new.tar.gz")
+            oss_object_name = self.config.get("runner.oss_object_name", "dyp/model.tar.gz")
             # trans to new format
             # {"model_filename":"", "params_filename":""} fleet每个参数一个文件,需要同这种方式加载
             program, feed_target_names, fetch_targets = paddle.static.load_inference_model(