소스 검색

dssm train

丁云鹏 4 달 전
부모
커밋
e22c45eaba
1개의 변경된 파일33개의 추가작업 그리고 27개의 파일을 삭제
  1. 33 27
      recommend-model-produce/src/main/python/tools/static_ps_infer_v2.py

+ 33 - 27
recommend-model-produce/src/main/python/tools/static_ps_infer_v2.py

@@ -26,6 +26,7 @@ import paddle.distributed.fleet as fleet
 import paddle.distributed.fleet.base.role_maker as role_maker
 import paddle
 from paddle.base.executor import FetchHandler
+import queue
 import threading
 
 import warnings
@@ -54,9 +55,10 @@ class InferenceFetchHandler(FetchHandler):
     def __init__(self, var_dict, output_file, batch_size=1000):
         super().__init__(var_dict=var_dict, period_secs=1)
         self.output_file = output_file
-        self.batch_size = batch_size
-        self.current_batch = []
-        self.total_samples = 0
+        self.result_queue = queue.Queue()
+        self.writer_thread = threading.Thread(target=self._writer)
+        self.writer_thread.daemon = True  # 设置为守护线程
+        self.writer_thread.start()
         
         # 创建输出目录(如果不存在)
         output_dir = os.path.dirname(output_file)
@@ -69,7 +71,9 @@ class InferenceFetchHandler(FetchHandler):
     def handler(self, fetch_vars):
         """处理每批次的推理结果"""
         result_dict = {}
-        sys.stdout.write("\n")
+
+        super().handler(res_dict=fetch_vars)
+
         for key in fetch_vars:
             # 转换数据类型
             if type(fetch_vars[key]) is np.ndarray:
@@ -77,31 +81,33 @@ class InferenceFetchHandler(FetchHandler):
             else:
                 result = var_value
             result_dict[key] = result
-            
-        for key in fetch_vars:
-            if type(fetch_vars[key]) is np.ndarray:
-                sys.stdout.write(f"{key}[0]: {fetch_vars[key][0]} ")
-
-        self.current_batch.append(result_dict)
-        
-        # # 当累积足够的结果时,写入文件
-        if len(self.current_batch) >= self.batch_size:
-            self._write_batch()
+        self.result_queue.put(result_dict)  # 将结果放入队列
+    
+    def _writer(self):
+        batch = []
+        while True:
+            try:
+                result_dict = self.result_queue.get(timeout=1)  # 非阻塞获取
+                batch.append(result_dict)
+                if len(batch) >= self.batch_size:
+                    with open(self.output_file, 'a') as f:
+                        for result in batch:
+                            f.write(json.dumps(result) + '\n')
+                    batch = []
+            except queue.Empty:
+                pass
     
-    def _write_batch(self):
-        """将批次结果写入文件"""
+    def _write_batch(self, batch):
         with open(self.output_file, 'a') as f:
-            for result in self.current_batch:
+            for result in batch:
                 f.write(json.dumps(result) + '\n')
-        self.current_batch = []
-    
-    def finish(self):
-        logger.info("InferenceFetchHandler finish")
-        """确保所有剩余结果都被保存"""
-        if self.current_batch:
-            self._write_batch()
-            logger.info(f"Final save: total {self.total_samples} samples saved to {self.output_file}")
-        self.done_event.set()
+
+    def flush(self):
+        """确保所有结果都被写入文件"""
+        # 等待队列中剩余的结果被处理
+        self.result_queue.join()
+        # 写入最后一批结果
+        self._write_batch(self.result_queue.queue)
     
 
 
@@ -314,7 +320,7 @@ class Main(object):
             print_period=print_step,
             debug=debug,
             fetch_handler=fetch_handler)
-        fetch_handler.finish()
+        fetch_handler.flush()
 
 
     def heter_train_loop(self, epoch):