| 
					
				 | 
			
			
				@@ -26,6 +26,7 @@ import paddle.distributed.fleet as fleet 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 import paddle.distributed.fleet.base.role_maker as role_maker 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 import paddle 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 from paddle.base.executor import FetchHandler 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+import queue 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 import threading 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 import warnings 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -54,9 +55,10 @@ class InferenceFetchHandler(FetchHandler): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     def __init__(self, var_dict, output_file, batch_size=1000): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         super().__init__(var_dict=var_dict, period_secs=1) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         self.output_file = output_file 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        self.batch_size = batch_size 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        self.current_batch = [] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        self.total_samples = 0 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        self.result_queue = queue.Queue() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        self.writer_thread = threading.Thread(target=self._writer) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        self.writer_thread.daemon = True  # 设置为守护线程 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        self.writer_thread.start() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				          
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         # 创建输出目录(如果不存在) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         output_dir = os.path.dirname(output_file) 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -69,7 +71,9 @@ class InferenceFetchHandler(FetchHandler): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     def handler(self, fetch_vars): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         """处理每批次的推理结果""" 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         result_dict = {} 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        sys.stdout.write("\n") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        super().handler(res_dict=fetch_vars) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         for key in fetch_vars: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             # 转换数据类型 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             if type(fetch_vars[key]) is np.ndarray: 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -77,31 +81,33 @@ class InferenceFetchHandler(FetchHandler): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             else: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 result = var_value 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             result_dict[key] = result 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-             
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        for key in fetch_vars: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            if type(fetch_vars[key]) is np.ndarray: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                sys.stdout.write(f"{key}[0]: {fetch_vars[key][0]} ") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        self.current_batch.append(result_dict) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-         
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        # # 当累积足够的结果时,写入文件 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        if len(self.current_batch) >= self.batch_size: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            self._write_batch() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        self.result_queue.put(result_dict)  # 将结果放入队列 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+     
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    def _writer(self): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        batch = [] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        while True: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            try: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                result_dict = self.result_queue.get(timeout=1)  # 非阻塞获取 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                batch.append(result_dict) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                if len(batch) >= self.batch_size: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    with open(self.output_file, 'a') as f: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        for result in batch: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                            f.write(json.dumps(result) + '\n') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    batch = [] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            except queue.Empty: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                pass 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				      
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    def _write_batch(self): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        """将批次结果写入文件""" 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    def _write_batch(self, batch): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         with open(self.output_file, 'a') as f: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            for result in self.current_batch: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            for result in batch: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 f.write(json.dumps(result) + '\n') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        self.current_batch = [] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-     
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    def finish(self): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        logger.info("InferenceFetchHandler finish") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        """确保所有剩余结果都被保存""" 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        if self.current_batch: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            self._write_batch() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            logger.info(f"Final save: total {self.total_samples} samples saved to {self.output_file}") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        self.done_event.set() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    def flush(self): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        """确保所有结果都被写入文件""" 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        # 等待队列中剩余的结果被处理 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        self.result_queue.join() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        # 写入最后一批结果 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        self._write_batch(self.result_queue.queue) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				      
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -314,7 +320,7 @@ class Main(object): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             print_period=print_step, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             debug=debug, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             fetch_handler=fetch_handler) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        fetch_handler.finish() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        fetch_handler.flush() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     def heter_train_loop(self, epoch): 
			 |