Browse Source

Implement windowed repeat penalty

Lengyue 2 năm trước cách đây
mục cha
commit
c54eecc63f

+ 2 - 2
fish_speech/configs/text2semantic_finetune.yaml

@@ -4,8 +4,8 @@ defaults:
 
 project: text2semantic_400m_finetune
 max_length: 4096
-ckpt_path: results/text2semantic_400m_pretrain/checkpoints/step_000065000.ckpt
-resume_weights_only: true
+# ckpt_path: results/text2semantic_400m_pretrain/checkpoints/step_000065000.ckpt
+# resume_weights_only: true
 
 # Lightning Trainer
 trainer:

+ 2 - 2
fish_speech/datasets/text.py

@@ -218,8 +218,8 @@ class AutoAugTextDataset(IterableDataset):
 
         final_text, final_semantic = [], []
 
-        # Shuffle unique lines
-        request = SampleDataRequest(num_samples=50)
+        # Shuffle unique lines, estimate that each sample is at least 20 tokens
+        request = SampleDataRequest(num_samples=self.max_length // 20)
         response = self.stub.SampleData(request)
         if len(response.samples) == 0:
             # Invalid group

+ 14 - 5
tools/llama/generate.py

@@ -43,11 +43,11 @@ def logits_to_probs(
 ):
     if previous_tokens is not None and repetition_penalty != 1.0:
         previous_tokens = previous_tokens.long()
-        score = torch.gather(logits, dim=-1, index=previous_tokens)
+        score = torch.gather(logits, dim=0, index=previous_tokens)
         score = torch.where(
             score < 0, score * repetition_penalty, score / repetition_penalty
         )
-        logits.scatter_(dim=-1, index=previous_tokens, src=score)
+        logits.scatter_(dim=0, index=previous_tokens, src=score)
 
     if top_p is not None and top_p < 1.0:
         sorted_logits, sorted_indices = torch.sort(logits, descending=True)
@@ -110,7 +110,9 @@ def decode_one_token(
             codebooks.append(
                 sample(
                     logits.codebook_logits[:, :, i],
-                    previous_tokens=previous_tokens[i + 1],
+                    previous_tokens=previous_tokens[i + 1]
+                    if previous_tokens is not None
+                    else None,
                     **sampling_kwargs,
                 )[0]
             )
@@ -162,6 +164,13 @@ def decode_n_tokens(
     )
 
     for i in tqdm(range(num_new_tokens)):
+        # We need to get windowed repeat penalty
+        win_size = 16
+        if i < win_size:
+            window = previous_tokens[:, :win_size]
+        else:
+            window = previous_tokens[:, i - win_size : i]
+
         with torch.backends.cuda.sdp_kernel(
             enable_flash=False, enable_mem_efficient=False, enable_math=True
         ):  # Actually better for Inductor to codegen attention here
@@ -169,7 +178,7 @@ def decode_n_tokens(
                 model,
                 cur_token,
                 input_pos,
-                previous_tokens,
+                window,
                 **sampling_kwargs,
             )
 
@@ -340,7 +349,7 @@ def load_model(config_name, checkpoint_path, device, precision):
 @click.option("--max_new_tokens", type=int, default=0)
 @click.option("--top-k", type=int, default=None)
 @click.option("--top-p", type=float, default=0.5)
-@click.option("--repetition-penalty", type=float, default=1.1)
+@click.option("--repetition-penalty", type=float, default=1.5)
 @click.option("--temperature", type=float, default=0.7)
 @click.option(
     "--checkpoint-path",