| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171 |
- import pyrootutils
- import torch
- import torch.nn.functional as F
- from matplotlib import pyplot as plt
- from transformers import AutoTokenizer
- # register eval resolver and root
- pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
- from torch.utils.data import DataLoader
- from fish_speech.datasets.text import AutoAugTextDataset, TextDataCollator
- from tools.llama.generate import load_model
- def smooth(
- scalars: list[float], weight: float
- ) -> list[float]: # Weight between 0 and 1
- last = scalars[0] # First value in the plot (first timestep)
- smoothed = list()
- for point in scalars:
- smoothed_val = last * weight + (1 - weight) * point # Calculate smoothed value
- smoothed.append(smoothed_val) # Save it
- last = smoothed_val # Anchor the last smoothed value
- return smoothed
- @torch.inference_mode()
- def analyze_one_model(loader, config, weight, max_length):
- device = "cuda" if torch.cuda.is_available() else "cpu"
- model = load_model(
- config,
- weight,
- device,
- torch.bfloat16,
- max_length,
- compile=False,
- )[0]
- current_step = 0
- model.eval()
- semantic_loss_sum = torch.zeros(
- max_length,
- dtype=torch.float32,
- device=device,
- )
- counter = torch.zeros(
- max_length,
- dtype=torch.long,
- device=device,
- )
- for batch in loader:
- batch = {k: v.to(device) for k, v in batch.items()}
- labels = batch["labels"]
- outputs = model(
- inp=batch["inputs"],
- key_padding_mask=batch["attention_masks"],
- )
- token_logits = outputs.token_logits
- codebook_logits = outputs.codebook_logits
- # Generate labels
- base_loss = F.cross_entropy(
- token_logits.reshape(-1, token_logits.size(-1)),
- labels[:, 0].reshape(-1),
- ignore_index=-100,
- reduction="none",
- )
- codebook_labels = labels[:, 1 : 1 + model.config.num_codebooks].mT
- semantic_loss = F.cross_entropy(
- codebook_logits.reshape(-1, codebook_logits.size(-1)),
- codebook_labels.reshape(-1),
- ignore_index=-100,
- reduction="none",
- )
- base_loss = base_loss.reshape(labels[:, 0].shape)
- semantic_loss = semantic_loss.reshape(codebook_labels.shape)
- semantic_loss_frame = semantic_loss.mean(-1)
- pad_pos = codebook_labels.sum(-1) == -100 * model.config.num_codebooks
- for loss_sample, pad in zip(semantic_loss_frame, pad_pos):
- semantic_loss_sum[~pad] += loss_sample[~pad]
- counter[~pad] += 1
- current_step += 1
- if current_step == 10:
- break
- semantic_loss = semantic_loss.cpu()
- counter = counter.cpu()
- xs, ys = [], []
- for i, (loss, count) in enumerate(zip(semantic_loss_sum, counter)):
- if count > 0:
- xs.append(i)
- ys.append((loss / count).item()) # for better loss visualization
- smoothed_ys = smooth(ys, 0.95)
- # Unload model
- del model
- torch.cuda.empty_cache()
- return xs, ys, smoothed_ys
- def main():
- tokenizer = AutoTokenizer.from_pretrained("fishaudio/fish-speech-1")
- max_length = 4096
- ds = AutoAugTextDataset(
- ["data/protos/sft/云天河"],
- tokenizer=tokenizer,
- use_speaker=False,
- interactive_prob=1.0,
- max_length=max_length,
- )
- loader = DataLoader(
- ds,
- batch_size=8,
- collate_fn=TextDataCollator(tokenizer, max_length=max_length),
- num_workers=0,
- shuffle=False,
- )
- plt.figure(figsize=(10, 5), dpi=200)
- plt.xlabel("Frame")
- plt.ylabel("Loss")
- plt.yscale("log")
- plt.title("Semantic Loss")
- plt.grid(which="both", axis="both")
- plt.xlim(0, max_length)
- tests = [
- (
- "pertrain-medium",
- "dual_ar_2_codebook_medium",
- "checkpoints/text2semantic-pretrain-medium-2k-v1.pth",
- ),
- (
- "sft-medium",
- "dual_ar_2_codebook_medium",
- "checkpoints/text2semantic-sft-medium-v1.1-4k.pth",
- ),
- (
- "sft-large",
- "dual_ar_2_codebook_large",
- "checkpoints/text2semantic-sft-large-v1.1-4k.pth",
- ),
- ]
- for name, config, weight in tests:
- xs, _, smoothed_ys = analyze_one_model(loader, config, weight, max_length)
- plt.plot(xs, smoothed_ys, label=name)
- plt.legend()
- plt.savefig("semantic_loss.png")
- if __name__ == "__main__":
- main()
|