| 
					
				 | 
			
			
				@@ -46,6 +46,58 @@ logging.basicConfig( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 logger = logging.getLogger(__name__) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+import json 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+# 创建推理结果处理器类 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+class InferenceFetchHandler(object): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    def __init__(self, output_file, batch_size=1000): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        self.output_file = output_file 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        self.batch_size = batch_size 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        self.current_batch = [] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        self.total_samples = 0 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+         
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        # 创建输出目录(如果不存在) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        output_dir = os.path.dirname(output_file) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        if not os.path.exists(output_dir): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            os.makedirs(output_dir) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+         
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        # 创建或清空输出文件 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        with open(self.output_file, 'w') as f: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            f.write('') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+     
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    def handler(self, fetch_vars): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        """处理每批次的推理结果""" 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        result_dict = {} 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        for var_name, var_value in fetch_vars.items(): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            # 转换数据类型 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            if isinstance(var_value, np.ndarray): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                result = var_value.tolist() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            else: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                result = var_value 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            result_dict[var_name] = result 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+         
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        self.current_batch.append(result_dict) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        self.total_samples += len(result_dict.get(list(result_dict.keys())[0], [])) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+         
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        # 当累积足够的结果时,写入文件 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        if len(self.current_batch) >= self.batch_size: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            self._write_batch() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            logger.info(f"Saved {self.total_samples} samples to {self.output_file}") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+     
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    def _write_batch(self): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        """将批次结果写入文件""" 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        with open(self.output_file, 'a') as f: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            for result in self.current_batch: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                f.write(json.dumps(result) + '\n') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        self.current_batch = [] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+     
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    def finish(self): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        """确保所有剩余结果都被保存""" 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        if self.current_batch: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            self._write_batch() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            logger.info(f"Final save: total {self.total_samples} samples saved to {self.output_file}") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 def parse_args(): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     parser = argparse.ArgumentParser("PaddleRec train script") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     parser.add_argument("-o", "--opt", nargs='*', type=str) 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -186,44 +238,21 @@ class Main(object): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         if reader_type == "InmemoryDataset": 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             self.reader.load_into_memory() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        for epoch in range(epochs): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            fleet.load_inference_model( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                init_model_path, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                mode=int(model_mode)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            epoch_start_time = time.time() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            if sync_mode == "heter": 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                self.heter_train_loop(epoch) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            elif reader_type == "QueueDataset": 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                self.dataset_train_loop(epoch) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            elif reader_type == "InmemoryDataset": 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                self.dataset_train_loop(epoch) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            epoch_time = time.time() - epoch_start_time 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            if use_auc is True: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                global_auc = get_global_auc(paddle.static.global_scope(), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                                            self.model.stat_pos.name, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                                            self.model.stat_neg.name) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                self.train_result_dict["auc"].append(global_auc) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                set_zero(self.model.stat_pos.name, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                         paddle.static.global_scope()) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                set_zero(self.model.stat_neg.name, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                         paddle.static.global_scope()) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                set_zero(self.model.batch_stat_pos.name, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                         paddle.static.global_scope()) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                set_zero(self.model.batch_stat_neg.name, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                         paddle.static.global_scope()) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                logger.info( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                    "Epoch: {}, using time: {} second, ips: {}/sec. auc: {}". 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                    format(epoch, epoch_time, self.count_method, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                           global_auc)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            else: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                logger.info( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                    "Epoch: {}, using time {} second, ips  {}/sec.".format( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                        epoch, epoch_time, self.count_method)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        fleet.load_inference_model( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            init_model_path, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            mode=int(model_mode)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        epoch_start_time = time.time() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        if sync_mode == "heter": 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            self.heter_train_loop(epoch) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        elif reader_type == "QueueDataset": 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            self.dataset_train_loop(epoch) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        elif reader_type == "InmemoryDataset": 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            self.dataset_train_loop(epoch) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            model_dir = "{}/{}".format(save_model_path, epoch) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        epoch_time = time.time() - epoch_start_time 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        logger.info( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            "using time {} second, ips  {}/sec.".format(epoch_time, self.count_method)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         if reader_type == "InmemoryDataset": 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             self.reader.release_memory() 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -259,6 +288,13 @@ class Main(object): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 "dump_fields_path": dump_fields_path, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 "dump_fields": config.get("runner.dump_fields") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             }) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+             
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        # 设置输出文件路径 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        output_dir = config.get("runner.inference_output_dir", "inference_results") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        output_file = os.path.join(output_dir, f"epoch_{epoch}_results.jsonl") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+         
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        # 创建处理器实例 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        fetch_handler = InferenceFetchHandler(output_file) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         print(paddle.static.default_main_program()._fleet_opt) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         results = self.exe.infer_from_dataset( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             program=paddle.static.default_main_program(), 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -266,8 +302,10 @@ class Main(object): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             fetch_list=fetch_vars, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             fetch_info=fetch_info, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             print_period=print_step, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            debug=debug) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            debug=debug, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            fetch_handler=fetch_handler) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         print("results {}".format(results)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        fetch_handler.finish() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     def heter_train_loop(self, epoch): 
			 |