|
|
@@ -0,0 +1,171 @@
|
|
|
+import pyrootutils
|
|
|
+import torch
|
|
|
+import torch.nn.functional as F
|
|
|
+from matplotlib import pyplot as plt
|
|
|
+from transformers import AutoTokenizer
|
|
|
+
|
|
|
+# register eval resolver and root
|
|
|
+pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
|
|
|
+
|
|
|
+from torch.utils.data import DataLoader
|
|
|
+
|
|
|
+from fish_speech.datasets.text import AutoAugTextDataset, TextDataCollator
|
|
|
+from tools.llama.generate import load_model
|
|
|
+
|
|
|
+
|
|
|
+def smooth(
|
|
|
+ scalars: list[float], weight: float
|
|
|
+) -> list[float]: # Weight between 0 and 1
|
|
|
+ last = scalars[0] # First value in the plot (first timestep)
|
|
|
+ smoothed = list()
|
|
|
+ for point in scalars:
|
|
|
+ smoothed_val = last * weight + (1 - weight) * point # Calculate smoothed value
|
|
|
+ smoothed.append(smoothed_val) # Save it
|
|
|
+ last = smoothed_val # Anchor the last smoothed value
|
|
|
+
|
|
|
+ return smoothed
|
|
|
+
|
|
|
+
|
|
|
+@torch.inference_mode()
|
|
|
+def analyze_one_model(loader, config, weight, max_length):
|
|
|
+ device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
+ model = load_model(
|
|
|
+ config,
|
|
|
+ weight,
|
|
|
+ device,
|
|
|
+ torch.bfloat16,
|
|
|
+ max_length,
|
|
|
+ compile=False,
|
|
|
+ )[0]
|
|
|
+
|
|
|
+ current_step = 0
|
|
|
+ model.eval()
|
|
|
+
|
|
|
+ semantic_loss_sum = torch.zeros(
|
|
|
+ max_length,
|
|
|
+ dtype=torch.float32,
|
|
|
+ device=device,
|
|
|
+ )
|
|
|
+ counter = torch.zeros(
|
|
|
+ max_length,
|
|
|
+ dtype=torch.long,
|
|
|
+ device=device,
|
|
|
+ )
|
|
|
+
|
|
|
+ for batch in loader:
|
|
|
+ batch = {k: v.to(device) for k, v in batch.items()}
|
|
|
+
|
|
|
+ labels = batch["labels"]
|
|
|
+ outputs = model(
|
|
|
+ inp=batch["inputs"],
|
|
|
+ key_padding_mask=batch["attention_masks"],
|
|
|
+ )
|
|
|
+
|
|
|
+ token_logits = outputs.token_logits
|
|
|
+ codebook_logits = outputs.codebook_logits
|
|
|
+
|
|
|
+ # Generate labels
|
|
|
+ base_loss = F.cross_entropy(
|
|
|
+ token_logits.reshape(-1, token_logits.size(-1)),
|
|
|
+ labels[:, 0].reshape(-1),
|
|
|
+ ignore_index=-100,
|
|
|
+ reduction="none",
|
|
|
+ )
|
|
|
+
|
|
|
+ codebook_labels = labels[:, 1 : 1 + model.config.num_codebooks].mT
|
|
|
+ semantic_loss = F.cross_entropy(
|
|
|
+ codebook_logits.reshape(-1, codebook_logits.size(-1)),
|
|
|
+ codebook_labels.reshape(-1),
|
|
|
+ ignore_index=-100,
|
|
|
+ reduction="none",
|
|
|
+ )
|
|
|
+
|
|
|
+ base_loss = base_loss.reshape(labels[:, 0].shape)
|
|
|
+ semantic_loss = semantic_loss.reshape(codebook_labels.shape)
|
|
|
+
|
|
|
+ semantic_loss_frame = semantic_loss.mean(-1)
|
|
|
+ pad_pos = codebook_labels.sum(-1) == -100 * model.config.num_codebooks
|
|
|
+
|
|
|
+ for loss_sample, pad in zip(semantic_loss_frame, pad_pos):
|
|
|
+ semantic_loss_sum[~pad] += loss_sample[~pad]
|
|
|
+ counter[~pad] += 1
|
|
|
+
|
|
|
+ current_step += 1
|
|
|
+ if current_step == 10:
|
|
|
+ break
|
|
|
+
|
|
|
+ semantic_loss = semantic_loss.cpu()
|
|
|
+ counter = counter.cpu()
|
|
|
+ xs, ys = [], []
|
|
|
+
|
|
|
+ for i, (loss, count) in enumerate(zip(semantic_loss_sum, counter)):
|
|
|
+ if count > 0:
|
|
|
+ xs.append(i)
|
|
|
+ ys.append((loss / count).item()) # for better loss visualization
|
|
|
+
|
|
|
+ smoothed_ys = smooth(ys, 0.95)
|
|
|
+
|
|
|
+ # Unload model
|
|
|
+ del model
|
|
|
+ torch.cuda.empty_cache()
|
|
|
+
|
|
|
+ return xs, ys, smoothed_ys
|
|
|
+
|
|
|
+
|
|
|
+def main():
|
|
|
+ tokenizer = AutoTokenizer.from_pretrained("fishaudio/fish-speech-1")
|
|
|
+ max_length = 4096
|
|
|
+
|
|
|
+ ds = AutoAugTextDataset(
|
|
|
+ ["data/protos/sft/云天河"],
|
|
|
+ tokenizer=tokenizer,
|
|
|
+ use_speaker=False,
|
|
|
+ interactive_prob=1.0,
|
|
|
+ max_length=max_length,
|
|
|
+ )
|
|
|
+
|
|
|
+ loader = DataLoader(
|
|
|
+ ds,
|
|
|
+ batch_size=8,
|
|
|
+ collate_fn=TextDataCollator(tokenizer, max_length=max_length),
|
|
|
+ num_workers=0,
|
|
|
+ shuffle=False,
|
|
|
+ )
|
|
|
+
|
|
|
+ plt.figure(figsize=(10, 5), dpi=200)
|
|
|
+
|
|
|
+ plt.xlabel("Frame")
|
|
|
+ plt.ylabel("Loss")
|
|
|
+ plt.yscale("log")
|
|
|
+ plt.title("Semantic Loss")
|
|
|
+ plt.grid(which="both", axis="both")
|
|
|
+ plt.xlim(0, max_length)
|
|
|
+
|
|
|
+ tests = [
|
|
|
+ (
|
|
|
+ "pertrain-medium",
|
|
|
+ "dual_ar_2_codebook_medium",
|
|
|
+ "checkpoints/text2semantic-pretrain-medium-2k-v1.pth",
|
|
|
+ ),
|
|
|
+ (
|
|
|
+ "sft-medium",
|
|
|
+ "dual_ar_2_codebook_medium",
|
|
|
+ "checkpoints/text2semantic-sft-medium-v1.1-4k.pth",
|
|
|
+ ),
|
|
|
+ (
|
|
|
+ "sft-large",
|
|
|
+ "dual_ar_2_codebook_large",
|
|
|
+ "checkpoints/text2semantic-sft-large-v1.1-4k.pth",
|
|
|
+ ),
|
|
|
+ ]
|
|
|
+
|
|
|
+ for name, config, weight in tests:
|
|
|
+ xs, _, smoothed_ys = analyze_one_model(loader, config, weight, max_length)
|
|
|
+ plt.plot(xs, smoothed_ys, label=name)
|
|
|
+
|
|
|
+ plt.legend()
|
|
|
+ plt.savefig("semantic_loss.png")
|
|
|
+
|
|
|
+
|
|
|
+if __name__ == "__main__":
|
|
|
+ main()
|