|
@@ -112,9 +112,11 @@ def decode_one_token(
|
|
|
layer.attention.kv_cache.k_cache.fill_(0)
|
|
layer.attention.kv_cache.k_cache.fill_(0)
|
|
|
layer.attention.kv_cache.v_cache.fill_(0)
|
|
layer.attention.kv_cache.v_cache.fill_(0)
|
|
|
|
|
|
|
|
|
|
+ buffer = [x.view(1, 1, -1)]
|
|
|
for codebook_idx in range(model.config.num_codebooks):
|
|
for codebook_idx in range(model.config.num_codebooks):
|
|
|
input_pos = torch.tensor([codebook_idx], device=x.device, dtype=torch.long)
|
|
input_pos = torch.tensor([codebook_idx], device=x.device, dtype=torch.long)
|
|
|
logits = model.forward_generate_fast(x, input_pos)
|
|
logits = model.forward_generate_fast(x, input_pos)
|
|
|
|
|
+ # print(x.shape, logits.shape)
|
|
|
a = sample(
|
|
a = sample(
|
|
|
logits,
|
|
logits,
|
|
|
previous_tokens=(
|
|
previous_tokens=(
|
|
@@ -126,6 +128,20 @@ def decode_one_token(
|
|
|
)[0]
|
|
)[0]
|
|
|
x = model.fast_embeddings(a)
|
|
x = model.fast_embeddings(a)
|
|
|
codebooks.append(a)
|
|
codebooks.append(a)
|
|
|
|
|
+ # x = torch.cat(buffer, dim=1)
|
|
|
|
|
+ # logits = model.forward_fast(x)[:, -1:, :]
|
|
|
|
|
+ # a = sample(
|
|
|
|
|
+ # logits,
|
|
|
|
|
+ # previous_tokens=(
|
|
|
|
|
+ # previous_tokens[codebook_idx + 1]
|
|
|
|
|
+ # if previous_tokens is not None
|
|
|
|
|
+ # else None
|
|
|
|
|
+ # ),
|
|
|
|
|
+ # **sampling_kwargs,
|
|
|
|
|
+ # )[0]
|
|
|
|
|
+ # x = model.fast_embeddings(a)
|
|
|
|
|
+ # codebooks.append(a)
|
|
|
|
|
+ # buffer.append(x.view(1, 1, -1))
|
|
|
|
|
|
|
|
return torch.stack(codebooks, dim=0)
|
|
return torch.stack(codebooks, dim=0)
|
|
|
|
|
|
|
@@ -135,7 +151,7 @@ def prefill(
|
|
|
) -> torch.Tensor:
|
|
) -> torch.Tensor:
|
|
|
# input_pos: [B, S]
|
|
# input_pos: [B, S]
|
|
|
x, logits = model.forward_generate_slow(x, input_pos)
|
|
x, logits = model.forward_generate_slow(x, input_pos)
|
|
|
- print("---", x.shape, logits.shape)
|
|
|
|
|
|
|
+
|
|
|
codebooks = [
|
|
codebooks = [
|
|
|
sample(
|
|
sample(
|
|
|
logits,
|
|
logits,
|
|
@@ -149,6 +165,7 @@ def prefill(
|
|
|
layer.attention.kv_cache.k_cache.fill_(0)
|
|
layer.attention.kv_cache.k_cache.fill_(0)
|
|
|
layer.attention.kv_cache.v_cache.fill_(0)
|
|
layer.attention.kv_cache.v_cache.fill_(0)
|
|
|
|
|
|
|
|
|
|
+ buffer = [x.view(1, 1, -1)]
|
|
|
for codebook_idx in range(model.config.num_codebooks):
|
|
for codebook_idx in range(model.config.num_codebooks):
|
|
|
input_pos = torch.tensor([codebook_idx], device=x.device, dtype=torch.long)
|
|
input_pos = torch.tensor([codebook_idx], device=x.device, dtype=torch.long)
|
|
|
logits = model.forward_generate_fast(x, input_pos)
|
|
logits = model.forward_generate_fast(x, input_pos)
|
|
@@ -160,6 +177,15 @@ def prefill(
|
|
|
)[0]
|
|
)[0]
|
|
|
x = model.fast_embeddings(a)
|
|
x = model.fast_embeddings(a)
|
|
|
codebooks.append(a)
|
|
codebooks.append(a)
|
|
|
|
|
+ # x = torch.cat(buffer, dim=1)
|
|
|
|
|
+ # logits = model.forward_fast(x)[:, -1:, :]
|
|
|
|
|
+ # a = sample(
|
|
|
|
|
+ # logits,
|
|
|
|
|
+ # **sampling_kwargs,
|
|
|
|
|
+ # )[0]
|
|
|
|
|
+ # x = model.fast_embeddings(a)
|
|
|
|
|
+ # codebooks.append(a)
|
|
|
|
|
+ # buffer.append(x.view(1, 1, -1))
|
|
|
|
|
|
|
|
return torch.stack(codebooks, dim=0)
|
|
return torch.stack(codebooks, dim=0)
|
|
|
|
|
|
|
@@ -211,6 +237,7 @@ def decode_n_tokens(
|
|
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
@torch.no_grad()
|
|
|
|
|
+@torch.inference_mode()
|
|
|
def generate(
|
|
def generate(
|
|
|
*,
|
|
*,
|
|
|
model: Transformer,
|
|
model: Transformer,
|
|
@@ -424,7 +451,7 @@ def split_text(text, min_length):
|
|
|
@click.option("--num-samples", type=int, default=1)
|
|
@click.option("--num-samples", type=int, default=1)
|
|
|
@click.option("--max-new-tokens", type=int, default=0)
|
|
@click.option("--max-new-tokens", type=int, default=0)
|
|
|
@click.option("--top-k", type=int, default=None)
|
|
@click.option("--top-k", type=int, default=None)
|
|
|
-@click.option("--top-p", type=float, default=0.5)
|
|
|
|
|
|
|
+@click.option("--top-p", type=float, default=0.9)
|
|
|
@click.option("--repetition-penalty", type=float, default=1.2)
|
|
@click.option("--repetition-penalty", type=float, default=1.2)
|
|
|
@click.option("--temperature", type=float, default=0.7)
|
|
@click.option("--temperature", type=float, default=0.7)
|
|
|
@click.option(
|
|
@click.option(
|