|
@@ -78,17 +78,26 @@ def get_infer_reader(input_var, config):
|
|
return reader_instance.get_reader(), file_list
|
|
return reader_instance.get_reader(), file_list
|
|
|
|
|
|
|
|
|
|
-def get_file_list(data_path, config):
|
|
|
|
- dirs,file_list = hdfs_client.ls_dir(data_path)
|
|
|
|
- print(dirs,file_list)
|
|
|
|
|
|
+def get_file_list(data_path, config, file_extensions=['.gz']):
|
|
|
|
+ all_files = []
|
|
|
|
+ sub_dirs,file_list = hdfs_client.ls_dir(data_path)
|
|
|
|
+ for sub_dir in sub_dirs:
|
|
|
|
+ _, files = hdfs_client.ls_dir(sub_dir)
|
|
|
|
+ for file in files:
|
|
|
|
+ # 扩展名过滤
|
|
|
|
+ if file_extensions and not any(file.endswith(ext) for ext in file_extensions):
|
|
|
|
+ continue
|
|
|
|
+ all_files.append(file)
|
|
|
|
+
|
|
|
|
+ print(dirs,all_files)
|
|
# 如果配置中指定了分割文件列表
|
|
# 如果配置中指定了分割文件列表
|
|
if config.get("runner.split_file_list"):
|
|
if config.get("runner.split_file_list"):
|
|
logger.info("Split file list for worker {}".format(dist.get_rank()))
|
|
logger.info("Split file list for worker {}".format(dist.get_rank()))
|
|
- file_list = fleet.util.get_file_shard(file_list)
|
|
|
|
- logger.info("File list: {}".format(file_list))
|
|
|
|
|
|
+ all_files = fleet.util.get_file_shard(all_files)
|
|
|
|
+ logger.info("File list: {}".format(all_files))
|
|
|
|
|
|
base_url = f'{configs["fs.default.name"]}'
|
|
base_url = f'{configs["fs.default.name"]}'
|
|
- full_paths = [base_url + file for file in file_list]
|
|
|
|
|
|
+ full_paths = [base_url + file for file in all_files]
|
|
|
|
|
|
return full_paths
|
|
return full_paths
|
|
|
|
|