file.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. import os
  2. from pathlib import Path
  3. from typing import Union
  4. from loguru import logger
  5. AUDIO_EXTENSIONS = {
  6. ".mp3",
  7. ".wav",
  8. ".flac",
  9. ".ogg",
  10. ".m4a",
  11. ".wma",
  12. ".aac",
  13. ".aiff",
  14. ".aif",
  15. ".aifc",
  16. }
  17. def list_files(
  18. path: Union[Path, str],
  19. extensions: set[str] = None,
  20. recursive: bool = False,
  21. sort: bool = True,
  22. ) -> list[Path]:
  23. """List files in a directory.
  24. Args:
  25. path (Path): Path to the directory.
  26. extensions (set, optional): Extensions to filter. Defaults to None.
  27. recursive (bool, optional): Whether to search recursively. Defaults to False.
  28. sort (bool, optional): Whether to sort the files. Defaults to True.
  29. Returns:
  30. list: List of files.
  31. """
  32. if isinstance(path, str):
  33. path = Path(path)
  34. if not path.exists():
  35. raise FileNotFoundError(f"Directory {path} does not exist.")
  36. files = (
  37. [
  38. Path(os.path.join(root, filename))
  39. for root, _, filenames in os.walk(path, followlinks=True)
  40. for filename in filenames
  41. if Path(os.path.join(root, filename)).is_file()
  42. ]
  43. if recursive
  44. else [f for f in path.glob("*") if f.is_file()]
  45. )
  46. if extensions is not None:
  47. files = [f for f in files if f.suffix in extensions]
  48. if sort:
  49. files = sorted(files)
  50. return files
  51. def get_latest_checkpoint(path: Path | str) -> Path | None:
  52. # Find the latest checkpoint
  53. ckpt_dir = Path(path)
  54. if ckpt_dir.exists() is False:
  55. return None
  56. ckpts = sorted(ckpt_dir.glob("*.ckpt"), key=os.path.getmtime)
  57. if len(ckpts) == 0:
  58. return None
  59. return ckpts[-1]
  60. def load_filelist(path: Path | str) -> list[tuple[Path, str, str, str]]:
  61. """
  62. Load a Bert-VITS2 style filelist.
  63. """
  64. files = set()
  65. results = []
  66. count_duplicated, count_not_found = 0, 0
  67. LANGUAGE_TO_LANGUAGES = {
  68. "zh": ["zh", "en"],
  69. "jp": ["jp", "en"],
  70. "en": ["en"],
  71. }
  72. with open(path, "r", encoding="utf-8") as f:
  73. for line in f.readlines():
  74. filename, speaker, language, text = line.strip().split("|")
  75. file = Path(filename)
  76. language = language.strip().lower()
  77. if language == "ja":
  78. language = "jp"
  79. assert language in ["zh", "jp", "en"], f"Invalid language {language}"
  80. languages = LANGUAGE_TO_LANGUAGES[language]
  81. if file in files:
  82. logger.warning(f"Duplicated file: {file}")
  83. count_duplicated += 1
  84. continue
  85. if not file.exists():
  86. logger.warning(f"File not found: {file}")
  87. count_not_found += 1
  88. continue
  89. results.append((file, speaker, languages, text))
  90. if count_duplicated > 0:
  91. logger.warning(f"Total duplicated files: {count_duplicated}")
  92. if count_not_found > 0:
  93. logger.warning(f"Total files not found: {count_not_found}")
  94. return results