|
|
@@ -109,7 +109,7 @@ def decode_one_token(
|
|
|
|
|
|
# Disable <s> and </s> tokens for codebooks
|
|
|
if model.config.num_codebooks != 0:
|
|
|
- logits.codebook_logits[:, :, :, :2] = -float("Inf")
|
|
|
+ logits.codebook_logits[:, :, :, :1] = -float("Inf")
|
|
|
|
|
|
for i in range(model.config.num_codebooks):
|
|
|
codebooks.append(
|
|
|
@@ -194,7 +194,7 @@ def decode_n_tokens(
|
|
|
)
|
|
|
|
|
|
# TODO: use tokenizer's eos
|
|
|
- if (cur_token[0, 0, -1] == eos_token_id).any():
|
|
|
+ if cur_token[0, 0, -1] == eos_token_id or (cur_token[0, 1:, -1] == 1).any():
|
|
|
break
|
|
|
|
|
|
return previous_tokens[:, : i + 1]
|