丁云鹏 4 meses atrás
pai
commit
38baeb0333

+ 3 - 7
recommend-model-produce/src/main/python/tools/static_ps_infer_v2.py

@@ -50,13 +50,8 @@ logger = logging.getLogger(__name__)
 
 import json
 
-class InferenceFetchHandler(FetchHandler):
-    def __init__(self, var_dict, period_secs, output_file, batch_size=1000):
-        assert var_dict is not None
-        self.var_dict = var_dict
-        self.period_secs = period_secs
-
-
+class InferenceFetchHandler(object):
+    def __init__(self, output_file, batch_size=1000):
         self.output_file = output_file
         self.batch_size = batch_size
         self.current_batch = []
@@ -322,6 +317,7 @@ class Main(object):
         
         # 创建处理器实例
         fetch_handler = InferenceFetchHandler(output_file)
+        fetch_handler.set_var_dict(fetch_vars)
         print(paddle.static.default_main_program()._fleet_opt)
         self.exe.infer_from_dataset(
             program=paddle.static.default_main_program(),