eval_in_context.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171
  1. import pyrootutils
  2. import torch
  3. import torch.nn.functional as F
  4. from matplotlib import pyplot as plt
  5. from transformers import AutoTokenizer
  6. # register eval resolver and root
  7. pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
  8. from torch.utils.data import DataLoader
  9. from fish_speech.datasets.semantic import AutoAugTextDataset, TextDataCollator
  10. from tools.llama.generate import load_model
  11. def smooth(
  12. scalars: list[float], weight: float
  13. ) -> list[float]: # Weight between 0 and 1
  14. last = scalars[0] # First value in the plot (first timestep)
  15. smoothed = list()
  16. for point in scalars:
  17. smoothed_val = last * weight + (1 - weight) * point # Calculate smoothed value
  18. smoothed.append(smoothed_val) # Save it
  19. last = smoothed_val # Anchor the last smoothed value
  20. return smoothed
  21. @torch.inference_mode()
  22. def analyze_one_model(loader, config, weight, max_length):
  23. device = "cuda" if torch.cuda.is_available() else "cpu"
  24. model = load_model(
  25. config,
  26. weight,
  27. device,
  28. torch.bfloat16,
  29. max_length,
  30. compile=False,
  31. )[0]
  32. current_step = 0
  33. model.eval()
  34. semantic_loss_sum = torch.zeros(
  35. max_length,
  36. dtype=torch.float32,
  37. device=device,
  38. )
  39. counter = torch.zeros(
  40. max_length,
  41. dtype=torch.long,
  42. device=device,
  43. )
  44. for batch in loader:
  45. batch = {k: v.to(device) for k, v in batch.items()}
  46. labels = batch["labels"]
  47. outputs = model(
  48. inp=batch["inputs"],
  49. key_padding_mask=batch["attention_masks"],
  50. )
  51. token_logits = outputs.token_logits
  52. codebook_logits = outputs.codebook_logits
  53. # Generate labels
  54. base_loss = F.cross_entropy(
  55. token_logits.reshape(-1, token_logits.size(-1)),
  56. labels[:, 0].reshape(-1),
  57. ignore_index=-100,
  58. reduction="none",
  59. )
  60. codebook_labels = labels[:, 1 : 1 + model.config.num_codebooks].mT
  61. semantic_loss = F.cross_entropy(
  62. codebook_logits.reshape(-1, codebook_logits.size(-1)),
  63. codebook_labels.reshape(-1),
  64. ignore_index=-100,
  65. reduction="none",
  66. )
  67. base_loss = base_loss.reshape(labels[:, 0].shape)
  68. semantic_loss = semantic_loss.reshape(codebook_labels.shape)
  69. semantic_loss_frame = semantic_loss.mean(-1)
  70. pad_pos = codebook_labels.sum(-1) == -100 * model.config.num_codebooks
  71. for loss_sample, pad in zip(semantic_loss_frame, pad_pos):
  72. semantic_loss_sum[~pad] += loss_sample[~pad]
  73. counter[~pad] += 1
  74. current_step += 1
  75. if current_step == 10:
  76. break
  77. semantic_loss = semantic_loss.cpu()
  78. counter = counter.cpu()
  79. xs, ys = [], []
  80. for i, (loss, count) in enumerate(zip(semantic_loss_sum, counter)):
  81. if count > 0:
  82. xs.append(i)
  83. ys.append((loss / count).item()) # for better loss visualization
  84. smoothed_ys = smooth(ys, 0.95)
  85. # Unload model
  86. del model
  87. torch.cuda.empty_cache()
  88. return xs, ys, smoothed_ys
  89. def main():
  90. tokenizer = AutoTokenizer.from_pretrained("fishaudio/fish-speech-1")
  91. max_length = 4096
  92. ds = AutoAugTextDataset(
  93. ["data/protos/sft/云天河"],
  94. tokenizer=tokenizer,
  95. use_speaker=False,
  96. interactive_prob=1.0,
  97. max_length=max_length,
  98. )
  99. loader = DataLoader(
  100. ds,
  101. batch_size=8,
  102. collate_fn=TextDataCollator(tokenizer, max_length=max_length),
  103. num_workers=0,
  104. shuffle=False,
  105. )
  106. plt.figure(figsize=(10, 5), dpi=200)
  107. plt.xlabel("Frame")
  108. plt.ylabel("Loss")
  109. plt.yscale("log")
  110. plt.title("Semantic Loss")
  111. plt.grid(which="both", axis="both")
  112. plt.xlim(0, max_length)
  113. tests = [
  114. (
  115. "pertrain-medium",
  116. "dual_ar_2_codebook_medium",
  117. "checkpoints/text2semantic-pretrain-medium-2k-v1.pth",
  118. ),
  119. (
  120. "sft-medium",
  121. "dual_ar_2_codebook_medium",
  122. "checkpoints/text2semantic-sft-medium-v1.1-4k.pth",
  123. ),
  124. (
  125. "sft-large",
  126. "dual_ar_2_codebook_large",
  127. "checkpoints/text2semantic-sft-large-v1.1-4k.pth",
  128. ),
  129. ]
  130. for name, config, weight in tests:
  131. xs, _, smoothed_ys = analyze_one_model(loader, config, weight, max_length)
  132. plt.plot(xs, smoothed_ys, label=name)
  133. plt.legend()
  134. plt.savefig("semantic_loss.png")
  135. if __name__ == "__main__":
  136. main()