parser.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229
  1. import itertools
  2. import re
  3. import string
  4. from typing import Optional
  5. from fish_speech.text.chinese import g2p as g2p_chinese
  6. from fish_speech.text.english import g2p as g2p_english
  7. from fish_speech.text.japanese import g2p as g2p_japanese
  8. from fish_speech.text.symbols import (
  9. language_id_map,
  10. language_unicode_range_map,
  11. punctuation,
  12. symbols_to_id,
  13. )
  14. LANGUAGE_TO_MODULE_MAP = {
  15. "ZH": g2p_chinese,
  16. "EN": g2p_english,
  17. "JP": g2p_japanese,
  18. }
  19. # This files is designed to parse the text, doing some normalization,
  20. # and return an annotated text.
  21. # Example: 1, 2, <JP>3</JP>, 4, <ZH>5</ZH>
  22. # For better compatibility, we also support tree-like structure, like:
  23. # 1, 2, <JP>3, <EN>4</EN></JP>, 5, <ZH>6</ZH>
  24. class Segment:
  25. def __init__(self, text, language=None, phones=None):
  26. self.text = text
  27. self.language = language.upper() if language is not None else None
  28. if phones is None:
  29. phones = LANGUAGE_TO_MODULE_MAP[self.language](self.text)
  30. self.phones = phones
  31. def __repr__(self):
  32. return f"<Segment {self.language}: '{self.text}' -> '{' '.join(self.phones)}'>"
  33. def __str__(self):
  34. return self.text
  35. SYMBOLS_MAPPING = {
  36. ":": ",",
  37. ";": ",",
  38. ",": ",",
  39. "。": ".",
  40. "!": "!",
  41. "?": "?",
  42. "\n": ".",
  43. "·": ",",
  44. "、": ",",
  45. "...": "…",
  46. "$": ".",
  47. "“": "'",
  48. "”": "'",
  49. "‘": "'",
  50. "’": "'",
  51. "(": "'",
  52. ")": "'",
  53. "(": "'",
  54. ")": "'",
  55. "《": "'",
  56. "》": "'",
  57. "【": "'",
  58. "】": "'",
  59. "[": "'",
  60. "]": "'",
  61. "—": "-",
  62. "~": "-",
  63. "~": "-",
  64. "「": "'",
  65. "」": "'",
  66. ";": ",",
  67. ":": ",",
  68. }
  69. REPLACE_SYMBOL_REGEX = re.compile(
  70. "|".join(re.escape(p) for p in SYMBOLS_MAPPING.keys())
  71. )
  72. ALL_KNOWN_UTF8_RANGE = list(
  73. itertools.chain.from_iterable(language_unicode_range_map.values())
  74. )
  75. REMOVE_UNKNOWN_SYMBOL_REGEX = re.compile(
  76. "[^"
  77. + "".join(
  78. f"{re.escape(chr(start))}-{re.escape(chr(end))}"
  79. for start, end in ALL_KNOWN_UTF8_RANGE
  80. )
  81. + "]"
  82. )
  83. def clean_text(text):
  84. # Clean the text
  85. text = text.strip()
  86. # Replace all chinese symbols with their english counterparts
  87. text = REPLACE_SYMBOL_REGEX.sub(lambda x: SYMBOLS_MAPPING[x.group()], text)
  88. text = REMOVE_UNKNOWN_SYMBOL_REGEX.sub("", text)
  89. return text
  90. def parse_text_to_segments(text, order=None):
  91. """
  92. Parse the text and return a list of segments.
  93. :param text: The text to be parsed.
  94. :param order: The order of languages. If None, use ["ZH", "JP", "EN"].
  95. :return: A list of segments.
  96. """
  97. if order is None:
  98. order = ["ZH", "JP", "EN"]
  99. order = [language.upper() for language in order]
  100. assert all(language in language_id_map for language in order)
  101. text = clean_text(text)
  102. texts = re.split(r"(<.*?>)", text)
  103. texts = [text for text in texts if text.strip() != ""]
  104. stack = []
  105. segments = []
  106. for text in texts:
  107. if text.startswith("<") and text.endswith(">") and text[1] != "/":
  108. current_language = text[1:-1]
  109. # The following line should be updated later
  110. assert (
  111. current_language.upper() in language_id_map
  112. ), f"Unknown language: {current_language}"
  113. stack.append(current_language)
  114. elif text.startswith("</") and text.endswith(">"):
  115. language = stack.pop()
  116. if language != text[2:-1]:
  117. raise ValueError(f"Language mismatch: {language} != {text[2:-1]}")
  118. elif stack:
  119. segments.append(Segment(text, stack[-1]))
  120. else:
  121. segments.extend(
  122. [i for i in parse_unknown_segment(text, order) if len(i.phones) > 0]
  123. )
  124. return segments
  125. def parse_unknown_segment(text, order):
  126. last_idx, last_language = 0, None
  127. for idx, char in enumerate(text):
  128. if char == " " or char in punctuation or char in string.digits:
  129. # If the punctuation / number is in the middle of the text,
  130. # we should not split the text.
  131. detected_language = last_language or order[0]
  132. else:
  133. detected_language = None
  134. _order = order.copy()
  135. if last_language is not None:
  136. # Prioritize the last language
  137. _order.remove(last_language)
  138. _order.insert(0, last_language)
  139. for language in _order:
  140. for start, end in language_unicode_range_map[language]:
  141. if start <= ord(char) <= end:
  142. detected_language = language
  143. break
  144. if detected_language is not None:
  145. break
  146. assert (
  147. detected_language is not None
  148. ), f"Incorrect language: {char}, clean before calling this function."
  149. if last_language is None:
  150. last_language = detected_language
  151. if detected_language != last_language:
  152. yield Segment(text[last_idx:idx], last_language)
  153. last_idx = idx
  154. last_language = detected_language
  155. if last_idx != len(text):
  156. yield Segment(text[last_idx:], last_language)
  157. def segments_to_phones(
  158. segments: list[Segment],
  159. ) -> tuple[tuple[Optional[str], str], list[int]]:
  160. phones = []
  161. ids = []
  162. for segment in segments:
  163. for phone in segment.phones:
  164. q0 = (segment.language, phone)
  165. q1 = (None, phone)
  166. if q0 in symbols_to_id:
  167. phones.append(q0)
  168. ids.append(symbols_to_id[q0])
  169. elif q1 in symbols_to_id:
  170. phones.append(q1)
  171. ids.append(symbols_to_id[q1])
  172. else:
  173. raise ValueError(f"Unknown phone: {segment.language} {phone}")
  174. return phones, ids
  175. def g2p(text, order=None):
  176. segments = parse_text_to_segments(text, order=order)
  177. phones, _ = segments_to_phones(segments)
  178. return phones
  179. if __name__ == "__main__":
  180. segments = parse_text_to_segments(
  181. "测试一下 Hugging face, BGM声音很大吗?那我改一下. <jp>世界、こんにちは。</jp>" # noqa: E501
  182. )
  183. print(segments)
  184. segments = parse_text_to_segments(
  185. "测试一下 Hugging face, BGM声音很大吗?那我改一下. 世界、こんにちは。" # noqa: E501
  186. )
  187. print(segments)