file.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
  1. import base64
  2. from pathlib import Path
  3. from typing import Union
  4. from loguru import logger
  5. from natsort import natsorted
  6. AUDIO_EXTENSIONS = {
  7. ".mp3",
  8. ".wav",
  9. ".flac",
  10. ".ogg",
  11. ".m4a",
  12. ".wma",
  13. ".aac",
  14. ".aiff",
  15. ".aif",
  16. ".aifc",
  17. }
  18. VIDEO_EXTENSIONS = {
  19. ".mp4",
  20. ".avi",
  21. }
  22. def audio_to_bytes(file_path):
  23. if not file_path or not Path(file_path).exists():
  24. return None
  25. with open(file_path, "rb") as wav_file:
  26. wav = wav_file.read()
  27. return wav
  28. def read_ref_text(ref_text):
  29. path = Path(ref_text)
  30. if path.exists() and path.is_file():
  31. with path.open("r", encoding="utf-8") as file:
  32. return file.read()
  33. return ref_text
  34. def list_files(
  35. path: Union[Path, str],
  36. extensions: set[str] = None,
  37. recursive: bool = False,
  38. sort: bool = True,
  39. ) -> list[Path]:
  40. """List files in a directory.
  41. Args:
  42. path (Path): Path to the directory.
  43. extensions (set, optional): Extensions to filter. Defaults to None.
  44. recursive (bool, optional): Whether to search recursively. Defaults to False.
  45. sort (bool, optional): Whether to sort the files. Defaults to True.
  46. Returns:
  47. list: List of files.
  48. """
  49. if isinstance(path, str):
  50. path = Path(path)
  51. if not path.exists():
  52. raise FileNotFoundError(f"Directory {path} does not exist.")
  53. files = [file for ext in extensions for file in path.rglob(f"*{ext}")]
  54. if sort:
  55. files = natsorted(files)
  56. return files
  57. def load_filelist(path: Path | str) -> list[tuple[Path, str, str, str]]:
  58. """
  59. Load a Bert-VITS2 style filelist.
  60. """
  61. files = set()
  62. results = []
  63. count_duplicated, count_not_found = 0, 0
  64. LANGUAGE_TO_LANGUAGES = {
  65. "zh": ["zh", "en"],
  66. "jp": ["jp", "en"],
  67. "en": ["en"],
  68. }
  69. with open(path, "r", encoding="utf-8") as f:
  70. for line in f.readlines():
  71. splits = line.strip().split("|", maxsplit=3)
  72. if len(splits) != 4:
  73. logger.warning(f"Invalid line: {line}")
  74. continue
  75. filename, speaker, language, text = splits
  76. file = Path(filename)
  77. language = language.strip().lower()
  78. if language == "ja":
  79. language = "jp"
  80. assert language in ["zh", "jp", "en"], f"Invalid language {language}"
  81. languages = LANGUAGE_TO_LANGUAGES[language]
  82. if file in files:
  83. logger.warning(f"Duplicated file: {file}")
  84. count_duplicated += 1
  85. continue
  86. if not file.exists():
  87. logger.warning(f"File not found: {file}")
  88. count_not_found += 1
  89. continue
  90. results.append((file, speaker, languages, text))
  91. if count_duplicated > 0:
  92. logger.warning(f"Total duplicated files: {count_duplicated}")
  93. if count_not_found > 0:
  94. logger.warning(f"Total files not found: {count_not_found}")
  95. return results