spliter.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. import re
  2. import string
  3. from fish_speech.text.clean import clean_text
  4. def utf_8_len(text):
  5. return len(text.encode("utf-8"))
  6. def break_text(texts, length, splits: set):
  7. for text in texts:
  8. if utf_8_len(text) <= length:
  9. yield text
  10. continue
  11. curr = ""
  12. for char in text:
  13. curr += char
  14. if char in splits:
  15. yield curr
  16. curr = ""
  17. if curr:
  18. yield curr
  19. def break_text_by_length(texts, length):
  20. for text in texts:
  21. if utf_8_len(text) <= length:
  22. yield text
  23. continue
  24. curr = ""
  25. for char in text:
  26. curr += char
  27. if utf_8_len(curr) >= length:
  28. yield curr
  29. curr = ""
  30. if curr:
  31. yield curr
  32. def add_cleaned(curr, segments):
  33. curr = curr.strip()
  34. if curr and not all(c.isspace() or c in string.punctuation for c in curr):
  35. segments.append(curr)
  36. def protect_float(text):
  37. # Turns 3.14 into <3_f_14> to prevent splitting
  38. return re.sub(r"(\d+)\.(\d+)", r"<\1_f_\2>", text)
  39. def unprotect_float(text):
  40. # Turns <3_f_14> into 3.14
  41. return re.sub(r"<(\d+)_f_(\d+)>", r"\1.\2", text)
  42. def split_text(text, length):
  43. text = clean_text(text)
  44. # Break the text into pieces with following rules:
  45. # 1. Split the text at ".", "!", "?" if text is NOT a float
  46. # 2. If the text is longer than length, split at ","
  47. # 3. If the text is still longer than length, split at " "
  48. # 4. If the text is still longer than length, split at any character to length
  49. texts = [text]
  50. texts = map(protect_float, texts)
  51. texts = break_text(texts, length, {".", "!", "?", "。", "!", "?"})
  52. texts = map(unprotect_float, texts)
  53. texts = break_text(texts, length, {",", ","})
  54. texts = break_text(texts, length, {" "})
  55. texts = list(break_text_by_length(texts, length))
  56. # Then, merge the texts into segments with length <= length
  57. segments = []
  58. curr = ""
  59. for text in texts:
  60. if utf_8_len(curr) + utf_8_len(text) <= length:
  61. curr += text
  62. else:
  63. add_cleaned(curr, segments)
  64. curr = text
  65. if curr:
  66. add_cleaned(curr, segments)
  67. return segments
  68. if __name__ == "__main__":
  69. # Test the split_text function
  70. text = "This is a test sentence. This is another test sentence. And a third one."
  71. assert split_text(text, 50) == [
  72. "This is a test sentence.",
  73. "This is another test sentence. And a third one.",
  74. ]
  75. assert split_text("a,aaaaaa3.14", 10) == ["a,", "aaaaaa3.14"]
  76. assert split_text(" ", 10) == []
  77. assert split_text("a", 10) == ["a"]
  78. text = "This is a test sentence with only commas, and no dots, and no exclamation marks, and no question marks, and no newlines."
  79. assert split_text(text, 50) == [
  80. "This is a test sentence with only commas,",
  81. "and no dots, and no exclamation marks,",
  82. "and no question marks, and no newlines.",
  83. ]
  84. text = "This is a test sentence This is a test sentence This is a test sentence. This is a test sentence, This is a test sentence, This is a test sentence."
  85. # First half split at " ", second half split at ","
  86. assert split_text(text, 50) == [
  87. "This is a test sentence This is a test sentence",
  88. "This is a test sentence. This is a test sentence,",
  89. "This is a test sentence, This is a test sentence.",
  90. ]
  91. text = "这是一段很长的中文文本,而且没有句号,也没有感叹号,也没有问号,也没有换行符。"
  92. assert split_text(text, 50) == [
  93. "这是一段很长的中文文本,",
  94. "而且没有句号,也没有感叹号,",
  95. "也没有问号,也没有换行符.",
  96. ]