Explorar el Código

[Breaking] Add new text-splitter, new length based on bytes

Lengyue hace 1 año
padre
commit
dbd3b18bf2
Se han modificado 6 ficheros con 138 adiciones y 38 borrados
  1. 2 1
      fish_speech/text/__init__.py
  2. 0 1
      fish_speech/text/clean.py
  3. 130 0
      fish_speech/text/spliter.py
  4. 2 2
      tools/api.py
  5. 2 32
      tools/llama/generate.py
  6. 2 2
      tools/webui.py

+ 2 - 1
fish_speech/text/__init__.py

@@ -1,3 +1,4 @@
 from .clean import clean_text
+from .spliter import split_text
 
-__all__ = ["clean_text"]
+__all__ = ["clean_text", "split_text"]

+ 0 - 1
fish_speech/text/clean.py

@@ -1,6 +1,5 @@
 import itertools
 import re
-import string
 
 LANGUAGE_UNICODE_RANGE_MAP = {
     "ZH": [(0x4E00, 0x9FFF)],

+ 130 - 0
fish_speech/text/spliter.py

@@ -0,0 +1,130 @@
+import re
+import string
+
+from fish_speech.text.clean import clean_text
+
+
+def utf_8_len(text):
+    return len(text.encode("utf-8"))
+
+
+def break_text(texts, length, splits: set):
+    for text in texts:
+        if utf_8_len(text) <= length:
+            yield text
+            continue
+
+        curr = ""
+        for char in text:
+            curr += char
+
+            if char in splits:
+                yield curr
+                curr = ""
+
+        if curr:
+            yield curr
+
+
+def break_text_by_length(texts, length):
+    for text in texts:
+        if utf_8_len(text) <= length:
+            yield text
+            continue
+
+        curr = ""
+        for char in text:
+            curr += char
+
+            if utf_8_len(curr) >= length:
+                yield curr
+                curr = ""
+
+        if curr:
+            yield curr
+
+
+def add_cleaned(curr, segments):
+    curr = curr.strip()
+    if curr and not all(c.isspace() or c in string.punctuation for c in curr):
+        segments.append(curr)
+
+
+def protect_float(text):
+    # Turns 3.14 into <3_f_14> to prevent splitting
+    return re.sub(r"(\d+)\.(\d+)", r"<\1_f_\2>", text)
+
+
+def unprotect_float(text):
+    # Turns <3_f_14> into 3.14
+    return re.sub(r"<(\d+)_f_(\d+)>", r"\1.\2", text)
+
+
+def split_text(text, length):
+    text = clean_text(text)
+
+    # Break the text into pieces with following rules:
+    # 1. Split the text at ".", "!", "?" if text is NOT a float
+    # 2. If the text is longer than length, split at ","
+    # 3. If the text is still longer than length, split at " "
+    # 4. If the text is still longer than length, split at any character to length
+
+    texts = [text]
+    texts = map(protect_float, texts)
+    texts = break_text(texts, length, {".", "!", "?"})
+    texts = map(unprotect_float, texts)
+    texts = break_text(texts, length, {","})
+    texts = break_text(texts, length, {" "})
+    texts = list(break_text_by_length(texts, length))
+
+    # Then, merge the texts into segments with length <= length
+    segments = []
+    curr = ""
+
+    for text in texts:
+        if utf_8_len(curr) + utf_8_len(text) <= length:
+            curr += text
+        else:
+            add_cleaned(curr, segments)
+            curr = text
+
+    if curr:
+        add_cleaned(curr, segments)
+
+    return segments
+
+
+if __name__ == "__main__":
+    # Test the split_text function
+
+    text = "This is a test sentence. This is another test sentence. And a third one."
+
+    assert split_text(text, 50) == [
+        "This is a test sentence.",
+        "This is another test sentence. And a third one.",
+    ]
+    assert split_text("a,aaaaaa3.14", 10) == ["a,", "aaaaaa3.14"]
+    assert split_text("   ", 10) == []
+    assert split_text("a", 10) == ["a"]
+
+    text = "This is a test sentence with only commas, and no dots, and no exclamation marks, and no question marks, and no newlines."
+    assert split_text(text, 50) == [
+        "This is a test sentence with only commas,",
+        "and no dots, and no exclamation marks,",
+        "and no question marks, and no newlines.",
+    ]
+
+    text = "This is a test sentence This is a test sentence This is a test sentence. This is a test sentence, This is a test sentence, This is a test sentence."
+    # First half split at " ", second half split at ","
+    assert split_text(text, 50) == [
+        "This is a test sentence This is a test sentence",
+        "This is a test sentence. This is a test sentence,",
+        "This is a test sentence, This is a test sentence.",
+    ]
+
+    text = "这是一段很长的中文文本,而且没有句号,也没有感叹号,也没有问号,也没有换行符。"
+    assert split_text(text, 50) == [
+        "这是一段很长的中文文本,",
+        "而且没有句号,也没有感叹号,",
+        "也没有问号,也没有换行符.",
+    ]

+ 2 - 2
tools/api.py

@@ -168,7 +168,7 @@ class InvokeRequest(BaseModel):
     reference_text: Optional[str] = None
     reference_audio: Optional[str] = None
     max_new_tokens: int = 0
-    chunk_length: Annotated[int, Field(ge=0, le=200, strict=True)] = 30
+    chunk_length: Annotated[int, Field(ge=0, le=500, strict=True)] = 150
     top_p: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7
     repetition_penalty: Annotated[float, Field(ge=0.9, le=2.0, strict=True)] = 1.5
     temperature: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7
@@ -399,7 +399,7 @@ if __name__ == "__main__":
                 reference_text=None,
                 reference_audio=None,
                 max_new_tokens=0,
-                chunk_length=30,
+                chunk_length=150,
                 top_p=0.7,
                 repetition_penalty=1.5,
                 temperature=0.7,

+ 2 - 32
tools/llama/generate.py

@@ -20,7 +20,7 @@ from tqdm import tqdm
 from transformers import AutoTokenizer
 
 from fish_speech.datasets.text import CODEBOOK_EOS_TOKEN_ID, CODEBOOK_PAD_TOKEN_ID
-from fish_speech.text.clean import clean_text
+from fish_speech.text import clean_text, split_text
 
 os.environ["TOKENIZERS_PARALLELISM"] = "false"
 torch._inductor.config.coordinate_descent_tuning = True
@@ -416,36 +416,6 @@ def load_model(
     return model.eval(), decode_one_token
 
 
-def split_text(text, min_length):
-    text = clean_text(text)
-    segments = []
-    curr = ""
-
-    def clean_add(curr):
-        curr = curr.strip()
-        if curr and not all(c.isspace() or c in string.punctuation for c in curr):
-            segments.append(curr)
-
-    def is_float(value):
-        try:
-            float(value)
-            return True
-        except ValueError:
-            return False
-
-    for index, char in enumerate(text):
-        curr += char
-        if char not in [".", "!", "?"]:
-            continue
-        if len(curr) >= min_length and not is_float(text[index - 1 : index + 2]):
-            clean_add(curr)
-            curr = ""
-
-    clean_add(curr)
-
-    return segments
-
-
 @dataclass
 class GenerateResponse:
     action: Literal["sample", "next"]
@@ -468,7 +438,7 @@ def generate_long(
     compile: bool = False,
     iterative_prompt: bool = True,
     max_length: int = 2048,
-    chunk_length: int = 30,
+    chunk_length: int = 150,
     speaker: Optional[str] = None,
     prompt_text: Optional[str] = None,
     prompt_tokens: Optional[torch.Tensor] = None,

+ 2 - 2
tools/webui.py

@@ -263,7 +263,7 @@ def build_app():
                             label=i18n("Iterative Prompt Length, 0 means off"),
                             minimum=0,
                             maximum=500,
-                            value=30,
+                            value=150,
                             step=8,
                         )
 
@@ -461,7 +461,7 @@ if __name__ == "__main__":
             reference_audio=None,
             reference_text="",
             max_new_tokens=0,
-            chunk_length=0,
+            chunk_length=150,
             top_p=0.7,
             repetition_penalty=1.5,
             temperature=0.7,