|
|
@@ -1,23 +1,81 @@
|
|
|
import platform
|
|
|
-from typing import Any, Optional
|
|
|
+from dataclasses import dataclass
|
|
|
+from typing import Any, Dict, Optional
|
|
|
|
|
|
import lightning as L
|
|
|
+import loralib as lora
|
|
|
import torch
|
|
|
import torch.nn.functional as F
|
|
|
from lightning.pytorch.utilities.types import OptimizerLRScheduler
|
|
|
|
|
|
import fish_speech.utils as utils
|
|
|
+from fish_speech.models.text2semantic.llama import Transformer
|
|
|
|
|
|
log = utils.RankedLogger(__name__, rank_zero_only=True)
|
|
|
|
|
|
|
|
|
+@dataclass
|
|
|
+class LoraConfig:
|
|
|
+ r: int
|
|
|
+ lora_alpha: float
|
|
|
+ lora_dropout: float = 0.0
|
|
|
+
|
|
|
+
|
|
|
class TextToSemantic(L.LightningModule):
|
|
|
- def __init__(self, model, optimizer: Any, lr_scheduler: Any):
|
|
|
+ def __init__(
|
|
|
+ self,
|
|
|
+ model: Transformer,
|
|
|
+ optimizer: Any,
|
|
|
+ lr_scheduler: Any,
|
|
|
+ lora_config: Optional[LoraConfig] = None,
|
|
|
+ ):
|
|
|
super().__init__()
|
|
|
|
|
|
self.model = model
|
|
|
self.optimizer_builder = optimizer
|
|
|
self.lr_scheduler_builder = lr_scheduler
|
|
|
+ self.lora_config = lora_config
|
|
|
+
|
|
|
+ if self.lora_config is not None:
|
|
|
+ self.setup_lora()
|
|
|
+
|
|
|
+ def setup_lora(self):
|
|
|
+ # Replace the embedding layer with a LoRA layer
|
|
|
+ self.model.embeddings = lora.Embedding(
|
|
|
+ num_embeddings=self.model.embeddings.num_embeddings,
|
|
|
+ embedding_dim=self.model.embeddings.embedding_dim,
|
|
|
+ padding_idx=self.model.embeddings.padding_idx,
|
|
|
+ r=self.lora_config.r,
|
|
|
+ lora_alpha=self.lora_config.lora_alpha,
|
|
|
+ )
|
|
|
+
|
|
|
+ # Replace output layer with a LoRA layer
|
|
|
+ linears = [(self.model, "output")]
|
|
|
+
|
|
|
+ # Replace all linear layers with LoRA layers
|
|
|
+ for layer in self.model.layers:
|
|
|
+ linears.extend([(layer.attention, "wqkv"), (layer.attention, "wo")])
|
|
|
+ linears.extend(
|
|
|
+ [
|
|
|
+ (layer.feed_forward, "w1"),
|
|
|
+ (layer.feed_forward, "w2"),
|
|
|
+ (layer.feed_forward, "w3"),
|
|
|
+ ]
|
|
|
+ )
|
|
|
+
|
|
|
+ for module, layer in linears:
|
|
|
+ updated_linear = lora.Linear(
|
|
|
+ in_features=getattr(module, layer).in_features,
|
|
|
+ out_features=getattr(module, layer).out_features,
|
|
|
+ bias=getattr(module, layer).bias,
|
|
|
+ r=self.lora_config.r,
|
|
|
+ lora_alpha=self.lora_config.lora_alpha,
|
|
|
+ lora_dropout=self.lora_config.lora_dropout,
|
|
|
+ )
|
|
|
+ setattr(module, layer, updated_linear)
|
|
|
+
|
|
|
+ # Mark only the LoRA layers as trainable
|
|
|
+ lora.mark_only_lora_as_trainable(self.model, bias="lora_only")
|
|
|
|
|
|
def forward(self, x):
|
|
|
return self.model(x)
|