Преглед изворни кода

Add in-context evaluation & smart pad util

Lengyue пре 1 година
родитељ
комит
18d71f72cb
2 измењених фајлова са 218 додато и 0 уклоњено
  1. 171 0
      tools/llama/eval_in_context.py
  2. 47 0
      tools/smart_pad.py

+ 171 - 0
tools/llama/eval_in_context.py

@@ -0,0 +1,171 @@
+import pyrootutils
+import torch
+import torch.nn.functional as F
+from matplotlib import pyplot as plt
+from transformers import AutoTokenizer
+
+# register eval resolver and root
+pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
+
+from torch.utils.data import DataLoader
+
+from fish_speech.datasets.text import AutoAugTextDataset, TextDataCollator
+from tools.llama.generate import load_model
+
+
+def smooth(
+    scalars: list[float], weight: float
+) -> list[float]:  # Weight between 0 and 1
+    last = scalars[0]  # First value in the plot (first timestep)
+    smoothed = list()
+    for point in scalars:
+        smoothed_val = last * weight + (1 - weight) * point  # Calculate smoothed value
+        smoothed.append(smoothed_val)  # Save it
+        last = smoothed_val  # Anchor the last smoothed value
+
+    return smoothed
+
+
+@torch.inference_mode()
+def analyze_one_model(loader, config, weight, max_length):
+    device = "cuda" if torch.cuda.is_available() else "cpu"
+    model = load_model(
+        config,
+        weight,
+        device,
+        torch.bfloat16,
+        max_length,
+        compile=False,
+    )[0]
+
+    current_step = 0
+    model.eval()
+
+    semantic_loss_sum = torch.zeros(
+        max_length,
+        dtype=torch.float32,
+        device=device,
+    )
+    counter = torch.zeros(
+        max_length,
+        dtype=torch.long,
+        device=device,
+    )
+
+    for batch in loader:
+        batch = {k: v.to(device) for k, v in batch.items()}
+
+        labels = batch["labels"]
+        outputs = model(
+            inp=batch["inputs"],
+            key_padding_mask=batch["attention_masks"],
+        )
+
+        token_logits = outputs.token_logits
+        codebook_logits = outputs.codebook_logits
+
+        # Generate labels
+        base_loss = F.cross_entropy(
+            token_logits.reshape(-1, token_logits.size(-1)),
+            labels[:, 0].reshape(-1),
+            ignore_index=-100,
+            reduction="none",
+        )
+
+        codebook_labels = labels[:, 1 : 1 + model.config.num_codebooks].mT
+        semantic_loss = F.cross_entropy(
+            codebook_logits.reshape(-1, codebook_logits.size(-1)),
+            codebook_labels.reshape(-1),
+            ignore_index=-100,
+            reduction="none",
+        )
+
+        base_loss = base_loss.reshape(labels[:, 0].shape)
+        semantic_loss = semantic_loss.reshape(codebook_labels.shape)
+
+        semantic_loss_frame = semantic_loss.mean(-1)
+        pad_pos = codebook_labels.sum(-1) == -100 * model.config.num_codebooks
+
+        for loss_sample, pad in zip(semantic_loss_frame, pad_pos):
+            semantic_loss_sum[~pad] += loss_sample[~pad]
+            counter[~pad] += 1
+
+        current_step += 1
+        if current_step == 10:
+            break
+
+    semantic_loss = semantic_loss.cpu()
+    counter = counter.cpu()
+    xs, ys = [], []
+
+    for i, (loss, count) in enumerate(zip(semantic_loss_sum, counter)):
+        if count > 0:
+            xs.append(i)
+            ys.append((loss / count).item())  # for better loss visualization
+
+    smoothed_ys = smooth(ys, 0.95)
+
+    # Unload model
+    del model
+    torch.cuda.empty_cache()
+
+    return xs, ys, smoothed_ys
+
+
+def main():
+    tokenizer = AutoTokenizer.from_pretrained("fishaudio/fish-speech-1")
+    max_length = 4096
+
+    ds = AutoAugTextDataset(
+        ["data/protos/sft/云天河"],
+        tokenizer=tokenizer,
+        use_speaker=False,
+        interactive_prob=1.0,
+        max_length=max_length,
+    )
+
+    loader = DataLoader(
+        ds,
+        batch_size=8,
+        collate_fn=TextDataCollator(tokenizer, max_length=max_length),
+        num_workers=0,
+        shuffle=False,
+    )
+
+    plt.figure(figsize=(10, 5), dpi=200)
+
+    plt.xlabel("Frame")
+    plt.ylabel("Loss")
+    plt.yscale("log")
+    plt.title("Semantic Loss")
+    plt.grid(which="both", axis="both")
+    plt.xlim(0, max_length)
+
+    tests = [
+        (
+            "pertrain-medium",
+            "dual_ar_2_codebook_medium",
+            "checkpoints/text2semantic-pretrain-medium-2k-v1.pth",
+        ),
+        (
+            "sft-medium",
+            "dual_ar_2_codebook_medium",
+            "checkpoints/text2semantic-sft-medium-v1.1-4k.pth",
+        ),
+        (
+            "sft-large",
+            "dual_ar_2_codebook_large",
+            "checkpoints/text2semantic-sft-large-v1.1-4k.pth",
+        ),
+    ]
+
+    for name, config, weight in tests:
+        xs, _, smoothed_ys = analyze_one_model(loader, config, weight, max_length)
+        plt.plot(xs, smoothed_ys, label=name)
+
+    plt.legend()
+    plt.savefig("semantic_loss.png")
+
+
+if __name__ == "__main__":
+    main()

+ 47 - 0
tools/smart_pad.py

@@ -0,0 +1,47 @@
+import random
+from multiprocessing import Pool
+from pathlib import Path
+
+import click
+import librosa
+import torch.nn.functional as F
+import torchaudio
+from tqdm import tqdm
+
+from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files
+
+threshold = 10 ** (-50 / 20.0)
+
+
+def process(file):
+    waveform, sample_rate = torchaudio.load(str(file), backend="sox")
+    loudness = librosa.feature.rms(
+        y=waveform.numpy().squeeze(), frame_length=2048, hop_length=512, center=True
+    )[0]
+    for i in range(len(loudness) - 1, 0, -1):
+        if loudness[i] > threshold:
+            break
+
+    silent_time = (len(loudness) - i) * 512 / sample_rate
+
+    if silent_time <= 0.3:
+        random_time = random.uniform(0.3, 0.7)
+        waveform = F.pad(
+            waveform, (0, int(random_time * sample_rate)), mode="constant", value=0
+        )
+
+    torchaudio.save(uri=str(file), src=waveform, sample_rate=sample_rate)
+
+
+@click.command()
+@click.argument("source", type=Path)
+@click.option("--num-workers", type=int, default=12)
+def main(source, num_workers):
+    files = list(list_files(source, AUDIO_EXTENSIONS, recursive=True))
+
+    with Pool(num_workers) as p:
+        list(tqdm(p.imap_unordered(process, files), total=len(files)))
+
+
+if __name__ == "__main__":
+    main()