file.py 1.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. import os
  2. from pathlib import Path
  3. from typing import Union
  4. AUDIO_EXTENSIONS = {
  5. ".mp3",
  6. ".wav",
  7. ".flac",
  8. ".ogg",
  9. ".m4a",
  10. ".wma",
  11. ".aac",
  12. ".aiff",
  13. ".aif",
  14. ".aifc",
  15. }
  16. def list_files(
  17. path: Union[Path, str],
  18. extensions: set[str] = None,
  19. recursive: bool = False,
  20. sort: bool = True,
  21. ) -> list[Path]:
  22. """List files in a directory.
  23. Args:
  24. path (Path): Path to the directory.
  25. extensions (set, optional): Extensions to filter. Defaults to None.
  26. recursive (bool, optional): Whether to search recursively. Defaults to False.
  27. sort (bool, optional): Whether to sort the files. Defaults to True.
  28. Returns:
  29. list: List of files.
  30. """
  31. if isinstance(path, str):
  32. path = Path(path)
  33. if not path.exists():
  34. raise FileNotFoundError(f"Directory {path} does not exist.")
  35. files = (
  36. [
  37. Path(os.path.join(root, filename))
  38. for root, _, filenames in os.walk(path, followlinks=True)
  39. for filename in filenames
  40. if Path(os.path.join(root, filename)).is_file()
  41. ]
  42. if recursive
  43. else [f for f in path.glob("*") if f.is_file()]
  44. )
  45. if extensions is not None:
  46. files = [f for f in files if f.suffix in extensions]
  47. if sort:
  48. files = sorted(files)
  49. return files
  50. def get_latest_checkpoint(path: Path | str) -> Path | None:
  51. # Find the latest checkpoint
  52. ckpt_dir = Path(path)
  53. if ckpt_dir.exists() is False:
  54. return None
  55. ckpts = sorted(ckpt_dir.glob("*.ckpt"), key=os.path.getmtime)
  56. if len(ckpts) == 0:
  57. return None
  58. return ckpts[-1]