丁云鹏 4 mēneši atpakaļ
vecāks
revīzija
ad2b3f0098

+ 16 - 16
recommend-model-produce/src/main/python/tools/static_ps_infer_v2.py

@@ -69,23 +69,23 @@ class InferenceFetchHandler(FetchHandler):
     def handler(self, fetch_vars):
         super().handler(res_dict=fetch_vars)
         """处理每批次的推理结果"""
-        # result_dict = {}
-        # logger.info("InferenceFetchHandler fetch_vars {}".format(fetch_vars))
-        # for var_name, var_value in fetch_vars.items():
-        #     # 转换数据类型
-        #     if isinstance(var_value, np.ndarray):
-        #         result = var_value.tolist()
-        #     else:
-        #         result = var_value
-        #     result_dict[var_name] = result
+        result_dict = {}
+        logger.info("InferenceFetchHandler fetch_vars {}".format(fetch_vars))
+        for var_name, var_value in fetch_vars.items():
+            # 转换数据类型
+            if isinstance(var_value, np.ndarray):
+                result = var_value.tolist()
+            else:
+                result = var_value
+            result_dict[var_name] = result
         
-        # self.current_batch.append(result_dict)
-        # self.total_samples += len(result_dict.get(list(result_dict.keys())[0], []))
+        self.current_batch.append(result_dict)
+        self.total_samples += len(result_dict.get(list(result_dict.keys())[0], []))
         
-        # # 当累积足够的结果时,写入文件
-        # if len(self.current_batch) >= self.batch_size:
-        #     self._write_batch()
-        #     logger.info(f"Saved {self.total_samples} samples to {self.output_file}")
+        # 当累积足够的结果时,写入文件
+        if len(self.current_batch) >= self.batch_size:
+            self._write_batch()
+            logger.info(f"Saved {self.total_samples} samples to {self.output_file}")
     
     def _write_batch(self):
         """将批次结果写入文件"""
@@ -302,7 +302,7 @@ class Main(object):
         output_file = os.path.join(output_dir, f"epoch_{epoch}_results.jsonl")
         
         # 创建处理器实例
-        fetch_handler = InferenceFetchHandler(var_dict = self.metrics)
+        fetch_handler = InferenceFetchHandler(var_dict = self.metrics, output_file = output_file)
         # fetch_handler.set_var_dict(self.metrics)
 
         print(paddle.static.default_main_program()._fleet_opt)