|
|
@@ -1,202 +0,0 @@
|
|
|
-from typing import Any, Optional
|
|
|
-
|
|
|
-import lightning as L
|
|
|
-import torch
|
|
|
-import torch.nn.functional as F
|
|
|
-from lightning.pytorch.utilities.types import OptimizerLRScheduler
|
|
|
-
|
|
|
-import fish_speech.utils as utils
|
|
|
-from fish_speech.conversation import CODEBOOK_PAD_TOKEN_ID
|
|
|
-from fish_speech.models.text2semantic.llama import NaiveTransformer
|
|
|
-
|
|
|
-log = utils.RankedLogger(__name__, rank_zero_only=True)
|
|
|
-
|
|
|
-
|
|
|
-class TextToSemantic(L.LightningModule):
|
|
|
- def __init__(
|
|
|
- self,
|
|
|
- model: NaiveTransformer,
|
|
|
- optimizer: Any,
|
|
|
- lr_scheduler: Any,
|
|
|
- ):
|
|
|
- super().__init__()
|
|
|
-
|
|
|
- self.model = model
|
|
|
- self.optimizer_builder = optimizer
|
|
|
- self.lr_scheduler_builder = lr_scheduler
|
|
|
-
|
|
|
- def forward(self, x):
|
|
|
- return self.model(x)
|
|
|
-
|
|
|
- def on_save_checkpoint(self, checkpoint):
|
|
|
- # Save only LoRA parameters
|
|
|
- state_dict = checkpoint["state_dict"]
|
|
|
- use_lora = any("lora" in name for name in state_dict.keys())
|
|
|
- if not use_lora:
|
|
|
- return
|
|
|
-
|
|
|
- for name in list(state_dict.keys()):
|
|
|
- if "lora" not in name:
|
|
|
- state_dict.pop(name)
|
|
|
-
|
|
|
- def configure_optimizers(self) -> OptimizerLRScheduler:
|
|
|
- # Get weight decay parameters
|
|
|
- weight_decay_parameters, other_parameters = [], []
|
|
|
- for name, param in self.named_parameters():
|
|
|
- if ".bias" in name or "norm.weight" in name or ".embeddings." in name:
|
|
|
- other_parameters.append(param)
|
|
|
- else:
|
|
|
- weight_decay_parameters.append(param)
|
|
|
-
|
|
|
- optimizer = self.optimizer_builder(
|
|
|
- [
|
|
|
- {"params": weight_decay_parameters},
|
|
|
- {"params": other_parameters, "weight_decay": 0.0},
|
|
|
- ]
|
|
|
- )
|
|
|
-
|
|
|
- # Print the parameters and their weight decay
|
|
|
- for i in optimizer.param_groups:
|
|
|
- log.info(
|
|
|
- f"Set weight decay: {i['weight_decay']} for {len(i['params'])} parameters"
|
|
|
- )
|
|
|
-
|
|
|
- lr_scheduler = self.lr_scheduler_builder(optimizer)
|
|
|
-
|
|
|
- return {
|
|
|
- "optimizer": optimizer,
|
|
|
- "lr_scheduler": {
|
|
|
- "scheduler": lr_scheduler,
|
|
|
- "interval": "step",
|
|
|
- },
|
|
|
- }
|
|
|
-
|
|
|
- # Copied from https://github.com/eric-mitchell/direct-preference-optimization/blob/main/trainers.py#L90
|
|
|
- def get_batch_logps(
|
|
|
- self,
|
|
|
- logits: torch.FloatTensor,
|
|
|
- labels: torch.LongTensor,
|
|
|
- average_log_prob: bool = False,
|
|
|
- ) -> torch.FloatTensor:
|
|
|
- """Compute the log probabilities of the given labels under the given logits.
|
|
|
-
|
|
|
- Args:
|
|
|
- logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, codebook_size, vocab_size)
|
|
|
- labels: Labels for which to compute the log probabilities. Label tokens with a value of -100 are ignored. Shape: (batch_size, sequence_length, codebook_size)
|
|
|
- average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens.
|
|
|
-
|
|
|
- Returns:
|
|
|
- A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits.
|
|
|
- """
|
|
|
- assert logits.shape[:-1] == labels.shape
|
|
|
-
|
|
|
- labels = labels.clone()
|
|
|
- loss_mask = labels != -100
|
|
|
-
|
|
|
- # dummy token; we'll ignore the losses on these tokens later
|
|
|
- labels[labels == -100] = 0
|
|
|
-
|
|
|
- per_token_logps = torch.gather(
|
|
|
- logits.log_softmax(-1), dim=-1, index=labels.unsqueeze(-1)
|
|
|
- ).squeeze(-1)
|
|
|
-
|
|
|
- if average_log_prob:
|
|
|
- return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
|
|
|
- else:
|
|
|
- return (per_token_logps * loss_mask).sum(-1)
|
|
|
-
|
|
|
- def _step(self, batch, batch_idx, stage: str):
|
|
|
- is_train = stage == "train"
|
|
|
-
|
|
|
- if is_train:
|
|
|
- # Key part to make lora work
|
|
|
- # Otherwise the parameters are merged, which lead to incorrect gradients
|
|
|
- self.model.train()
|
|
|
-
|
|
|
- # Do positive and negative samples in the same batch to speed up training
|
|
|
- labels = batch["labels"]
|
|
|
- outputs = self.model(
|
|
|
- inp=batch["inputs"],
|
|
|
- key_padding_mask=batch["attention_masks"],
|
|
|
- )
|
|
|
- token_logits = outputs.token_logits
|
|
|
- codebook_logits = outputs.codebook_logits
|
|
|
-
|
|
|
- # Generate labels
|
|
|
- base_loss = F.cross_entropy(
|
|
|
- token_logits.view(-1, token_logits.size(-1)),
|
|
|
- labels[:, 0].reshape(-1),
|
|
|
- ignore_index=-100,
|
|
|
- )
|
|
|
-
|
|
|
- codebook_labels = labels[:, 1 : 1 + self.model.config.num_codebooks].mT
|
|
|
- semantic_loss = F.cross_entropy(
|
|
|
- codebook_logits.view(-1, codebook_logits.size(-1)),
|
|
|
- codebook_labels.reshape(-1),
|
|
|
- ignore_index=-100,
|
|
|
- )
|
|
|
-
|
|
|
- loss = base_loss + semantic_loss
|
|
|
-
|
|
|
- self.log(
|
|
|
- f"{stage}/loss",
|
|
|
- loss,
|
|
|
- on_step=is_train,
|
|
|
- on_epoch=not is_train,
|
|
|
- prog_bar=True,
|
|
|
- logger=True,
|
|
|
- sync_dist=not is_train,
|
|
|
- )
|
|
|
-
|
|
|
- self.log(
|
|
|
- f"{stage}/base_loss",
|
|
|
- base_loss,
|
|
|
- on_step=is_train,
|
|
|
- on_epoch=not is_train,
|
|
|
- prog_bar=False,
|
|
|
- logger=True,
|
|
|
- sync_dist=not is_train,
|
|
|
- )
|
|
|
-
|
|
|
- self.log(
|
|
|
- f"{stage}/semantic_loss",
|
|
|
- semantic_loss,
|
|
|
- on_step=is_train,
|
|
|
- on_epoch=not is_train,
|
|
|
- prog_bar=False,
|
|
|
- logger=True,
|
|
|
- sync_dist=not is_train,
|
|
|
- )
|
|
|
-
|
|
|
- # Top-5 accuracy
|
|
|
- accuracy = self.get_accuracy(codebook_logits, codebook_labels)
|
|
|
- self.log(
|
|
|
- f"{stage}/top_5_accuracy",
|
|
|
- accuracy,
|
|
|
- on_step=is_train,
|
|
|
- on_epoch=not is_train,
|
|
|
- prog_bar=True,
|
|
|
- logger=True,
|
|
|
- sync_dist=not is_train,
|
|
|
- )
|
|
|
-
|
|
|
- return loss
|
|
|
-
|
|
|
- def get_accuracy(self, logits, labels):
|
|
|
- mask = (labels != -100) & (labels != CODEBOOK_PAD_TOKEN_ID)
|
|
|
- if mask.sum() == 0:
|
|
|
- return torch.tensor(0.0, device=logits.device)
|
|
|
-
|
|
|
- _, indices = logits.topk(5, dim=-1)
|
|
|
- correct = indices.eq(labels.unsqueeze(-1))
|
|
|
- correct[~mask] = 0
|
|
|
- correct = correct.sum()
|
|
|
- accuracy = correct / mask.sum()
|
|
|
-
|
|
|
- return accuracy
|
|
|
-
|
|
|
- def training_step(self, batch, batch_idx):
|
|
|
- return self._step(batch, batch_idx, "train")
|
|
|
-
|
|
|
- def validation_step(self, batch, batch_idx):
|
|
|
- return self._step(batch, batch_idx, "val")
|