file.py 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. import os
  2. from glob import glob
  3. from pathlib import Path
  4. from typing import Union
  5. from loguru import logger
  6. from natsort import natsorted
  7. AUDIO_EXTENSIONS = {
  8. ".mp3",
  9. ".wav",
  10. ".flac",
  11. ".ogg",
  12. ".m4a",
  13. ".wma",
  14. ".aac",
  15. ".aiff",
  16. ".aif",
  17. ".aifc",
  18. }
  19. def list_files(
  20. path: Union[Path, str],
  21. extensions: set[str] = None,
  22. recursive: bool = False,
  23. sort: bool = True,
  24. ) -> list[Path]:
  25. """List files in a directory.
  26. Args:
  27. path (Path): Path to the directory.
  28. extensions (set, optional): Extensions to filter. Defaults to None.
  29. recursive (bool, optional): Whether to search recursively. Defaults to False.
  30. sort (bool, optional): Whether to sort the files. Defaults to True.
  31. Returns:
  32. list: List of files.
  33. """
  34. if isinstance(path, str):
  35. path = Path(path)
  36. if not path.exists():
  37. raise FileNotFoundError(f"Directory {path} does not exist.")
  38. files = [file for ext in extensions for file in path.rglob(f"*{ext}")]
  39. if sort:
  40. files = natsorted(files)
  41. return files
  42. def get_latest_checkpoint(path: Path | str) -> Path | None:
  43. # Find the latest checkpoint
  44. ckpt_dir = Path(path)
  45. if ckpt_dir.exists() is False:
  46. return None
  47. ckpts = sorted(ckpt_dir.glob("*.ckpt"), key=os.path.getmtime)
  48. if len(ckpts) == 0:
  49. return None
  50. return ckpts[-1]
  51. def load_filelist(path: Path | str) -> list[tuple[Path, str, str, str]]:
  52. """
  53. Load a Bert-VITS2 style filelist.
  54. """
  55. files = set()
  56. results = []
  57. count_duplicated, count_not_found = 0, 0
  58. LANGUAGE_TO_LANGUAGES = {
  59. "zh": ["zh", "en"],
  60. "jp": ["jp", "en"],
  61. "en": ["en"],
  62. }
  63. with open(path, "r", encoding="utf-8") as f:
  64. for line in f.readlines():
  65. splits = line.strip().split("|", maxsplit=3)
  66. if len(splits) != 4:
  67. logger.warning(f"Invalid line: {line}")
  68. continue
  69. filename, speaker, language, text = splits
  70. file = Path(filename)
  71. language = language.strip().lower()
  72. if language == "ja":
  73. language = "jp"
  74. assert language in ["zh", "jp", "en"], f"Invalid language {language}"
  75. languages = LANGUAGE_TO_LANGUAGES[language]
  76. if file in files:
  77. logger.warning(f"Duplicated file: {file}")
  78. count_duplicated += 1
  79. continue
  80. if not file.exists():
  81. logger.warning(f"File not found: {file}")
  82. count_not_found += 1
  83. continue
  84. results.append((file, speaker, languages, text))
  85. if count_duplicated > 0:
  86. logger.warning(f"Total duplicated files: {count_duplicated}")
  87. if count_not_found > 0:
  88. logger.warning(f"Total files not found: {count_not_found}")
  89. return results