|
@@ -98,21 +98,34 @@ def decode_one_token(
|
|
|
) -> torch.Tensor:
|
|
) -> torch.Tensor:
|
|
|
assert input_pos.shape[-1] == 1
|
|
assert input_pos.shape[-1] == 1
|
|
|
|
|
|
|
|
- logits = model.forward_generate(x, input_pos)
|
|
|
|
|
|
|
+ x, logits = model.forward_generate_slow(x, input_pos)
|
|
|
codebooks = [
|
|
codebooks = [
|
|
|
sample(
|
|
sample(
|
|
|
- logits.token_logits,
|
|
|
|
|
|
|
+ logits,
|
|
|
previous_tokens=None, # Disable repetition penalty for the token codebook
|
|
previous_tokens=None, # Disable repetition penalty for the token codebook
|
|
|
**sampling_kwargs,
|
|
**sampling_kwargs,
|
|
|
)[0]
|
|
)[0]
|
|
|
]
|
|
]
|
|
|
|
|
|
|
|
- # Disable <s> and </s> tokens for codebooks
|
|
|
|
|
- if model.config.num_codebooks != 0:
|
|
|
|
|
- for i in range(model.config.num_codebooks):
|
|
|
|
|
- codebooks.append(
|
|
|
|
|
- torch.argmax(logits.codebook_logits[:, :, i], dim=-1).view(1)
|
|
|
|
|
- )
|
|
|
|
|
|
|
+ # Cleanup the cache
|
|
|
|
|
+ for layer in model.fast_layers:
|
|
|
|
|
+ layer.attention.kv_cache.k_cache.fill_(0)
|
|
|
|
|
+ layer.attention.kv_cache.v_cache.fill_(0)
|
|
|
|
|
+
|
|
|
|
|
+ for codebook_idx in range(model.config.num_codebooks):
|
|
|
|
|
+ input_pos = torch.tensor([codebook_idx], device=x.device, dtype=torch.long)
|
|
|
|
|
+ logits = model.forward_generate_fast(x, input_pos)
|
|
|
|
|
+ a = sample(
|
|
|
|
|
+ logits,
|
|
|
|
|
+ previous_tokens=(
|
|
|
|
|
+ previous_tokens[codebook_idx + 1]
|
|
|
|
|
+ if previous_tokens is not None
|
|
|
|
|
+ else None
|
|
|
|
|
+ ),
|
|
|
|
|
+ **sampling_kwargs,
|
|
|
|
|
+ )[0]
|
|
|
|
|
+ x = model.fast_embeddings(a)
|
|
|
|
|
+ codebooks.append(a)
|
|
|
|
|
|
|
|
return torch.stack(codebooks, dim=0)
|
|
return torch.stack(codebooks, dim=0)
|
|
|
|
|
|
|
@@ -121,20 +134,32 @@ def prefill(
|
|
|
model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs
|
|
model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs
|
|
|
) -> torch.Tensor:
|
|
) -> torch.Tensor:
|
|
|
# input_pos: [B, S]
|
|
# input_pos: [B, S]
|
|
|
- logits = model.forward_generate(x, input_pos)
|
|
|
|
|
|
|
+ x, logits = model.forward_generate_slow(x, input_pos)
|
|
|
|
|
+ print("---", x.shape, logits.shape)
|
|
|
codebooks = [
|
|
codebooks = [
|
|
|
sample(
|
|
sample(
|
|
|
- logits.token_logits,
|
|
|
|
|
|
|
+ logits,
|
|
|
previous_tokens=None,
|
|
previous_tokens=None,
|
|
|
**sampling_kwargs,
|
|
**sampling_kwargs,
|
|
|
)[0]
|
|
)[0]
|
|
|
]
|
|
]
|
|
|
|
|
|
|
|
- if model.config.num_codebooks != 0:
|
|
|
|
|
- for i in range(model.config.num_codebooks):
|
|
|
|
|
- codebooks.append(
|
|
|
|
|
- torch.argmax(logits.codebook_logits[:, :, i], dim=-1).view(1)
|
|
|
|
|
- )
|
|
|
|
|
|
|
+ # Cleanup the cache
|
|
|
|
|
+ for layer in model.fast_layers:
|
|
|
|
|
+ layer.attention.kv_cache.k_cache.fill_(0)
|
|
|
|
|
+ layer.attention.kv_cache.v_cache.fill_(0)
|
|
|
|
|
+
|
|
|
|
|
+ for codebook_idx in range(model.config.num_codebooks):
|
|
|
|
|
+ input_pos = torch.tensor([codebook_idx], device=x.device, dtype=torch.long)
|
|
|
|
|
+ logits = model.forward_generate_fast(x, input_pos)
|
|
|
|
|
+ # print(x.shape, logits.shape)
|
|
|
|
|
+ a = sample(
|
|
|
|
|
+ logits,
|
|
|
|
|
+ previous_tokens=None,
|
|
|
|
|
+ **sampling_kwargs,
|
|
|
|
|
+ )[0]
|
|
|
|
|
+ x = model.fast_embeddings(a)
|
|
|
|
|
+ codebooks.append(a)
|
|
|
|
|
|
|
|
return torch.stack(codebooks, dim=0)
|
|
return torch.stack(codebooks, dim=0)
|
|
|
|
|
|
|
@@ -317,7 +342,10 @@ def encode_tokens(
|
|
|
|
|
|
|
|
# Since 1.0, we use <s:xxx> to replace <semantic>
|
|
# Since 1.0, we use <s:xxx> to replace <semantic>
|
|
|
main_token_ids = torch.tensor(
|
|
main_token_ids = torch.tensor(
|
|
|
- [[tokenizer.pad_token_id] * data.size(1)], dtype=torch.int, device=device
|
|
|
|
|
|
|
+ # TODO: replace this
|
|
|
|
|
+ [[tokenizer.pad_token_id] * data.size(1)],
|
|
|
|
|
+ dtype=torch.int,
|
|
|
|
|
+ device=device,
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
data = torch.cat((main_token_ids, data), dim=0)
|
|
data = torch.cat((main_token_ids, data), dim=0)
|
|
@@ -397,7 +425,7 @@ def split_text(text, min_length):
|
|
|
@click.option("--max-new-tokens", type=int, default=0)
|
|
@click.option("--max-new-tokens", type=int, default=0)
|
|
|
@click.option("--top-k", type=int, default=None)
|
|
@click.option("--top-k", type=int, default=None)
|
|
|
@click.option("--top-p", type=float, default=0.5)
|
|
@click.option("--top-p", type=float, default=0.5)
|
|
|
-@click.option("--repetition-penalty", type=float, default=1.5)
|
|
|
|
|
|
|
+@click.option("--repetition-penalty", type=float, default=1.2)
|
|
|
@click.option("--temperature", type=float, default=0.7)
|
|
@click.option("--temperature", type=float, default=0.7)
|
|
|
@click.option(
|
|
@click.option(
|
|
|
"--checkpoint-path",
|
|
"--checkpoint-path",
|
|
@@ -544,14 +572,14 @@ def main(
|
|
|
# Put the generated tokens
|
|
# Put the generated tokens
|
|
|
codes = y[1:, prompt_length:-1].clone()
|
|
codes = y[1:, prompt_length:-1].clone()
|
|
|
|
|
|
|
|
- if getattr(cfg, "use_delay_pattern", True):
|
|
|
|
|
- new_codes = []
|
|
|
|
|
- for j, code in enumerate(codes):
|
|
|
|
|
- new_codes.append(
|
|
|
|
|
- code[j : codes.shape[1] - (model.config.num_codebooks - j - 1)]
|
|
|
|
|
- )
|
|
|
|
|
|
|
+ # if getattr(cfg, "use_delay_pattern", True):
|
|
|
|
|
+ # new_codes = []
|
|
|
|
|
+ # for j, code in enumerate(codes):
|
|
|
|
|
+ # new_codes.append(
|
|
|
|
|
+ # code[j : codes.shape[1] - (model.config.num_codebooks - j - 1)]
|
|
|
|
|
+ # )
|
|
|
|
|
|
|
|
- codes = torch.stack(new_codes, dim=0)
|
|
|
|
|
|
|
+ # codes = torch.stack(new_codes, dim=0)
|
|
|
|
|
|
|
|
codes = codes - 2
|
|
codes = codes - 2
|
|
|
if not (codes >= 0).all():
|
|
if not (codes >= 0).all():
|