|
|
@@ -1,6 +1,5 @@
|
|
|
-import platform
|
|
|
from dataclasses import dataclass
|
|
|
-from typing import Any, Dict, Optional
|
|
|
+from typing import Any, Optional
|
|
|
|
|
|
import lightning as L
|
|
|
import loralib as lora
|
|
|
@@ -28,6 +27,9 @@ class TextToSemantic(L.LightningModule):
|
|
|
optimizer: Any,
|
|
|
lr_scheduler: Any,
|
|
|
lora_config: Optional[LoraConfig] = None,
|
|
|
+ save_lora_only: bool = False,
|
|
|
+ use_dpo: bool = False,
|
|
|
+ dpo_beta: float = 0.2,
|
|
|
):
|
|
|
super().__init__()
|
|
|
|
|
|
@@ -35,6 +37,9 @@ class TextToSemantic(L.LightningModule):
|
|
|
self.optimizer_builder = optimizer
|
|
|
self.lr_scheduler_builder = lr_scheduler
|
|
|
self.lora_config = lora_config
|
|
|
+ self.save_lora_only = save_lora_only
|
|
|
+ self.use_dpo = use_dpo # We don't support reference model yet
|
|
|
+ self.dpo_beta = dpo_beta
|
|
|
|
|
|
if self.lora_config is not None:
|
|
|
self.setup_lora()
|
|
|
@@ -81,10 +86,10 @@ class TextToSemantic(L.LightningModule):
|
|
|
return self.model(x)
|
|
|
|
|
|
def on_save_checkpoint(self, checkpoint):
|
|
|
- if self.lora_config is None:
|
|
|
+ if self.lora_config is None or self.save_lora_only is False:
|
|
|
return
|
|
|
|
|
|
- # Save the LoRA parameters
|
|
|
+ # Save only LoRA parameters
|
|
|
state_dict = checkpoint["state_dict"]
|
|
|
for name in list(state_dict.keys()):
|
|
|
if "lora" not in name:
|
|
|
@@ -102,16 +107,59 @@ class TextToSemantic(L.LightningModule):
|
|
|
},
|
|
|
}
|
|
|
|
|
|
+ # Copied from https://github.com/eric-mitchell/direct-preference-optimization/blob/main/trainers.py#L90
|
|
|
+ def get_batch_logps(
|
|
|
+ self,
|
|
|
+ logits: torch.FloatTensor,
|
|
|
+ labels: torch.LongTensor,
|
|
|
+ average_log_prob: bool = False,
|
|
|
+ ) -> torch.FloatTensor:
|
|
|
+ """Compute the log probabilities of the given labels under the given logits.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, codebook_size, vocab_size)
|
|
|
+ labels: Labels for which to compute the log probabilities. Label tokens with a value of -100 are ignored. Shape: (batch_size, sequence_length, codebook_size)
|
|
|
+ average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens.
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits.
|
|
|
+ """
|
|
|
+ assert logits.shape[:-1] == labels.shape
|
|
|
+
|
|
|
+ labels = labels.clone()
|
|
|
+ loss_mask = labels != -100
|
|
|
+
|
|
|
+ # dummy token; we'll ignore the losses on these tokens later
|
|
|
+ labels[labels == -100] = 0
|
|
|
+
|
|
|
+ per_token_logps = torch.gather(
|
|
|
+ logits.log_softmax(-1), dim=-1, index=labels.unsqueeze(-1)
|
|
|
+ ).squeeze(-1)
|
|
|
+
|
|
|
+ if average_log_prob:
|
|
|
+ return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
|
|
|
+ else:
|
|
|
+ return (per_token_logps * loss_mask).sum(-1)
|
|
|
+
|
|
|
def _step(self, batch, batch_idx, stage: str):
|
|
|
+ # Do positive and negative samples in the same batch to speed up training
|
|
|
outputs = self.model(
|
|
|
x=batch["inputs"],
|
|
|
key_padding_mask=batch["attention_masks"],
|
|
|
)
|
|
|
+ labels = batch["labels"]
|
|
|
+ token_logits = outputs.token_logits
|
|
|
+ codebook_logits = outputs.codebook_logits
|
|
|
+
|
|
|
+ if self.use_dpo:
|
|
|
+ # Firtst half is positive, second half is negative
|
|
|
+ token_logits, negative_token_logits = token_logits.chunk(2)
|
|
|
+ codebook_logits, negative_codebook_logits = codebook_logits.chunk(2)
|
|
|
+ labels, negative_labels = labels.chunk(2)
|
|
|
|
|
|
# Generate labels
|
|
|
- labels = batch["labels"]
|
|
|
base_loss = F.cross_entropy(
|
|
|
- outputs.token_logits.reshape(-1, outputs.token_logits.size(-1)),
|
|
|
+ token_logits.reshape(-1, token_logits.size(-1)),
|
|
|
labels[:, 0].reshape(-1),
|
|
|
ignore_index=-100,
|
|
|
)
|
|
|
@@ -120,7 +168,7 @@ class TextToSemantic(L.LightningModule):
|
|
|
if self.model.config.num_codebooks != 0:
|
|
|
codebook_labels = labels[:, 1 : 1 + self.model.config.num_codebooks].mT
|
|
|
semantic_loss = F.cross_entropy(
|
|
|
- outputs.codebook_logits.reshape(-1, outputs.codebook_logits.size(-1)),
|
|
|
+ codebook_logits.reshape(-1, codebook_logits.size(-1)),
|
|
|
codebook_labels.reshape(-1),
|
|
|
ignore_index=-100,
|
|
|
)
|
|
|
@@ -129,6 +177,72 @@ class TextToSemantic(L.LightningModule):
|
|
|
else:
|
|
|
loss = base_loss
|
|
|
|
|
|
+ # If we use dpo
|
|
|
+ if self.use_dpo:
|
|
|
+ negative_codebook_labels = negative_labels[
|
|
|
+ :, 1 : 1 + self.model.config.num_codebooks
|
|
|
+ ].mT
|
|
|
+
|
|
|
+ positive_codebook_logps = self.get_batch_logps(
|
|
|
+ codebook_logits, codebook_labels
|
|
|
+ )
|
|
|
+ negative_codebook_logps = self.get_batch_logps(
|
|
|
+ negative_codebook_logits, negative_codebook_labels
|
|
|
+ )
|
|
|
+
|
|
|
+ # TODO: implement the reference model, avoid screwing up the gradients
|
|
|
+ dpo_loss = -F.logsigmoid(
|
|
|
+ (positive_codebook_logps - negative_codebook_logps) * self.dpo_beta
|
|
|
+ ).mean()
|
|
|
+
|
|
|
+ chosen_rewards = self.dpo_beta * positive_codebook_logps.detach()
|
|
|
+ rejected_rewards = self.dpo_beta * negative_codebook_logps.detach()
|
|
|
+ reward_accuracy = (
|
|
|
+ (positive_codebook_logps > negative_codebook_logps).float().mean()
|
|
|
+ )
|
|
|
+ chosen_rewards, rejected_rewards = (
|
|
|
+ chosen_rewards.mean(),
|
|
|
+ rejected_rewards.mean(),
|
|
|
+ )
|
|
|
+
|
|
|
+ loss = loss + dpo_loss
|
|
|
+
|
|
|
+ self.log(
|
|
|
+ f"{stage}/dpo_loss",
|
|
|
+ dpo_loss,
|
|
|
+ on_step=True,
|
|
|
+ on_epoch=False,
|
|
|
+ prog_bar=False,
|
|
|
+ logger=True,
|
|
|
+ )
|
|
|
+
|
|
|
+ self.log(
|
|
|
+ f"{stage}/chosen_rewards",
|
|
|
+ chosen_rewards,
|
|
|
+ on_step=True,
|
|
|
+ on_epoch=False,
|
|
|
+ prog_bar=False,
|
|
|
+ logger=True,
|
|
|
+ )
|
|
|
+
|
|
|
+ self.log(
|
|
|
+ f"{stage}/rejected_rewards",
|
|
|
+ rejected_rewards,
|
|
|
+ on_step=True,
|
|
|
+ on_epoch=False,
|
|
|
+ prog_bar=False,
|
|
|
+ logger=True,
|
|
|
+ )
|
|
|
+
|
|
|
+ self.log(
|
|
|
+ f"{stage}/reward_accuracy",
|
|
|
+ reward_accuracy,
|
|
|
+ on_step=True,
|
|
|
+ on_epoch=False,
|
|
|
+ prog_bar=False,
|
|
|
+ logger=True,
|
|
|
+ )
|
|
|
+
|
|
|
self.log(
|
|
|
f"{stage}/loss",
|
|
|
loss,
|
|
|
@@ -159,15 +273,13 @@ class TextToSemantic(L.LightningModule):
|
|
|
|
|
|
# Top-5 accuracy
|
|
|
if self.model.config.num_codebooks == 0:
|
|
|
- _, indices = outputs.token_logits.topk(5, dim=-1)
|
|
|
+ _, indices = token_logits.topk(5, dim=-1)
|
|
|
correct = indices.eq(labels[:, 0].unsqueeze(-1))
|
|
|
correct[labels[:, 0] == -100] = 0
|
|
|
correct = correct.sum()
|
|
|
accuracy = correct / (labels[:, 0] != -100).sum()
|
|
|
else:
|
|
|
- _, indices = outputs.codebook_logits.topk(5, dim=-1)
|
|
|
- # print(codebook_labels[0, :10], torch.argmax(outputs.codebook_logits[0, :10], dim=-1))
|
|
|
- # print(codebook_labels[codebook_labels != -100][:10], indices[codebook_labels != -100][:10])
|
|
|
+ _, indices = codebook_logits.topk(5, dim=-1)
|
|
|
correct = indices.eq(codebook_labels.unsqueeze(-1))
|
|
|
correct[codebook_labels == -100] = 0
|
|
|
correct = correct.sum()
|