parser.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233
  1. import itertools
  2. import re
  3. import string
  4. from typing import Optional
  5. from fish_speech.text.chinese import g2p as g2p_chinese
  6. from fish_speech.text.english import g2p as g2p_english
  7. from fish_speech.text.japanese import g2p as g2p_japanese
  8. from fish_speech.text.symbols import (
  9. language_id_map,
  10. language_unicode_range_map,
  11. punctuation,
  12. symbols_to_id,
  13. )
  14. LANGUAGE_TO_MODULE_MAP = {
  15. "ZH": g2p_chinese,
  16. "EN": g2p_english,
  17. "JP": g2p_japanese,
  18. }
  19. # This files is designed to parse the text, doing some normalization,
  20. # and return an annotated text.
  21. # Example: 1, 2, <JP>3</JP>, 4, <ZH>5</ZH>
  22. # For better compatibility, we also support tree-like structure, like:
  23. # 1, 2, <JP>3, <EN>4</EN></JP>, 5, <ZH>6</ZH>
  24. class Segment:
  25. def __init__(self, text, language=None, phones=None):
  26. self.text = text
  27. self.language = language.upper() if language is not None else None
  28. if phones is None:
  29. phones = LANGUAGE_TO_MODULE_MAP[self.language](self.text)
  30. self.phones = phones
  31. def __repr__(self):
  32. return f"<Segment {self.language}: '{self.text}' -> '{' '.join(self.phones)}'>"
  33. def __str__(self):
  34. return self.text
  35. SYMBOLS_MAPPING = {
  36. ":": ",",
  37. ";": ",",
  38. ",": ",",
  39. "。": ".",
  40. "!": "!",
  41. "?": "?",
  42. "\n": ".",
  43. "·": ",",
  44. "、": ",",
  45. "...": "…",
  46. "$": ".",
  47. "“": "'",
  48. "”": "'",
  49. "‘": "'",
  50. "’": "'",
  51. "(": "'",
  52. ")": "'",
  53. "(": "'",
  54. ")": "'",
  55. "《": "'",
  56. "》": "'",
  57. "【": "'",
  58. "】": "'",
  59. "[": "'",
  60. "]": "'",
  61. "—": "-",
  62. "~": "-",
  63. "~": "-",
  64. "・": "-",
  65. "「": "'",
  66. "」": "'",
  67. ";": ",",
  68. ":": ",",
  69. }
  70. REPLACE_SYMBOL_REGEX = re.compile(
  71. "|".join(re.escape(p) for p in SYMBOLS_MAPPING.keys())
  72. )
  73. ALL_KNOWN_UTF8_RANGE = list(
  74. itertools.chain.from_iterable(language_unicode_range_map.values())
  75. )
  76. REMOVE_UNKNOWN_SYMBOL_REGEX = re.compile(
  77. "[^"
  78. + "".join(
  79. f"{re.escape(chr(start))}-{re.escape(chr(end))}"
  80. for start, end in ALL_KNOWN_UTF8_RANGE
  81. )
  82. + "]"
  83. )
  84. def clean_text(text):
  85. # Clean the text
  86. text = text.strip()
  87. # Replace all chinese symbols with their english counterparts
  88. text = REPLACE_SYMBOL_REGEX.sub(lambda x: SYMBOLS_MAPPING[x.group()], text)
  89. text = REMOVE_UNKNOWN_SYMBOL_REGEX.sub("", text)
  90. return text
  91. def parse_text_to_segments(text, order=None):
  92. """
  93. Parse the text and return a list of segments.
  94. :param text: The text to be parsed.
  95. :param order: The order of languages. If None, use ["ZH", "JP", "EN"].
  96. :return: A list of segments.
  97. """
  98. if order is None:
  99. order = ["ZH", "JP", "EN"]
  100. order = [language.upper() for language in order]
  101. assert all(language in language_id_map for language in order)
  102. text = clean_text(text)
  103. texts = re.split(r"(<.*?>)", text)
  104. texts = [text for text in texts if text.strip() != ""]
  105. stack = []
  106. segments = []
  107. for text in texts:
  108. if text.startswith("<") and text.endswith(">") and text[1] != "/":
  109. current_language = text[1:-1]
  110. # The following line should be updated later
  111. assert (
  112. current_language.upper() in language_id_map
  113. ), f"Unknown language: {current_language}"
  114. stack.append(current_language)
  115. elif text.startswith("</") and text.endswith(">"):
  116. language = stack.pop()
  117. if language != text[2:-1]:
  118. raise ValueError(f"Language mismatch: {language} != {text[2:-1]}")
  119. elif stack:
  120. segments.append(Segment(text, stack[-1]))
  121. else:
  122. segments.extend(
  123. [i for i in parse_unknown_segment(text, order) if len(i.phones) > 0]
  124. )
  125. return segments
  126. def parse_unknown_segment(text, order):
  127. last_idx, last_language = 0, None
  128. for idx, char in enumerate(text):
  129. if char == " " or char in punctuation or char in string.digits:
  130. # If the punctuation / number is in the middle of the text,
  131. # we should not split the text.
  132. detected_language = last_language or order[0]
  133. else:
  134. detected_language = None
  135. _order = order.copy()
  136. if last_language is not None:
  137. # Prioritize the last language
  138. _order.remove(last_language)
  139. _order.insert(0, last_language)
  140. for language in _order:
  141. for start, end in language_unicode_range_map[language]:
  142. if start <= ord(char) <= end:
  143. detected_language = language
  144. break
  145. if detected_language is not None:
  146. break
  147. assert (
  148. detected_language is not None
  149. ), f"Incorrect language: {char}, clean before calling this function."
  150. if last_language is None:
  151. last_language = detected_language
  152. if detected_language != last_language:
  153. yield Segment(text[last_idx:idx], last_language)
  154. last_idx = idx
  155. last_language = detected_language
  156. if last_idx != len(text):
  157. yield Segment(text[last_idx:], last_language)
  158. def segments_to_phones(
  159. segments: list[Segment],
  160. ) -> tuple[tuple[Optional[str], str], list[int]]:
  161. phones = []
  162. ids = []
  163. for segment in segments:
  164. for phone in segment.phones:
  165. if phone.strip() == "":
  166. continue
  167. q0 = (segment.language, phone)
  168. q1 = (None, phone)
  169. if q0 in symbols_to_id:
  170. phones.append(q0)
  171. ids.append(symbols_to_id[q0])
  172. elif q1 in symbols_to_id:
  173. phones.append(q1)
  174. ids.append(symbols_to_id[q1])
  175. else:
  176. raise ValueError(f"Unknown phone: {segment.language} - {phone} -")
  177. return phones, ids
  178. def g2p(text, order=None):
  179. segments = parse_text_to_segments(text, order=order)
  180. phones, _ = segments_to_phones(segments)
  181. return phones
  182. if __name__ == "__main__":
  183. segments = parse_text_to_segments(
  184. "测试一下 Hugging face, BGM声音很大吗?那我改一下. <jp>世界、こんにちは。</jp>" # noqa: E501
  185. )
  186. print(segments)
  187. segments = parse_text_to_segments(
  188. "测试一下 Hugging face, BGM声音很大吗?那我改一下. 世界、こんにちは。" # noqa: E501
  189. )
  190. print(segments)