|
|
@@ -163,7 +163,7 @@ def decode_n_tokens(
|
|
|
**sampling_kwargs,
|
|
|
):
|
|
|
previous_tokens = torch.zeros(
|
|
|
- (model.config.num_codebooks + 1, num_new_tokens),
|
|
|
+ (model.config.num_codebooks + 1, model.config.max_seq_len),
|
|
|
dtype=torch.int,
|
|
|
device=cur_token.device,
|
|
|
)
|