file.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. import os
  2. from glob import glob
  3. from pathlib import Path
  4. from typing import Union
  5. from loguru import logger
  6. from natsort import natsorted
  7. AUDIO_EXTENSIONS = {
  8. ".mp3",
  9. ".wav",
  10. ".flac",
  11. ".ogg",
  12. ".m4a",
  13. ".wma",
  14. ".aac",
  15. ".aiff",
  16. ".aif",
  17. ".aifc",
  18. }
  19. def list_files(
  20. path: Union[Path, str],
  21. extensions: set[str] = None,
  22. recursive: bool = False,
  23. sort: bool = True,
  24. ) -> list[Path]:
  25. """List files in a directory.
  26. Args:
  27. path (Path): Path to the directory.
  28. extensions (set, optional): Extensions to filter. Defaults to None.
  29. recursive (bool, optional): Whether to search recursively. Defaults to False.
  30. sort (bool, optional): Whether to sort the files. Defaults to True.
  31. Returns:
  32. list: List of files.
  33. """
  34. if isinstance(path, str):
  35. path = Path(path)
  36. if not path.exists():
  37. raise FileNotFoundError(f"Directory {path} does not exist.")
  38. files = [Path(f) for f in glob(str(path / "**/*"), recursive=recursive)]
  39. if extensions is not None:
  40. files = [f for f in files if f.suffix in extensions]
  41. if sort:
  42. files = natsorted(files)
  43. return files
  44. def get_latest_checkpoint(path: Path | str) -> Path | None:
  45. # Find the latest checkpoint
  46. ckpt_dir = Path(path)
  47. if ckpt_dir.exists() is False:
  48. return None
  49. ckpts = sorted(ckpt_dir.glob("*.ckpt"), key=os.path.getmtime)
  50. if len(ckpts) == 0:
  51. return None
  52. return ckpts[-1]
  53. def load_filelist(path: Path | str) -> list[tuple[Path, str, str, str]]:
  54. """
  55. Load a Bert-VITS2 style filelist.
  56. """
  57. files = set()
  58. results = []
  59. count_duplicated, count_not_found = 0, 0
  60. LANGUAGE_TO_LANGUAGES = {
  61. "zh": ["zh", "en"],
  62. "jp": ["jp", "en"],
  63. "en": ["en"],
  64. }
  65. with open(path, "r", encoding="utf-8") as f:
  66. for line in f.readlines():
  67. splits = line.strip().split("|", maxsplit=3)
  68. if len(splits) != 4:
  69. logger.warning(f"Invalid line: {line}")
  70. continue
  71. filename, speaker, language, text = splits
  72. file = Path(filename)
  73. language = language.strip().lower()
  74. if language == "ja":
  75. language = "jp"
  76. assert language in ["zh", "jp", "en"], f"Invalid language {language}"
  77. languages = LANGUAGE_TO_LANGUAGES[language]
  78. if file in files:
  79. logger.warning(f"Duplicated file: {file}")
  80. count_duplicated += 1
  81. continue
  82. if not file.exists():
  83. logger.warning(f"File not found: {file}")
  84. count_not_found += 1
  85. continue
  86. results.append((file, speaker, languages, text))
  87. if count_duplicated > 0:
  88. logger.warning(f"Total duplicated files: {count_duplicated}")
  89. if count_not_found > 0:
  90. logger.warning(f"Total files not found: {count_not_found}")
  91. return results