|
|
@@ -68,6 +68,20 @@ class TextToSemantic(L.LightningModule):
|
|
|
]
|
|
|
)
|
|
|
|
|
|
+ if hasattr(self.model, "fast_layers"):
|
|
|
+ # Dual-AR model
|
|
|
+ linears.extend([(self.model, "fast_output")])
|
|
|
+
|
|
|
+ for layer in self.model.fast_layers:
|
|
|
+ linears.extend([(layer.attention, "wqkv"), (layer.attention, "wo")])
|
|
|
+ linears.extend(
|
|
|
+ [
|
|
|
+ (layer.feed_forward, "w1"),
|
|
|
+ (layer.feed_forward, "w2"),
|
|
|
+ (layer.feed_forward, "w3"),
|
|
|
+ ]
|
|
|
+ )
|
|
|
+
|
|
|
for module, layer in linears:
|
|
|
updated_linear = lora.Linear(
|
|
|
in_features=getattr(module, layer).in_features,
|
|
|
@@ -162,6 +176,8 @@ class TextToSemantic(L.LightningModule):
|
|
|
return (per_token_logps * loss_mask).sum(-1)
|
|
|
|
|
|
def _step(self, batch, batch_idx, stage: str):
|
|
|
+ is_train = stage == "train"
|
|
|
+
|
|
|
# Do positive and negative samples in the same batch to speed up training
|
|
|
labels = batch["labels"]
|
|
|
outputs = self.model(
|
|
|
@@ -224,8 +240,8 @@ class TextToSemantic(L.LightningModule):
|
|
|
self.log(
|
|
|
f"{stage}/dpo_loss",
|
|
|
dpo_loss,
|
|
|
- on_step=True,
|
|
|
- on_epoch=False,
|
|
|
+ on_step=is_train,
|
|
|
+ on_epoch=not is_train,
|
|
|
prog_bar=False,
|
|
|
logger=True,
|
|
|
)
|
|
|
@@ -233,8 +249,8 @@ class TextToSemantic(L.LightningModule):
|
|
|
self.log(
|
|
|
f"{stage}/chosen_rewards",
|
|
|
chosen_rewards,
|
|
|
- on_step=True,
|
|
|
- on_epoch=False,
|
|
|
+ on_step=is_train,
|
|
|
+ on_epoch=not is_train,
|
|
|
prog_bar=False,
|
|
|
logger=True,
|
|
|
)
|
|
|
@@ -242,8 +258,8 @@ class TextToSemantic(L.LightningModule):
|
|
|
self.log(
|
|
|
f"{stage}/rejected_rewards",
|
|
|
rejected_rewards,
|
|
|
- on_step=True,
|
|
|
- on_epoch=False,
|
|
|
+ on_step=is_train,
|
|
|
+ on_epoch=not is_train,
|
|
|
prog_bar=False,
|
|
|
logger=True,
|
|
|
)
|
|
|
@@ -251,8 +267,8 @@ class TextToSemantic(L.LightningModule):
|
|
|
self.log(
|
|
|
f"{stage}/reward_accuracy",
|
|
|
reward_accuracy,
|
|
|
- on_step=True,
|
|
|
- on_epoch=False,
|
|
|
+ on_step=is_train,
|
|
|
+ on_epoch=not is_train,
|
|
|
prog_bar=False,
|
|
|
logger=True,
|
|
|
)
|
|
|
@@ -260,8 +276,8 @@ class TextToSemantic(L.LightningModule):
|
|
|
self.log(
|
|
|
f"{stage}/loss",
|
|
|
loss,
|
|
|
- on_step=True,
|
|
|
- on_epoch=False,
|
|
|
+ on_step=is_train,
|
|
|
+ on_epoch=not is_train,
|
|
|
prog_bar=True,
|
|
|
logger=True,
|
|
|
)
|
|
|
@@ -269,8 +285,8 @@ class TextToSemantic(L.LightningModule):
|
|
|
self.log(
|
|
|
f"{stage}/base_loss",
|
|
|
base_loss,
|
|
|
- on_step=True,
|
|
|
- on_epoch=False,
|
|
|
+ on_step=is_train,
|
|
|
+ on_epoch=not is_train,
|
|
|
prog_bar=False,
|
|
|
logger=True,
|
|
|
)
|
|
|
@@ -278,8 +294,8 @@ class TextToSemantic(L.LightningModule):
|
|
|
self.log(
|
|
|
f"{stage}/semantic_loss",
|
|
|
semantic_loss,
|
|
|
- on_step=True,
|
|
|
- on_epoch=False,
|
|
|
+ on_step=is_train,
|
|
|
+ on_epoch=not is_train,
|
|
|
prog_bar=False,
|
|
|
logger=True,
|
|
|
)
|
|
|
@@ -289,8 +305,8 @@ class TextToSemantic(L.LightningModule):
|
|
|
self.log(
|
|
|
f"{stage}/top_5_accuracy",
|
|
|
accuracy,
|
|
|
- on_step=True,
|
|
|
- on_epoch=False,
|
|
|
+ on_step=is_train,
|
|
|
+ on_epoch=not is_train,
|
|
|
prog_bar=True,
|
|
|
logger=True,
|
|
|
)
|
|
|
@@ -304,8 +320,8 @@ class TextToSemantic(L.LightningModule):
|
|
|
self.log(
|
|
|
f"{stage}/top_5_accuracy_in",
|
|
|
accuracy,
|
|
|
- on_step=True,
|
|
|
- on_epoch=False,
|
|
|
+ on_step=is_train,
|
|
|
+ on_epoch=not is_train,
|
|
|
prog_bar=True,
|
|
|
logger=True,
|
|
|
)
|