|
@@ -46,6 +46,58 @@ logging.basicConfig(
|
|
logger = logging.getLogger(__name__)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
+import json
|
|
|
|
+
|
|
|
|
+# 创建推理结果处理器类
|
|
|
|
+class InferenceFetchHandler(object):
|
|
|
|
+ def __init__(self, output_file, batch_size=1000):
|
|
|
|
+ self.output_file = output_file
|
|
|
|
+ self.batch_size = batch_size
|
|
|
|
+ self.current_batch = []
|
|
|
|
+ self.total_samples = 0
|
|
|
|
+
|
|
|
|
+ # 创建输出目录(如果不存在)
|
|
|
|
+ output_dir = os.path.dirname(output_file)
|
|
|
|
+ if not os.path.exists(output_dir):
|
|
|
|
+ os.makedirs(output_dir)
|
|
|
|
+
|
|
|
|
+ # 创建或清空输出文件
|
|
|
|
+ with open(self.output_file, 'w') as f:
|
|
|
|
+ f.write('')
|
|
|
|
+
|
|
|
|
+ def handler(self, fetch_vars):
|
|
|
|
+ """处理每批次的推理结果"""
|
|
|
|
+ result_dict = {}
|
|
|
|
+ for var_name, var_value in fetch_vars.items():
|
|
|
|
+ # 转换数据类型
|
|
|
|
+ if isinstance(var_value, np.ndarray):
|
|
|
|
+ result = var_value.tolist()
|
|
|
|
+ else:
|
|
|
|
+ result = var_value
|
|
|
|
+ result_dict[var_name] = result
|
|
|
|
+
|
|
|
|
+ self.current_batch.append(result_dict)
|
|
|
|
+ self.total_samples += len(result_dict.get(list(result_dict.keys())[0], []))
|
|
|
|
+
|
|
|
|
+ # 当累积足够的结果时,写入文件
|
|
|
|
+ if len(self.current_batch) >= self.batch_size:
|
|
|
|
+ self._write_batch()
|
|
|
|
+ logger.info(f"Saved {self.total_samples} samples to {self.output_file}")
|
|
|
|
+
|
|
|
|
+ def _write_batch(self):
|
|
|
|
+ """将批次结果写入文件"""
|
|
|
|
+ with open(self.output_file, 'a') as f:
|
|
|
|
+ for result in self.current_batch:
|
|
|
|
+ f.write(json.dumps(result) + '\n')
|
|
|
|
+ self.current_batch = []
|
|
|
|
+
|
|
|
|
+ def finish(self):
|
|
|
|
+ """确保所有剩余结果都被保存"""
|
|
|
|
+ if self.current_batch:
|
|
|
|
+ self._write_batch()
|
|
|
|
+ logger.info(f"Final save: total {self.total_samples} samples saved to {self.output_file}")
|
|
|
|
+
|
|
|
|
+
|
|
def parse_args():
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser("PaddleRec train script")
|
|
parser = argparse.ArgumentParser("PaddleRec train script")
|
|
parser.add_argument("-o", "--opt", nargs='*', type=str)
|
|
parser.add_argument("-o", "--opt", nargs='*', type=str)
|
|
@@ -186,44 +238,21 @@ class Main(object):
|
|
if reader_type == "InmemoryDataset":
|
|
if reader_type == "InmemoryDataset":
|
|
self.reader.load_into_memory()
|
|
self.reader.load_into_memory()
|
|
|
|
|
|
- for epoch in range(epochs):
|
|
|
|
- fleet.load_inference_model(
|
|
|
|
- init_model_path,
|
|
|
|
- mode=int(model_mode))
|
|
|
|
- epoch_start_time = time.time()
|
|
|
|
-
|
|
|
|
- if sync_mode == "heter":
|
|
|
|
- self.heter_train_loop(epoch)
|
|
|
|
- elif reader_type == "QueueDataset":
|
|
|
|
- self.dataset_train_loop(epoch)
|
|
|
|
- elif reader_type == "InmemoryDataset":
|
|
|
|
- self.dataset_train_loop(epoch)
|
|
|
|
-
|
|
|
|
- epoch_time = time.time() - epoch_start_time
|
|
|
|
-
|
|
|
|
- if use_auc is True:
|
|
|
|
- global_auc = get_global_auc(paddle.static.global_scope(),
|
|
|
|
- self.model.stat_pos.name,
|
|
|
|
- self.model.stat_neg.name)
|
|
|
|
- self.train_result_dict["auc"].append(global_auc)
|
|
|
|
- set_zero(self.model.stat_pos.name,
|
|
|
|
- paddle.static.global_scope())
|
|
|
|
- set_zero(self.model.stat_neg.name,
|
|
|
|
- paddle.static.global_scope())
|
|
|
|
- set_zero(self.model.batch_stat_pos.name,
|
|
|
|
- paddle.static.global_scope())
|
|
|
|
- set_zero(self.model.batch_stat_neg.name,
|
|
|
|
- paddle.static.global_scope())
|
|
|
|
- logger.info(
|
|
|
|
- "Epoch: {}, using time: {} second, ips: {}/sec. auc: {}".
|
|
|
|
- format(epoch, epoch_time, self.count_method,
|
|
|
|
- global_auc))
|
|
|
|
- else:
|
|
|
|
- logger.info(
|
|
|
|
- "Epoch: {}, using time {} second, ips {}/sec.".format(
|
|
|
|
- epoch, epoch_time, self.count_method))
|
|
|
|
|
|
+ fleet.load_inference_model(
|
|
|
|
+ init_model_path,
|
|
|
|
+ mode=int(model_mode))
|
|
|
|
+ epoch_start_time = time.time()
|
|
|
|
+
|
|
|
|
+ if sync_mode == "heter":
|
|
|
|
+ self.heter_train_loop(epoch)
|
|
|
|
+ elif reader_type == "QueueDataset":
|
|
|
|
+ self.dataset_train_loop(epoch)
|
|
|
|
+ elif reader_type == "InmemoryDataset":
|
|
|
|
+ self.dataset_train_loop(epoch)
|
|
|
|
|
|
- model_dir = "{}/{}".format(save_model_path, epoch)
|
|
|
|
|
|
+ epoch_time = time.time() - epoch_start_time
|
|
|
|
+ logger.info(
|
|
|
|
+ "using time {} second, ips {}/sec.".format(epoch_time, self.count_method))
|
|
|
|
|
|
if reader_type == "InmemoryDataset":
|
|
if reader_type == "InmemoryDataset":
|
|
self.reader.release_memory()
|
|
self.reader.release_memory()
|
|
@@ -259,6 +288,13 @@ class Main(object):
|
|
"dump_fields_path": dump_fields_path,
|
|
"dump_fields_path": dump_fields_path,
|
|
"dump_fields": config.get("runner.dump_fields")
|
|
"dump_fields": config.get("runner.dump_fields")
|
|
})
|
|
})
|
|
|
|
+
|
|
|
|
+ # 设置输出文件路径
|
|
|
|
+ output_dir = config.get("runner.inference_output_dir", "inference_results")
|
|
|
|
+ output_file = os.path.join(output_dir, f"epoch_{epoch}_results.jsonl")
|
|
|
|
+
|
|
|
|
+ # 创建处理器实例
|
|
|
|
+ fetch_handler = InferenceFetchHandler(output_file)
|
|
print(paddle.static.default_main_program()._fleet_opt)
|
|
print(paddle.static.default_main_program()._fleet_opt)
|
|
results = self.exe.infer_from_dataset(
|
|
results = self.exe.infer_from_dataset(
|
|
program=paddle.static.default_main_program(),
|
|
program=paddle.static.default_main_program(),
|
|
@@ -266,8 +302,10 @@ class Main(object):
|
|
fetch_list=fetch_vars,
|
|
fetch_list=fetch_vars,
|
|
fetch_info=fetch_info,
|
|
fetch_info=fetch_info,
|
|
print_period=print_step,
|
|
print_period=print_step,
|
|
- debug=debug)
|
|
|
|
|
|
+ debug=debug,
|
|
|
|
+ fetch_handler=fetch_handler)
|
|
print("results {}".format(results))
|
|
print("results {}".format(results))
|
|
|
|
+ fetch_handler.finish()
|
|
|
|
|
|
|
|
|
|
def heter_train_loop(self, epoch):
|
|
def heter_train_loop(self, epoch):
|