Ver Fonte

dssm train

丁云鹏 há 4 meses atrás
pai
commit
d84b73a90c

+ 10 - 8
recommend-model-produce/src/main/python/tools/static_ps_infer_v2.py

@@ -67,18 +67,21 @@ class InferenceFetchHandler(FetchHandler):
             f.write('')
     
     def handler(self, fetch_vars):
-        super().handler(res_dict=fetch_vars)
         """处理每批次的推理结果"""
         result_dict = {}
-        # logger.info("InferenceFetchHandler fetch_vars {}".format(fetch_vars))
-        for var_name, var_value in fetch_vars.items():
+        sys.stdout.write("\n")
+        for key in fetch_vars:
             # 转换数据类型
-            if isinstance(var_value, np.ndarray):
-                result = var_value.tolist()
+            if type(fetch_vars[key]) is np.ndarray:
+                result = res_dict[key][0]
             else:
                 result = var_value
-            result_dict[var_name] = result
-        
+            result_dict[key] = result
+            
+        for key in fetch_vars:
+            if type(fetch_vars[key]) is np.ndarray:
+                sys.stdout.write(f"{key}[0]: {fetch_vars[key][0]} ")
+
         self.current_batch.append(result_dict)
         
         # # 当累积足够的结果时,写入文件
@@ -301,7 +304,6 @@ class Main(object):
         
         # 创建处理器实例
         fetch_handler = InferenceFetchHandler(var_dict = self.metrics, output_file = output_file)
-        # fetch_handler.set_var_dict(self.metrics)
 
         print(paddle.static.default_main_program()._fleet_opt)
         self.exe.infer_from_dataset(