|
@@ -241,7 +241,9 @@ def generate(
|
|
|
|
|
|
|
|
codebook_dim = 1 + model.config.num_codebooks
|
|
codebook_dim = 1 + model.config.num_codebooks
|
|
|
# create an empty tensor of the expected final shape and fill in the current tokens
|
|
# create an empty tensor of the expected final shape and fill in the current tokens
|
|
|
- empty = torch.empty((codebook_dim, max_new_tokens), dtype=dtype, device=device)
|
|
|
|
|
|
|
+ empty = torch.empty(
|
|
|
|
|
+ (codebook_dim, model.config.max_seq_len), dtype=dtype, device=device
|
|
|
|
|
+ )
|
|
|
empty[:, :T] = prompt
|
|
empty[:, :T] = prompt
|
|
|
seq = empty
|
|
seq = empty
|
|
|
input_pos = torch.arange(0, T, device=device)
|
|
input_pos = torch.arange(0, T, device=device)
|