file.py 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. import os
  2. from glob import glob
  3. from pathlib import Path
  4. from typing import Union
  5. from loguru import logger
  6. AUDIO_EXTENSIONS = {
  7. ".mp3",
  8. ".wav",
  9. ".flac",
  10. ".ogg",
  11. ".m4a",
  12. ".wma",
  13. ".aac",
  14. ".aiff",
  15. ".aif",
  16. ".aifc",
  17. }
  18. def list_files(
  19. path: Union[Path, str],
  20. extensions: set[str] = None,
  21. recursive: bool = False,
  22. sort: bool = True,
  23. ) -> list[Path]:
  24. """List files in a directory.
  25. Args:
  26. path (Path): Path to the directory.
  27. extensions (set, optional): Extensions to filter. Defaults to None.
  28. recursive (bool, optional): Whether to search recursively. Defaults to False.
  29. sort (bool, optional): Whether to sort the files. Defaults to True.
  30. Returns:
  31. list: List of files.
  32. """
  33. if isinstance(path, str):
  34. path = Path(path)
  35. if not path.exists():
  36. raise FileNotFoundError(f"Directory {path} does not exist.")
  37. files = [Path(f) for f in glob(str(path / "**/*"), recursive=recursive)]
  38. if extensions is not None:
  39. files = [f for f in files if f.suffix in extensions]
  40. if sort:
  41. files = sorted(files)
  42. return files
  43. def get_latest_checkpoint(path: Path | str) -> Path | None:
  44. # Find the latest checkpoint
  45. ckpt_dir = Path(path)
  46. if ckpt_dir.exists() is False:
  47. return None
  48. ckpts = sorted(ckpt_dir.glob("*.ckpt"), key=os.path.getmtime)
  49. if len(ckpts) == 0:
  50. return None
  51. return ckpts[-1]
  52. def load_filelist(path: Path | str) -> list[tuple[Path, str, str, str]]:
  53. """
  54. Load a Bert-VITS2 style filelist.
  55. """
  56. files = set()
  57. results = []
  58. count_duplicated, count_not_found = 0, 0
  59. LANGUAGE_TO_LANGUAGES = {
  60. "zh": ["zh", "en"],
  61. "jp": ["jp", "en"],
  62. "en": ["en"],
  63. }
  64. with open(path, "r", encoding="utf-8") as f:
  65. for line in f.readlines():
  66. filename, speaker, language, text = line.strip().split("|")
  67. file = Path(filename)
  68. language = language.strip().lower()
  69. if language == "ja":
  70. language = "jp"
  71. assert language in ["zh", "jp", "en"], f"Invalid language {language}"
  72. languages = LANGUAGE_TO_LANGUAGES[language]
  73. if file in files:
  74. logger.warning(f"Duplicated file: {file}")
  75. count_duplicated += 1
  76. continue
  77. if not file.exists():
  78. logger.warning(f"File not found: {file}")
  79. count_not_found += 1
  80. continue
  81. results.append((file, speaker, languages, text))
  82. if count_duplicated > 0:
  83. logger.warning(f"Total duplicated files: {count_duplicated}")
  84. if count_not_found > 0:
  85. logger.warning(f"Total files not found: {count_not_found}")
  86. return results