parser.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243
  1. import itertools
  2. import re
  3. import string
  4. from typing import Optional
  5. from loguru import logger
  6. from fish_speech.text.chinese import g2p as g2p_chinese
  7. from fish_speech.text.english import g2p as g2p_english
  8. from fish_speech.text.japanese import g2p as g2p_japanese
  9. from fish_speech.text.symbols import (
  10. language_id_map,
  11. language_unicode_range_map,
  12. punctuation,
  13. symbols_to_id,
  14. )
  15. LANGUAGE_TO_MODULE_MAP = {
  16. "ZH": g2p_chinese,
  17. "EN": g2p_english,
  18. "JP": g2p_japanese,
  19. }
  20. # This files is designed to parse the text, doing some normalization,
  21. # and return an annotated text.
  22. # Example: 1, 2, <JP>3</JP>, 4, <ZH>5</ZH>
  23. # For better compatibility, we also support tree-like structure, like:
  24. # 1, 2, <JP>3, <EN>4</EN></JP>, 5, <ZH>6</ZH>
  25. class Segment:
  26. def __init__(self, text, language=None, phones=None):
  27. self.text = text
  28. self.language = language.upper() if language is not None else None
  29. if phones is None:
  30. phones = LANGUAGE_TO_MODULE_MAP[self.language](self.text)
  31. self.phones = phones
  32. def __repr__(self):
  33. return f"<Segment {self.language}: '{self.text}' -> '{' '.join(self.phones)}'>"
  34. def __str__(self):
  35. return self.text
  36. SYMBOLS_MAPPING = {
  37. ":": ",",
  38. ";": ",",
  39. ",": ",",
  40. "。": ".",
  41. "!": "!",
  42. "?": "?",
  43. "\n": ".",
  44. "·": ",",
  45. "、": ",",
  46. "...": "…",
  47. "$": ".",
  48. "“": "'",
  49. "”": "'",
  50. "‘": "'",
  51. "’": "'",
  52. "(": "'",
  53. ")": "'",
  54. "(": "'",
  55. ")": "'",
  56. "《": "'",
  57. "》": "'",
  58. "【": "'",
  59. "】": "'",
  60. "[": "'",
  61. "]": "'",
  62. "—": "-",
  63. "~": "-",
  64. "~": "-",
  65. "・": "-",
  66. "「": "'",
  67. "」": "'",
  68. ";": ",",
  69. ":": ",",
  70. }
  71. REPLACE_SYMBOL_REGEX = re.compile(
  72. "|".join(re.escape(p) for p in SYMBOLS_MAPPING.keys())
  73. )
  74. ALL_KNOWN_UTF8_RANGE = list(
  75. itertools.chain.from_iterable(language_unicode_range_map.values())
  76. )
  77. REMOVE_UNKNOWN_SYMBOL_REGEX = re.compile(
  78. "[^"
  79. + "".join(
  80. f"{re.escape(chr(start))}-{re.escape(chr(end))}"
  81. for start, end in ALL_KNOWN_UTF8_RANGE
  82. )
  83. + "]"
  84. )
  85. def clean_text(text):
  86. # Clean the text
  87. text = text.strip()
  88. # Replace <p:(.*?)> with <PPP(.*?)PPP>
  89. text = re.sub(r"<p:(.*?)>", r"<PPP\1PPP>", text)
  90. # Replace all chinese symbols with their english counterparts
  91. text = REPLACE_SYMBOL_REGEX.sub(lambda x: SYMBOLS_MAPPING[x.group()], text)
  92. text = REMOVE_UNKNOWN_SYMBOL_REGEX.sub("", text)
  93. # Replace <PPP(.*?)PPP> with <p:(.*?)>
  94. text = re.sub(r"<PPP(.*?)PPP>", r"<p:\1>", text)
  95. return text
  96. def parse_text_to_segments(text, order=None):
  97. """
  98. Parse the text and return a list of segments.
  99. :param text: The text to be parsed.
  100. :param order: The order of languages. If None, use ["ZH", "JP", "EN"].
  101. :return: A list of segments.
  102. """
  103. if order is None:
  104. order = ["ZH", "JP", "EN"]
  105. order = [language.upper() for language in order]
  106. assert all(language in language_id_map for language in order)
  107. text = clean_text(text)
  108. texts = re.split(r"(<.*?>)", text)
  109. texts = [text for text in texts if text.strip() != ""]
  110. stack = []
  111. segments = []
  112. for text in texts:
  113. if text.startswith("<") and text.endswith(">") and text[1] != "/":
  114. current_language = text[1:-1]
  115. # The following line should be updated later
  116. assert (
  117. current_language.upper() in language_id_map
  118. ), f"Unknown language: {current_language}"
  119. stack.append(current_language)
  120. elif text.startswith("</") and text.endswith(">"):
  121. language = stack.pop()
  122. if language != text[2:-1]:
  123. raise ValueError(f"Language mismatch: {language} != {text[2:-1]}")
  124. elif stack:
  125. segments.append(Segment(text, stack[-1]))
  126. else:
  127. segments.extend(
  128. [i for i in parse_unknown_segment(text, order) if len(i.phones) > 0]
  129. )
  130. return segments
  131. def parse_unknown_segment(text, order):
  132. last_idx, last_language = 0, None
  133. for idx, char in enumerate(text):
  134. if char == " " or char in punctuation or char in string.digits:
  135. # If the punctuation / number is in the middle of the text,
  136. # we should not split the text.
  137. detected_language = last_language or order[0]
  138. else:
  139. detected_language = None
  140. _order = order.copy()
  141. if last_language is not None:
  142. # Prioritize the last language
  143. _order.remove(last_language)
  144. _order.insert(0, last_language)
  145. for language in _order:
  146. for start, end in language_unicode_range_map[language]:
  147. if start <= ord(char) <= end:
  148. detected_language = language
  149. break
  150. if detected_language is not None:
  151. break
  152. assert (
  153. detected_language is not None
  154. ), f"Incorrect language: {char}, clean before calling this function."
  155. if last_language is None:
  156. last_language = detected_language
  157. if detected_language != last_language:
  158. yield Segment(text[last_idx:idx], last_language)
  159. last_idx = idx
  160. last_language = detected_language
  161. if last_idx != len(text):
  162. yield Segment(text[last_idx:], last_language)
  163. def segments_to_phones(
  164. segments: list[Segment],
  165. ) -> tuple[tuple[Optional[str], str], list[int]]:
  166. phones = []
  167. ids = []
  168. for segment in segments:
  169. for phone in segment.phones:
  170. if phone.strip() == "":
  171. continue
  172. q0 = (segment.language, phone)
  173. q1 = (None, phone)
  174. if q0 in symbols_to_id:
  175. phones.append(q0)
  176. ids.append(symbols_to_id[q0])
  177. elif q1 in symbols_to_id:
  178. phones.append(q1)
  179. ids.append(symbols_to_id[q1])
  180. else:
  181. logger.warning(
  182. f"Unknown phone: {segment.language}: `{phone}`, ignored."
  183. )
  184. return phones, ids
  185. def g2p(text, order=None):
  186. segments = parse_text_to_segments(text, order=order)
  187. phones, _ = segments_to_phones(segments)
  188. return phones
  189. if __name__ == "__main__":
  190. segments = parse_text_to_segments(
  191. "测试一下 Hugging face, BGM声音很大吗?那我改一下. <jp>世界、こんにちは。</jp>" # noqa: E501
  192. )
  193. print(segments)
  194. segments = parse_text_to_segments(
  195. "测试一下 Hugging face, BGM声音很大吗?那我改一下. 世界、こんにちは。" # noqa: E501
  196. )
  197. print(segments)
  198. print(clean_text("测试一下 Hugging face, BGM声音很大吗?那我改一下. 世界、こんにちは。<p:123> <p:aH>"))