Просмотр исходного кода

Support prompt auto truncating codebook

Lengyue 2 лет назад
Родитель
Сommit
65e646ae31
1 измененных файлов с 15 добавлено и 6 удалено
  1. 15 6
      tools/llama/generate.py

+ 15 - 6
tools/llama/generate.py

@@ -269,6 +269,7 @@ def encode_tokens(
     use_g2p=False,
     speaker=None,
     order="zh,jp,en",
+    num_codebooks=4,
 ):
     if prompt_text is not None:
         string = prompt_text + " " + string
@@ -298,7 +299,7 @@ def encode_tokens(
     tokens = torch.tensor([tokens], dtype=torch.int, device=device)
 
     # Codebooks
-    zeros = torch.zeros((4, tokens.size(1)), dtype=torch.int, device=device)
+    zeros = torch.zeros((num_codebooks, tokens.size(1)), dtype=torch.int, device=device)
     prompt = torch.cat((tokens, zeros), dim=0)
 
     if prompt_tokens is None:
@@ -308,11 +309,18 @@ def encode_tokens(
     assert prompt_tokens.ndim == 2
     data = prompt_tokens + 2
 
-    zeros = (
-        torch.zeros((1, data.size(1)), dtype=torch.int, device=device)
-        + tokenizer.pad_token_id
-    )  # 32311 is the <pad> token
-    data = torch.cat((zeros, data), dim=0)
+    if prompt_tokens.shape[0] > num_codebooks:
+        logger.warning(
+            f"Prompt tokens shape {prompt_tokens.shape} is larger than num_codebooks {num_codebooks}, getting first {num_codebooks} codebooks"
+        )
+        data = data[:num_codebooks]
+
+    # Since 1.0, we use <s:xxx> to replace <semantic>
+    main_tokens = [f"<s:{i}>" for i in data[0]]
+    main_token_ids = tokenizer.convert_tokens_to_ids(main_tokens)
+    main_token_ids = torch.tensor([main_token_ids], dtype=torch.int, device=device)
+
+    data = torch.cat((main_token_ids, data), dim=0)
     prompt = torch.cat((prompt, data), dim=1)
 
     return prompt
@@ -434,6 +442,7 @@ def main(
         use_g2p=use_g2p,
         speaker=speaker,
         order=order,
+        num_codebooks=model.config.num_codebooks,
     )
     prompt_length = encoded.size(1)
     logger.info(f"Encoded prompt shape: {encoded.shape}")