| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181 |
- import platform
- from dataclasses import dataclass
- from typing import Any, Dict, Optional
- import lightning as L
- import loralib as lora
- import torch
- import torch.nn.functional as F
- from lightning.pytorch.utilities.types import OptimizerLRScheduler
- import fish_speech.utils as utils
- from fish_speech.models.text2semantic.llama import Transformer
- log = utils.RankedLogger(__name__, rank_zero_only=True)
- @dataclass
- class LoraConfig:
- r: int
- lora_alpha: float
- lora_dropout: float = 0.0
- class TextToSemantic(L.LightningModule):
- def __init__(
- self,
- model: Transformer,
- optimizer: Any,
- lr_scheduler: Any,
- lora_config: Optional[LoraConfig] = None,
- ):
- super().__init__()
- self.model = model
- self.optimizer_builder = optimizer
- self.lr_scheduler_builder = lr_scheduler
- self.lora_config = lora_config
- if self.lora_config is not None:
- self.setup_lora()
- def setup_lora(self):
- # Replace the embedding layer with a LoRA layer
- self.model.embeddings = lora.Embedding(
- num_embeddings=self.model.embeddings.num_embeddings,
- embedding_dim=self.model.embeddings.embedding_dim,
- padding_idx=self.model.embeddings.padding_idx,
- r=self.lora_config.r,
- lora_alpha=self.lora_config.lora_alpha,
- )
- # Replace output layer with a LoRA layer
- linears = [(self.model, "output")]
- # Replace all linear layers with LoRA layers
- for layer in self.model.layers:
- linears.extend([(layer.attention, "wqkv"), (layer.attention, "wo")])
- linears.extend(
- [
- (layer.feed_forward, "w1"),
- (layer.feed_forward, "w2"),
- (layer.feed_forward, "w3"),
- ]
- )
- for module, layer in linears:
- updated_linear = lora.Linear(
- in_features=getattr(module, layer).in_features,
- out_features=getattr(module, layer).out_features,
- bias=getattr(module, layer).bias,
- r=self.lora_config.r,
- lora_alpha=self.lora_config.lora_alpha,
- lora_dropout=self.lora_config.lora_dropout,
- )
- setattr(module, layer, updated_linear)
- # Mark only the LoRA layers as trainable
- lora.mark_only_lora_as_trainable(self.model, bias="lora_only")
- def forward(self, x):
- return self.model(x)
- def configure_optimizers(self) -> OptimizerLRScheduler:
- optimizer = self.optimizer_builder(self.parameters())
- lr_scheduler = self.lr_scheduler_builder(optimizer)
- return {
- "optimizer": optimizer,
- "lr_scheduler": {
- "scheduler": lr_scheduler,
- "interval": "step",
- },
- }
- def _step(self, batch, batch_idx, stage: str):
- outputs = self.model(
- x=batch["inputs"],
- key_padding_mask=batch["attention_masks"],
- )
- # Generate labels
- labels = batch["labels"]
- base_loss = F.cross_entropy(
- outputs.token_logits.reshape(-1, outputs.token_logits.size(-1)),
- labels[:, 0].reshape(-1),
- ignore_index=-100,
- )
- # If we have a codebook, add the loss
- if self.model.config.num_codebooks != 0:
- codebook_labels = labels[:, 1 : 1 + self.model.config.num_codebooks].mT
- semantic_loss = F.cross_entropy(
- outputs.codebook_logits.reshape(-1, outputs.codebook_logits.size(-1)),
- codebook_labels.reshape(-1),
- ignore_index=-100,
- )
- loss = base_loss + semantic_loss
- else:
- loss = base_loss
- self.log(
- f"{stage}/loss",
- loss,
- on_step=True,
- on_epoch=False,
- prog_bar=True,
- logger=True,
- )
- if self.model.config.num_codebooks != 0:
- self.log(
- f"{stage}/base_loss",
- base_loss,
- on_step=True,
- on_epoch=False,
- prog_bar=False,
- logger=True,
- )
- self.log(
- f"{stage}/semantic_loss",
- semantic_loss,
- on_step=True,
- on_epoch=False,
- prog_bar=False,
- logger=True,
- )
- # Top-5 accuracy
- if self.model.config.num_codebooks == 0:
- _, indices = outputs.token_logits.topk(5, dim=-1)
- correct = indices.eq(labels[:, 0].unsqueeze(-1))
- correct[labels[:, 0] == -100] = 0
- correct = correct.sum()
- accuracy = correct / (labels[:, 0] != -100).sum()
- else:
- _, indices = outputs.codebook_logits.topk(5, dim=-1)
- # print(codebook_labels[0, :10], torch.argmax(outputs.codebook_logits[0, :10], dim=-1))
- # print(codebook_labels[codebook_labels != -100][:10], indices[codebook_labels != -100][:10])
- correct = indices.eq(codebook_labels.unsqueeze(-1))
- correct[codebook_labels == -100] = 0
- correct = correct.sum()
- accuracy = correct / (codebook_labels != -100).sum()
- self.log(
- f"{stage}/top_5_accuracy",
- accuracy,
- on_step=True,
- on_epoch=False,
- prog_bar=True,
- logger=True,
- )
- return loss
- def training_step(self, batch, batch_idx):
- return self._step(batch, batch_idx, "train")
- def validation_step(self, batch, batch_idx):
- return self._step(batch, batch_idx, "val")
|