Przeglądaj źródła

Optimize inifinity generation

Lengyue 2 lat temu
rodzic
commit
e7e7a8c2e3
1 zmienionych plików z 2 dodań i 2 usunięć
  1. 2 2
      tools/llama/generate.py

+ 2 - 2
tools/llama/generate.py

@@ -109,7 +109,7 @@ def decode_one_token(
 
     # Disable <s> and </s> tokens for codebooks
     if model.config.num_codebooks != 0:
-        logits.codebook_logits[:, :, :, :2] = -float("Inf")
+        logits.codebook_logits[:, :, :, :1] = -float("Inf")
 
         for i in range(model.config.num_codebooks):
             codebooks.append(
@@ -194,7 +194,7 @@ def decode_n_tokens(
         )
 
         # TODO: use tokenizer's eos
-        if (cur_token[0, 0, -1] == eos_token_id).any():
+        if cur_token[0, 0, -1] == eos_token_id or (cur_token[0, 1:, -1] == 1).any():
             break
 
     return previous_tokens[:, : i + 1]