|
|
@@ -109,8 +109,6 @@ def decode_one_token(
|
|
|
|
|
|
# Disable <s> and </s> tokens for codebooks
|
|
|
if model.config.num_codebooks != 0:
|
|
|
- logits.codebook_logits[:, :, :, :1] = -float("Inf")
|
|
|
-
|
|
|
for i in range(model.config.num_codebooks):
|
|
|
codebooks.append(
|
|
|
sample(
|
|
|
@@ -122,12 +120,7 @@ def decode_one_token(
|
|
|
)[0]
|
|
|
)
|
|
|
|
|
|
- codebooks = torch.stack(codebooks, dim=0)
|
|
|
- codebooks[1] = torch.where(
|
|
|
- codebooks[0] <= 32311, codebooks[1], codebooks[0] - 32311 + 2
|
|
|
- )
|
|
|
-
|
|
|
- return codebooks
|
|
|
+ return torch.stack(codebooks, dim=0)
|
|
|
|
|
|
|
|
|
def prefill(
|
|
|
@@ -143,10 +136,7 @@ def prefill(
|
|
|
)[0]
|
|
|
]
|
|
|
|
|
|
- # Disable <s> and </s> tokens for codebooks
|
|
|
if model.config.num_codebooks != 0:
|
|
|
- logits.codebook_logits[:, :, :, :2] = -float("Inf")
|
|
|
-
|
|
|
for i in range(model.config.num_codebooks):
|
|
|
codebooks.append(
|
|
|
sample(
|
|
|
@@ -330,9 +320,9 @@ def encode_tokens(
|
|
|
data = data[:num_codebooks]
|
|
|
|
|
|
# Since 1.0, we use <s:xxx> to replace <semantic>
|
|
|
- main_tokens = [f"<s:{i}>" for i in data[0]]
|
|
|
- main_token_ids = tokenizer.convert_tokens_to_ids(main_tokens)
|
|
|
- main_token_ids = torch.tensor([main_token_ids], dtype=torch.int, device=device)
|
|
|
+ main_token_ids = torch.tensor(
|
|
|
+ [[tokenizer.pad_token_id] * data.size(1)], dtype=torch.int, device=device
|
|
|
+ )
|
|
|
|
|
|
data = torch.cat((main_token_ids, data), dim=0)
|
|
|
prompt = torch.cat((prompt, data), dim=1)
|
|
|
@@ -502,7 +492,7 @@ def main(
|
|
|
decode_one_token, mode="reduce-overhead", fullgraph=True
|
|
|
)
|
|
|
|
|
|
- for i in range(num_samples):
|
|
|
+ for idx in range(num_samples):
|
|
|
torch.cuda.synchronize()
|
|
|
|
|
|
t0 = time.perf_counter()
|
|
|
@@ -518,7 +508,7 @@ def main(
|
|
|
repetition_penalty=repetition_penalty,
|
|
|
)
|
|
|
|
|
|
- if i == 0 and compile:
|
|
|
+ if idx == 0 and compile:
|
|
|
logger.info(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
|
|
|
|
|
|
torch.cuda.synchronize()
|
|
|
@@ -535,11 +525,18 @@ def main(
|
|
|
)
|
|
|
|
|
|
codes = y[1:, prompt_length:-1]
|
|
|
+ new_codes = []
|
|
|
+ for j, code in enumerate(codes):
|
|
|
+ new_codes.append(
|
|
|
+ code[j : codes.shape[1] - (model.config.num_codebooks - j - 1)]
|
|
|
+ )
|
|
|
+
|
|
|
+ codes = torch.stack(new_codes, dim=0)
|
|
|
codes = codes - 2
|
|
|
assert (codes >= 0).all(), "Codes should be >= 0"
|
|
|
|
|
|
- np.save(f"codes_{i}.npy", codes.cpu().numpy())
|
|
|
- logger.info(f"Saved codes to codes_{i}.npy")
|
|
|
+ np.save(f"codes_{idx}.npy", codes.cpu().numpy())
|
|
|
+ logger.info(f"Saved codes to codes_{idx}.npy")
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|