class SinusoidalTimeEmbedding(nn.Module):
"""Transformer-style positional encoding of the diffusion timestep.
A single network handles every t in {0, ..., T-1}, so it needs to know
*which* noise level it is denoising. The sin/cos embedding (used in
transformers and the original DDPM paper) gives a smooth, continuous
representation of t at many frequencies.
"""
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, t):
half = self.dim // 2
# Geometric series of frequencies spanning short to long timescales.
freqs = torch.exp(-np.log(10000) * torch.arange(half, device=t.device) / half)
emb = t.float().unsqueeze(-1) * freqs
return torch.cat([emb.sin(), emb.cos()], dim=-1)
class NoisePredictor(nn.Module):
"""eps_theta(x_t, t): predicts the noise that produced x_t.
The network output is NOT the reverse-step mean mu_theta; it is the
estimated 2-D noise vector eps_hat. The mean is computed analytically
from eps_hat in the sampling loop via the Ho et al. reparameterisation
(see the formula in the sampling cell). Equivalent parameterisations
exist (predict mu directly, or predict x_0), but eps-prediction is
better-conditioned because eps ~ N(0, I) at every timestep.
A small MLP suffices for 2-D data; for images this becomes a U-Net.
"""
def __init__(self, time_dim=64, hidden=128):
super().__init__()
# Project the raw sin/cos embedding through a tiny MLP (standard DDPM trick).
self.time_emb = nn.Sequential(
SinusoidalTimeEmbedding(time_dim),
nn.Linear(time_dim, time_dim), nn.SiLU(),
nn.Linear(time_dim, time_dim),
)
# Concatenate (x_t, time-embedding) -> 4-layer MLP. Output is the
# 2-D noise estimate eps_hat (same shape as x_t), NOT the reverse-step mean.
self.net = nn.Sequential(
nn.Linear(2 + time_dim, hidden), nn.SiLU(),
nn.Linear(hidden, hidden), nn.SiLU(),
nn.Linear(hidden, hidden), nn.SiLU(),
nn.Linear(hidden, 2), # eps_hat in R^2; reverse mean derived from it
)
def forward(self, x, t):
# Concatenate the noisy point with its timestep embedding.
h = torch.cat([x, self.time_emb(t)], dim=-1)
return self.net(h)
model = NoisePredictor().to(device)