| 
					
				 | 
			
			
				@@ -78,17 +78,26 @@ def get_infer_reader(input_var, config): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         return reader_instance.get_reader(), file_list 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-def get_file_list(data_path, config): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    dirs,file_list = hdfs_client.ls_dir(data_path) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    print(dirs,file_list) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+def get_file_list(data_path, config, file_extensions=['.gz']): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    all_files = [] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    sub_dirs,file_list = hdfs_client.ls_dir(data_path)     
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    for sub_dir in sub_dirs: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        _, files = hdfs_client.ls_dir(sub_dir) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        for file in files: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            # 扩展名过滤 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            if file_extensions and not any(file.endswith(ext) for ext in file_extensions): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                continue     
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            all_files.append(file) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+     
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    print(dirs,all_files) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     # 如果配置中指定了分割文件列表 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     if config.get("runner.split_file_list"): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         logger.info("Split file list for worker {}".format(dist.get_rank())) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        file_list = fleet.util.get_file_shard(file_list) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    logger.info("File list: {}".format(file_list)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        all_files = fleet.util.get_file_shard(all_files) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    logger.info("File list: {}".format(all_files)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     base_url = f'{configs["fs.default.name"]}' 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    full_paths = [base_url + file for file in file_list] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    full_paths = [base_url + file for file in all_files] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     return full_paths 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 |