parser.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220
  1. import itertools
  2. import re
  3. import string
  4. from typing import Optional
  5. from fish_speech.text.chinese import g2p as g2p_chinese
  6. from fish_speech.text.english import g2p as g2p_english
  7. from fish_speech.text.japanese import g2p as g2p_japanese
  8. from fish_speech.text.symbols import (
  9. language_id_map,
  10. language_unicode_range_map,
  11. punctuation,
  12. symbols_to_id,
  13. )
  14. LANGUAGE_TO_MODULE_MAP = {
  15. "ZH": g2p_chinese,
  16. "EN": g2p_english,
  17. "JP": g2p_japanese,
  18. }
  19. # This files is designed to parse the text, doing some normalization,
  20. # and return an annotated text.
  21. # Example: 1, 2, <JP>3</JP>, 4, <ZH>5</ZH>
  22. # For better compatibility, we also support tree-like structure, like:
  23. # 1, 2, <JP>3, <EN>4</EN></JP>, 5, <ZH>6</ZH>
  24. class Segment:
  25. def __init__(self, text, language=None, phones=None):
  26. self.text = text
  27. self.language = language.upper() if language is not None else None
  28. if phones is None:
  29. phones = LANGUAGE_TO_MODULE_MAP[self.language](self.text)
  30. self.phones = phones
  31. def __repr__(self):
  32. return f"<Segment {self.language}: '{self.text}' -> '{' '.join(self.phones)}'>"
  33. def __str__(self):
  34. return self.text
  35. SYMBOLS_MAPPING = {
  36. ":": ",",
  37. ";": ",",
  38. ",": ",",
  39. "。": ".",
  40. "!": "!",
  41. "?": "?",
  42. "\n": ".",
  43. "·": ",",
  44. "、": ",",
  45. "...": "…",
  46. "$": ".",
  47. "“": "'",
  48. "”": "'",
  49. "‘": "'",
  50. "’": "'",
  51. "(": "'",
  52. ")": "'",
  53. "(": "'",
  54. ")": "'",
  55. "《": "'",
  56. "》": "'",
  57. "【": "'",
  58. "】": "'",
  59. "[": "'",
  60. "]": "'",
  61. "—": "-",
  62. "~": "-",
  63. "~": "-",
  64. "「": "'",
  65. "」": "'",
  66. ";": ",",
  67. ":": ",",
  68. }
  69. REPLACE_SYMBOL_REGEX = re.compile(
  70. "|".join(re.escape(p) for p in SYMBOLS_MAPPING.keys())
  71. )
  72. ALL_KNOWN_UTF8_RANGE = list(
  73. itertools.chain.from_iterable(language_unicode_range_map.values())
  74. )
  75. REMOVE_UNKNOWN_SYMBOL_REGEX = re.compile(
  76. "[^"
  77. + "".join(
  78. f"{re.escape(chr(start))}-{re.escape(chr(end))}"
  79. for start, end in ALL_KNOWN_UTF8_RANGE
  80. )
  81. + "]"
  82. )
  83. def parse_text_to_segments(text, order=None):
  84. """
  85. Parse the text and return a list of segments.
  86. :param text: The text to be parsed.
  87. :param order: The order of languages. If None, use ["ZH", "JP", "EN"].
  88. :return: A list of segments.
  89. """
  90. if order is None:
  91. order = ["ZH", "JP", "EN"]
  92. order = [language.upper() for language in order]
  93. assert all(language in language_id_map for language in order)
  94. # Clean the text
  95. text = text.strip()
  96. # Replace all chinese symbols with their english counterparts
  97. text = REPLACE_SYMBOL_REGEX.sub(lambda x: SYMBOLS_MAPPING[x.group()], text)
  98. text = REMOVE_UNKNOWN_SYMBOL_REGEX.sub("", text)
  99. texts = re.split(r"(<.*?>)", text)
  100. texts = [text for text in texts if text.strip() != ""]
  101. stack = []
  102. segments = []
  103. for text in texts:
  104. if text.startswith("<") and text.endswith(">") and text[1] != "/":
  105. current_language = text[1:-1]
  106. # The following line should be updated later
  107. assert current_language.upper() in language_id_map
  108. stack.append(current_language)
  109. elif text.startswith("</") and text.endswith(">"):
  110. language = stack.pop()
  111. if language != text[2:-1]:
  112. raise ValueError(f"Language mismatch: {language} != {text[2:-1]}")
  113. elif stack:
  114. segments.append(Segment(text, stack[-1]))
  115. else:
  116. segments.extend(parse_unknown_segment(text, order))
  117. return segments
  118. def parse_unknown_segment(text, order):
  119. last_idx, last_language = 0, None
  120. for idx, char in enumerate(text):
  121. if char in punctuation or char in string.digits:
  122. # If the punctuation / number is in the middle of the text,
  123. # we should not split the text.
  124. detected_language = last_language or order[0]
  125. else:
  126. detected_language = None
  127. _order = order.copy()
  128. if last_language is not None:
  129. # Prioritize the last language
  130. _order.remove(last_language)
  131. _order.insert(0, last_language)
  132. for language in _order:
  133. for start, end in language_unicode_range_map[language]:
  134. if start <= ord(char) <= end:
  135. detected_language = language
  136. break
  137. if detected_language is not None:
  138. break
  139. assert (
  140. detected_language is not None
  141. ), f"Incorrect language: {char}, clean before calling this function."
  142. if last_language is None:
  143. last_language = detected_language
  144. if detected_language != last_language:
  145. yield Segment(text[last_idx:idx], last_language)
  146. last_idx = idx
  147. last_language = detected_language
  148. if last_idx != len(text):
  149. yield Segment(text[last_idx:], last_language)
  150. def segments_to_phones(
  151. segments: list[Segment],
  152. ) -> tuple[tuple[Optional[str], str], list[int]]:
  153. phones = []
  154. ids = []
  155. for segment in segments:
  156. for phone in segment.phones:
  157. q0 = (segment.language, phone)
  158. q1 = (None, phone)
  159. if q0 in symbols_to_id:
  160. phones.append(q0)
  161. ids.append(symbols_to_id[q0])
  162. elif q1 in symbols_to_id:
  163. phones.append(q1)
  164. ids.append(symbols_to_id[q1])
  165. else:
  166. raise ValueError(f"Unknown phone: {segment.language} {phone}")
  167. return phones, ids
  168. def g2p(text, order=None):
  169. segments = parse_text_to_segments(text, order=order)
  170. _, phones = segments_to_phones(segments)
  171. return phones
  172. if __name__ == "__main__":
  173. segments = parse_text_to_segments(
  174. "毕业然后复活卡b站推荐bug<zh>加流量。<en>Hugging face, B GM</en>声音很大吗</zh>?那我改一下Ё。 <jp>君の虜になってしまえばきっと</jp>" # noqa: E501
  175. )
  176. print(segments)
  177. segments = parse_text_to_segments(
  178. "毕业然后复活卡b站推荐bug加流量。Hugging face, BGM 声音很大吗?那我改一下Ё。君の虜になってしまえばきっと" # noqa: E501
  179. )
  180. print(segments)