丁云鹏 5 mesi fa
parent
commit
83f761f0b8

+ 9 - 4
recommend-model-produce/src/main/python/tools/static_ps_trainer_v2.py

@@ -248,11 +248,16 @@ class Main(object):
 
             # trans to new format
             # {"model_filename":"", "params_filename":""} fleet每个参数一个文件,需要同这种方式加载
-            paddle.static.load_inference_model(os.path.join(model_dir, "dnn_plugin"), model_filename=None, params_filename=None})
+            paddle.static.load_inference_model(
+                os.path.join(model_dir, "dnn_plugin"),
+                self.exe,
+                model_filename=None,
+                params_filename=None)
+            
             paddle.static.save_inference_model(
-                    os.path.join(model_dir, "dnn_plugin_new"),
-                    [feed.name for feed in self.inference_feed_var],
-                    [self.inference_target_var], self.exe)
+                os.path.join(model_dir, "dnn_plugin_new"),
+                [feed.name for feed in self.inference_feed_var],
+                [self.inference_target_var], self.exe)
 
             compress.compress_tar(os.path.join(model_dir, "dnn_plugin_new"), "dnn_plugin_new.tar.gz")
             client = HangZhouOSSClient("art-recommend")