often 5 mēneši atpakaļ
vecāks
revīzija
a8cd264cef

+ 15 - 6
recommend-model-produce/src/main/python/tools/utils/static_ps/reader_helper_hdfs.py

@@ -78,17 +78,26 @@ def get_infer_reader(input_var, config):
         return reader_instance.get_reader(), file_list
 
 
-def get_file_list(data_path, config):
-    dirs,file_list = hdfs_client.ls_dir(data_path)
-    print(dirs,file_list)
+def get_file_list(data_path, config, file_extensions=['.gz']):
+    all_files = []
+    sub_dirs,file_list = hdfs_client.ls_dir(data_path)    
+    for sub_dir in sub_dirs:
+        _, files = hdfs_client.ls_dir(sub_dir)
+        for file in files:
+            # 扩展名过滤
+            if file_extensions and not any(file.endswith(ext) for ext in file_extensions):
+                continue    
+            all_files.append(file)
+    
+    print(dirs,all_files)
     # 如果配置中指定了分割文件列表
     if config.get("runner.split_file_list"):
         logger.info("Split file list for worker {}".format(dist.get_rank()))
-        file_list = fleet.util.get_file_shard(file_list)
-    logger.info("File list: {}".format(file_list))
+        all_files = fleet.util.get_file_shard(all_files)
+    logger.info("File list: {}".format(all_files))
 
     base_url = f'{configs["fs.default.name"]}'
-    full_paths = [base_url + file for file in file_list]
+    full_paths = [base_url + file for file in all_files]
 
     return full_paths