Procházet zdrojové kódy

Add workaround for english no-g2p decoding

Lengyue před 2 roky
rodič
revize
a0bf975f55
1 změnil soubory, kde provedl 14 přidání a 7 odebrání
  1. 14 7
      tools/llama/generate.py

+ 14 - 7
tools/llama/generate.py

@@ -292,13 +292,20 @@ def encode_tokens(
 
     string = f"[INST] {string} [/INST]"
 
-    tokens = tokenizer.encode(
-        string,
-        max_length=10**6,
-        add_special_tokens=bos,
-        truncation=False,
-    )
-    tokens = torch.tensor([tokens], dtype=torch.int, device=device)
+    # Handle English less frequent words
+    # TODO: update tokenizer to handle this
+    sub_strings = string.split(" ")
+    new_tokens = []
+    for i, string in enumerate(sub_strings):
+        tokens = tokenizer.encode(
+            string,
+            add_special_tokens=i == 0 and bos,
+            max_length=10**6,
+            truncation=False,
+        )
+        new_tokens.extend(tokens)
+
+    tokens = torch.tensor([new_tokens], dtype=torch.int, device=device)
 
     # Codebooks
     zeros = torch.zeros((num_codebooks, tokens.size(1)), dtype=torch.int, device=device)