often пре 5 месеци
родитељ
комит
7ccf080886
1 измењених фајлова са 76 додато и 38 уклоњено
  1. 76 38
      recommend-model-produce/src/main/python/tools/static_ps_infer_v2.py

+ 76 - 38
recommend-model-produce/src/main/python/tools/static_ps_infer_v2.py

@@ -46,6 +46,58 @@ logging.basicConfig(
 logger = logging.getLogger(__name__)
 
 
+import json
+
+# 创建推理结果处理器类
+class InferenceFetchHandler(object):
+    def __init__(self, output_file, batch_size=1000):
+        self.output_file = output_file
+        self.batch_size = batch_size
+        self.current_batch = []
+        self.total_samples = 0
+        
+        # 创建输出目录(如果不存在)
+        output_dir = os.path.dirname(output_file)
+        if not os.path.exists(output_dir):
+            os.makedirs(output_dir)
+        
+        # 创建或清空输出文件
+        with open(self.output_file, 'w') as f:
+            f.write('')
+    
+    def handler(self, fetch_vars):
+        """处理每批次的推理结果"""
+        result_dict = {}
+        for var_name, var_value in fetch_vars.items():
+            # 转换数据类型
+            if isinstance(var_value, np.ndarray):
+                result = var_value.tolist()
+            else:
+                result = var_value
+            result_dict[var_name] = result
+        
+        self.current_batch.append(result_dict)
+        self.total_samples += len(result_dict.get(list(result_dict.keys())[0], []))
+        
+        # 当累积足够的结果时,写入文件
+        if len(self.current_batch) >= self.batch_size:
+            self._write_batch()
+            logger.info(f"Saved {self.total_samples} samples to {self.output_file}")
+    
+    def _write_batch(self):
+        """将批次结果写入文件"""
+        with open(self.output_file, 'a') as f:
+            for result in self.current_batch:
+                f.write(json.dumps(result) + '\n')
+        self.current_batch = []
+    
+    def finish(self):
+        """确保所有剩余结果都被保存"""
+        if self.current_batch:
+            self._write_batch()
+            logger.info(f"Final save: total {self.total_samples} samples saved to {self.output_file}")
+
+
 def parse_args():
     parser = argparse.ArgumentParser("PaddleRec train script")
     parser.add_argument("-o", "--opt", nargs='*', type=str)
@@ -186,44 +238,21 @@ class Main(object):
         if reader_type == "InmemoryDataset":
             self.reader.load_into_memory()
 
-        for epoch in range(epochs):
-            fleet.load_inference_model(
-                init_model_path,
-                mode=int(model_mode))
-            epoch_start_time = time.time()
-
-            if sync_mode == "heter":
-                self.heter_train_loop(epoch)
-            elif reader_type == "QueueDataset":
-                self.dataset_train_loop(epoch)
-            elif reader_type == "InmemoryDataset":
-                self.dataset_train_loop(epoch)
-
-            epoch_time = time.time() - epoch_start_time
-
-            if use_auc is True:
-                global_auc = get_global_auc(paddle.static.global_scope(),
-                                            self.model.stat_pos.name,
-                                            self.model.stat_neg.name)
-                self.train_result_dict["auc"].append(global_auc)
-                set_zero(self.model.stat_pos.name,
-                         paddle.static.global_scope())
-                set_zero(self.model.stat_neg.name,
-                         paddle.static.global_scope())
-                set_zero(self.model.batch_stat_pos.name,
-                         paddle.static.global_scope())
-                set_zero(self.model.batch_stat_neg.name,
-                         paddle.static.global_scope())
-                logger.info(
-                    "Epoch: {}, using time: {} second, ips: {}/sec. auc: {}".
-                    format(epoch, epoch_time, self.count_method,
-                           global_auc))
-            else:
-                logger.info(
-                    "Epoch: {}, using time {} second, ips  {}/sec.".format(
-                        epoch, epoch_time, self.count_method))
+        fleet.load_inference_model(
+            init_model_path,
+            mode=int(model_mode))
+        epoch_start_time = time.time()
+
+        if sync_mode == "heter":
+            self.heter_train_loop(epoch)
+        elif reader_type == "QueueDataset":
+            self.dataset_train_loop(epoch)
+        elif reader_type == "InmemoryDataset":
+            self.dataset_train_loop(epoch)
 
-            model_dir = "{}/{}".format(save_model_path, epoch)
+        epoch_time = time.time() - epoch_start_time
+        logger.info(
+            "using time {} second, ips  {}/sec.".format(epoch_time, self.count_method))
 
         if reader_type == "InmemoryDataset":
             self.reader.release_memory()
@@ -259,6 +288,13 @@ class Main(object):
                 "dump_fields_path": dump_fields_path,
                 "dump_fields": config.get("runner.dump_fields")
             })
+            
+        # 设置输出文件路径
+        output_dir = config.get("runner.inference_output_dir", "inference_results")
+        output_file = os.path.join(output_dir, f"epoch_{epoch}_results.jsonl")
+        
+        # 创建处理器实例
+        fetch_handler = InferenceFetchHandler(output_file)
         print(paddle.static.default_main_program()._fleet_opt)
         results = self.exe.infer_from_dataset(
             program=paddle.static.default_main_program(),
@@ -266,8 +302,10 @@ class Main(object):
             fetch_list=fetch_vars,
             fetch_info=fetch_info,
             print_period=print_step,
-            debug=debug)
+            debug=debug,
+            fetch_handler=fetch_handler)
         print("results {}".format(results))
+        fetch_handler.finish()
 
 
     def heter_train_loop(self, epoch):