Просмотр исходного кода

The decimal point will be split into two sentences, will be cut off from the middle. (#176)

* The decimal point will be split into two sentences, for example "据海关统计,今年前4个月,我国货物贸易进出口总值13.81万亿元,同比增长5.7%, 其中,出口7.81万亿元,增长4.9%,进口6万亿元,增长6.8%。", will be cut off from the middle.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
duliangang 1 год назад
Родитель
Сommit
ba74994696
1 измененных файлов с 9 добавлено и 3 удалено
  1. 9 3
      tools/llama/generate.py

+ 9 - 3
tools/llama/generate.py

@@ -426,12 +426,18 @@ def split_text(text, min_length):
         if curr and not all(c.isspace() or c in string.punctuation for c in curr):
             segments.append(curr)
 
-    for char in text:
+    def is_float(value):
+        try:
+            float(value)
+            return True
+        except ValueError:
+            return False
+
+    for index, char in enumerate(text):
         curr += char
         if char not in [".", "!", "?"]:
             continue
-
-        if len(curr) >= min_length:
+        if len(curr) >= min_length and not is_float(text[index - 1 : index + 2]):
             clean_add(curr)
             curr = ""