file.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. import os
  2. from pathlib import Path
  3. from typing import Union
  4. from loguru import logger
  5. from natsort import natsorted
  6. AUDIO_EXTENSIONS = {
  7. ".mp3",
  8. ".wav",
  9. ".flac",
  10. ".ogg",
  11. ".m4a",
  12. ".wma",
  13. ".aac",
  14. ".aiff",
  15. ".aif",
  16. ".aifc",
  17. }
  18. VIDEO_EXTENSIONS = {
  19. ".mp4",
  20. ".avi",
  21. }
  22. def get_latest_checkpoint(path: Path | str) -> Path | None:
  23. # Find the latest checkpoint
  24. ckpt_dir = Path(path)
  25. if ckpt_dir.exists() is False:
  26. return None
  27. ckpts = sorted(ckpt_dir.glob("*.ckpt"), key=os.path.getmtime)
  28. if len(ckpts) == 0:
  29. return None
  30. return ckpts[-1]
  31. def audio_to_bytes(file_path):
  32. if not file_path or not Path(file_path).exists():
  33. return None
  34. with open(file_path, "rb") as wav_file:
  35. wav = wav_file.read()
  36. return wav
  37. def read_ref_text(ref_text):
  38. path = Path(ref_text)
  39. if path.exists() and path.is_file():
  40. with path.open("r", encoding="utf-8") as file:
  41. return file.read()
  42. return ref_text
  43. def list_files(
  44. path: Union[Path, str],
  45. extensions: set[str] = set(),
  46. recursive: bool = False,
  47. sort: bool = True,
  48. ) -> list[Path]:
  49. """List files in a directory.
  50. Args:
  51. path (Path): Path to the directory.
  52. extensions (set, optional): Extensions to filter. Defaults to None.
  53. recursive (bool, optional): Whether to search recursively. Defaults to False.
  54. sort (bool, optional): Whether to sort the files. Defaults to True.
  55. Returns:
  56. list: List of files.
  57. """
  58. if isinstance(path, str):
  59. path = Path(path)
  60. if not path.exists():
  61. raise FileNotFoundError(f"Directory {path} does not exist.")
  62. files = [file for ext in extensions for file in path.rglob(f"*{ext}")]
  63. if sort:
  64. files = natsorted(files)
  65. return files
  66. def load_filelist(path: Path | str) -> list[tuple[Path, str, str, str]]:
  67. """
  68. Load a Bert-VITS2 style filelist.
  69. """
  70. files = set()
  71. results = []
  72. count_duplicated, count_not_found = 0, 0
  73. LANGUAGE_TO_LANGUAGES = {
  74. "zh": ["zh", "en"],
  75. "jp": ["jp", "en"],
  76. "en": ["en"],
  77. }
  78. with open(path, "r", encoding="utf-8") as f:
  79. for line in f.readlines():
  80. splits = line.strip().split("|", maxsplit=3)
  81. if len(splits) != 4:
  82. logger.warning(f"Invalid line: {line}")
  83. continue
  84. filename, speaker, language, text = splits
  85. file = Path(filename)
  86. language = language.strip().lower()
  87. if language == "ja":
  88. language = "jp"
  89. assert language in ["zh", "jp", "en"], f"Invalid language {language}"
  90. languages = LANGUAGE_TO_LANGUAGES[language]
  91. if file in files:
  92. logger.warning(f"Duplicated file: {file}")
  93. count_duplicated += 1
  94. continue
  95. if not file.exists():
  96. logger.warning(f"File not found: {file}")
  97. count_not_found += 1
  98. continue
  99. results.append((file, speaker, languages, text))
  100. if count_duplicated > 0:
  101. logger.warning(f"Total duplicated files: {count_duplicated}")
  102. if count_not_found > 0:
  103. logger.warning(f"Total files not found: {count_not_found}")
  104. return results