|
@@ -129,7 +129,6 @@ class ConvNext1DModel(ModelMixin, ConfigMixin):
|
|
|
num_layers: int = 20,
|
|
num_layers: int = 20,
|
|
|
dilation_cycle_length: int = 4,
|
|
dilation_cycle_length: int = 4,
|
|
|
time_embedding_type: str = "positional",
|
|
time_embedding_type: str = "positional",
|
|
|
- condition_dim: Optional[int] = None,
|
|
|
|
|
):
|
|
):
|
|
|
super().__init__()
|
|
super().__init__()
|
|
|
|
|
|
|
@@ -157,7 +156,7 @@ class ConvNext1DModel(ModelMixin, ConfigMixin):
|
|
|
timestep_input_dim,
|
|
timestep_input_dim,
|
|
|
intermediate_dim,
|
|
intermediate_dim,
|
|
|
act_fn="silu",
|
|
act_fn="silu",
|
|
|
- cond_proj_dim=condition_dim,
|
|
|
|
|
|
|
+ cond_proj_dim=None, # No conditional projection for now
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
# Project to intermediate dim
|
|
# Project to intermediate dim
|
|
@@ -219,9 +218,12 @@ class ConvNext1DModel(ModelMixin, ConfigMixin):
|
|
|
|
|
|
|
|
# 1. time
|
|
# 1. time
|
|
|
t_emb = self.time_proj(timestep)
|
|
t_emb = self.time_proj(timestep)
|
|
|
- t_emb = self.time_mlp(sample=t_emb[:, None], condition=condition.mT).mT
|
|
|
|
|
|
|
+ t_emb = self.time_mlp(t_emb)[..., None]
|
|
|
|
|
|
|
|
# 2. pre-process
|
|
# 2. pre-process
|
|
|
|
|
+ if condition is not None:
|
|
|
|
|
+ sample = torch.cat([sample, condition], dim=1)
|
|
|
|
|
+
|
|
|
x = self.in_proj(sample)
|
|
x = self.in_proj(sample)
|
|
|
|
|
|
|
|
if sample_mask.ndim == 2:
|
|
if sample_mask.ndim == 2:
|