parser.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239
  1. import itertools
  2. import re
  3. import string
  4. from typing import Optional
  5. from fish_speech.text.chinese import g2p as g2p_chinese
  6. from fish_speech.text.english import g2p as g2p_english
  7. from fish_speech.text.japanese import g2p as g2p_japanese
  8. from fish_speech.text.symbols import (
  9. language_id_map,
  10. language_unicode_range_map,
  11. punctuation,
  12. symbols_to_id,
  13. )
  14. LANGUAGE_TO_MODULE_MAP = {
  15. "ZH": g2p_chinese,
  16. "EN": g2p_english,
  17. "JP": g2p_japanese,
  18. }
  19. # This files is designed to parse the text, doing some normalization,
  20. # and return an annotated text.
  21. # Example: 1, 2, <JP>3</JP>, 4, <ZH>5</ZH>
  22. # For better compatibility, we also support tree-like structure, like:
  23. # 1, 2, <JP>3, <EN>4</EN></JP>, 5, <ZH>6</ZH>
  24. class Segment:
  25. def __init__(self, text, language=None, phones=None):
  26. self.text = text
  27. self.language = language.upper() if language is not None else None
  28. if phones is None:
  29. phones = LANGUAGE_TO_MODULE_MAP[self.language](self.text)
  30. self.phones = phones
  31. def __repr__(self):
  32. return f"<Segment {self.language}: '{self.text}' -> '{' '.join(self.phones)}'>"
  33. def __str__(self):
  34. return self.text
  35. SYMBOLS_MAPPING = {
  36. ":": ",",
  37. ";": ",",
  38. ",": ",",
  39. "。": ".",
  40. "!": "!",
  41. "?": "?",
  42. "\n": ".",
  43. "·": ",",
  44. "、": ",",
  45. "...": "…",
  46. "$": ".",
  47. "“": "'",
  48. "”": "'",
  49. "‘": "'",
  50. "’": "'",
  51. "(": "'",
  52. ")": "'",
  53. "(": "'",
  54. ")": "'",
  55. "《": "'",
  56. "》": "'",
  57. "【": "'",
  58. "】": "'",
  59. "[": "'",
  60. "]": "'",
  61. "—": "-",
  62. "~": "-",
  63. "~": "-",
  64. "・": "-",
  65. "「": "'",
  66. "」": "'",
  67. ";": ",",
  68. ":": ",",
  69. }
  70. REPLACE_SYMBOL_REGEX = re.compile(
  71. "|".join(re.escape(p) for p in SYMBOLS_MAPPING.keys())
  72. )
  73. ALL_KNOWN_UTF8_RANGE = list(
  74. itertools.chain.from_iterable(language_unicode_range_map.values())
  75. )
  76. REMOVE_UNKNOWN_SYMBOL_REGEX = re.compile(
  77. "[^"
  78. + "".join(
  79. f"{re.escape(chr(start))}-{re.escape(chr(end))}"
  80. for start, end in ALL_KNOWN_UTF8_RANGE
  81. )
  82. + "]"
  83. )
  84. def clean_text(text):
  85. # Clean the text
  86. text = text.strip()
  87. # Replace <p:(.*?)> with <PPP(.*?)PPP>
  88. text = re.sub(r"<p:(.*?)>", r"<PPP\1PPP>", text)
  89. # Replace all chinese symbols with their english counterparts
  90. text = REPLACE_SYMBOL_REGEX.sub(lambda x: SYMBOLS_MAPPING[x.group()], text)
  91. text = REMOVE_UNKNOWN_SYMBOL_REGEX.sub("", text)
  92. # Replace <PPP(.*?)PPP> with <p:(.*?)>
  93. text = re.sub(r"<PPP(.*?)PPP>", r"<p:\1>", text)
  94. return text
  95. def parse_text_to_segments(text, order=None):
  96. """
  97. Parse the text and return a list of segments.
  98. :param text: The text to be parsed.
  99. :param order: The order of languages. If None, use ["ZH", "JP", "EN"].
  100. :return: A list of segments.
  101. """
  102. if order is None:
  103. order = ["ZH", "JP", "EN"]
  104. order = [language.upper() for language in order]
  105. assert all(language in language_id_map for language in order)
  106. text = clean_text(text)
  107. texts = re.split(r"(<.*?>)", text)
  108. texts = [text for text in texts if text.strip() != ""]
  109. stack = []
  110. segments = []
  111. for text in texts:
  112. if text.startswith("<") and text.endswith(">") and text[1] != "/":
  113. current_language = text[1:-1]
  114. # The following line should be updated later
  115. assert (
  116. current_language.upper() in language_id_map
  117. ), f"Unknown language: {current_language}"
  118. stack.append(current_language)
  119. elif text.startswith("</") and text.endswith(">"):
  120. language = stack.pop()
  121. if language != text[2:-1]:
  122. raise ValueError(f"Language mismatch: {language} != {text[2:-1]}")
  123. elif stack:
  124. segments.append(Segment(text, stack[-1]))
  125. else:
  126. segments.extend(
  127. [i for i in parse_unknown_segment(text, order) if len(i.phones) > 0]
  128. )
  129. return segments
  130. def parse_unknown_segment(text, order):
  131. last_idx, last_language = 0, None
  132. for idx, char in enumerate(text):
  133. if char == " " or char in punctuation or char in string.digits:
  134. # If the punctuation / number is in the middle of the text,
  135. # we should not split the text.
  136. detected_language = last_language or order[0]
  137. else:
  138. detected_language = None
  139. _order = order.copy()
  140. if last_language is not None:
  141. # Prioritize the last language
  142. _order.remove(last_language)
  143. _order.insert(0, last_language)
  144. for language in _order:
  145. for start, end in language_unicode_range_map[language]:
  146. if start <= ord(char) <= end:
  147. detected_language = language
  148. break
  149. if detected_language is not None:
  150. break
  151. assert (
  152. detected_language is not None
  153. ), f"Incorrect language: {char}, clean before calling this function."
  154. if last_language is None:
  155. last_language = detected_language
  156. if detected_language != last_language:
  157. yield Segment(text[last_idx:idx], last_language)
  158. last_idx = idx
  159. last_language = detected_language
  160. if last_idx != len(text):
  161. yield Segment(text[last_idx:], last_language)
  162. def segments_to_phones(
  163. segments: list[Segment],
  164. ) -> tuple[tuple[Optional[str], str], list[int]]:
  165. phones = []
  166. ids = []
  167. for segment in segments:
  168. for phone in segment.phones:
  169. if phone.strip() == "":
  170. continue
  171. q0 = (segment.language, phone)
  172. q1 = (None, phone)
  173. if q0 in symbols_to_id:
  174. phones.append(q0)
  175. ids.append(symbols_to_id[q0])
  176. elif q1 in symbols_to_id:
  177. phones.append(q1)
  178. ids.append(symbols_to_id[q1])
  179. else:
  180. raise ValueError(f"Unknown phone: {segment.language} - {phone} -")
  181. return phones, ids
  182. def g2p(text, order=None):
  183. segments = parse_text_to_segments(text, order=order)
  184. phones, _ = segments_to_phones(segments)
  185. return phones
  186. if __name__ == "__main__":
  187. segments = parse_text_to_segments(
  188. "测试一下 Hugging face, BGM声音很大吗?那我改一下. <jp>世界、こんにちは。</jp>" # noqa: E501
  189. )
  190. print(segments)
  191. segments = parse_text_to_segments(
  192. "测试一下 Hugging face, BGM声音很大吗?那我改一下. 世界、こんにちは。" # noqa: E501
  193. )
  194. print(segments)
  195. print(clean_text("测试一下 Hugging face, BGM声音很大吗?那我改一下. 世界、こんにちは。<p:123> <p:aH>"))