|
@@ -26,6 +26,7 @@ import paddle.distributed.fleet as fleet
|
|
|
import paddle.distributed.fleet.base.role_maker as role_maker
|
|
|
import paddle
|
|
|
from paddle.base.executor import FetchHandler
|
|
|
+import queue
|
|
|
import threading
|
|
|
|
|
|
import warnings
|
|
@@ -54,9 +55,10 @@ class InferenceFetchHandler(FetchHandler):
|
|
|
def __init__(self, var_dict, output_file, batch_size=1000):
|
|
|
super().__init__(var_dict=var_dict, period_secs=1)
|
|
|
self.output_file = output_file
|
|
|
- self.batch_size = batch_size
|
|
|
- self.current_batch = []
|
|
|
- self.total_samples = 0
|
|
|
+ self.result_queue = queue.Queue()
|
|
|
+ self.writer_thread = threading.Thread(target=self._writer)
|
|
|
+ self.writer_thread.daemon = True # 设置为守护线程
|
|
|
+ self.writer_thread.start()
|
|
|
|
|
|
# 创建输出目录(如果不存在)
|
|
|
output_dir = os.path.dirname(output_file)
|
|
@@ -69,7 +71,9 @@ class InferenceFetchHandler(FetchHandler):
|
|
|
def handler(self, fetch_vars):
|
|
|
"""处理每批次的推理结果"""
|
|
|
result_dict = {}
|
|
|
- sys.stdout.write("\n")
|
|
|
+
|
|
|
+ super().handler(res_dict=fetch_vars)
|
|
|
+
|
|
|
for key in fetch_vars:
|
|
|
# 转换数据类型
|
|
|
if type(fetch_vars[key]) is np.ndarray:
|
|
@@ -77,31 +81,33 @@ class InferenceFetchHandler(FetchHandler):
|
|
|
else:
|
|
|
result = var_value
|
|
|
result_dict[key] = result
|
|
|
-
|
|
|
- for key in fetch_vars:
|
|
|
- if type(fetch_vars[key]) is np.ndarray:
|
|
|
- sys.stdout.write(f"{key}[0]: {fetch_vars[key][0]} ")
|
|
|
-
|
|
|
- self.current_batch.append(result_dict)
|
|
|
-
|
|
|
- # # 当累积足够的结果时,写入文件
|
|
|
- if len(self.current_batch) >= self.batch_size:
|
|
|
- self._write_batch()
|
|
|
+ self.result_queue.put(result_dict) # 将结果放入队列
|
|
|
+
|
|
|
+ def _writer(self):
|
|
|
+ batch = []
|
|
|
+ while True:
|
|
|
+ try:
|
|
|
+ result_dict = self.result_queue.get(timeout=1) # 非阻塞获取
|
|
|
+ batch.append(result_dict)
|
|
|
+ if len(batch) >= self.batch_size:
|
|
|
+ with open(self.output_file, 'a') as f:
|
|
|
+ for result in batch:
|
|
|
+ f.write(json.dumps(result) + '\n')
|
|
|
+ batch = []
|
|
|
+ except queue.Empty:
|
|
|
+ pass
|
|
|
|
|
|
- def _write_batch(self):
|
|
|
- """将批次结果写入文件"""
|
|
|
+ def _write_batch(self, batch):
|
|
|
with open(self.output_file, 'a') as f:
|
|
|
- for result in self.current_batch:
|
|
|
+ for result in batch:
|
|
|
f.write(json.dumps(result) + '\n')
|
|
|
- self.current_batch = []
|
|
|
-
|
|
|
- def finish(self):
|
|
|
- logger.info("InferenceFetchHandler finish")
|
|
|
- """确保所有剩余结果都被保存"""
|
|
|
- if self.current_batch:
|
|
|
- self._write_batch()
|
|
|
- logger.info(f"Final save: total {self.total_samples} samples saved to {self.output_file}")
|
|
|
- self.done_event.set()
|
|
|
+
|
|
|
+ def flush(self):
|
|
|
+ """确保所有结果都被写入文件"""
|
|
|
+ # 等待队列中剩余的结果被处理
|
|
|
+ self.result_queue.join()
|
|
|
+ # 写入最后一批结果
|
|
|
+ self._write_batch(self.result_queue.queue)
|
|
|
|
|
|
|
|
|
|
|
@@ -314,7 +320,7 @@ class Main(object):
|
|
|
print_period=print_step,
|
|
|
debug=debug,
|
|
|
fetch_handler=fetch_handler)
|
|
|
- fetch_handler.finish()
|
|
|
+ fetch_handler.flush()
|
|
|
|
|
|
|
|
|
def heter_train_loop(self, epoch):
|