Unconditional Diffusion With a UNET

Trying to rewrite the whole guided latent diffusion thing is hard, so instead here’s some smushy images of butterflies.
machine learning
code
Author

Daniel Claborne

Published

May 19, 2025

A while ago this whole guided diffusion thing made waves in the ML world and provided us with amazing images of anthropomorphized cucumbers walking chiwawas. Of course, I tried my best to understand this amazing technology as well as play around with various open-source implementations.

My post on vector-quantized VAE’s was the first part of trying to implement parts of these Rube-Goldberg machines. Unfortunately, life and laziness happened, in addition to the remaining components being a bit more difficult to implement imo. I decided to try and recreate a portion of a popular form of guided diffusion model called a UNET Ronneberger, Fischer, and Brox (2015) based on the implementation in the codebase for stable diffusion.

In the end, it turns out sometimes following some code pretty closely but trying to also ‘code it on your own’ is still kinda frustrating. The reasons they code things up a certain way is not always clear, and I ended up breaking things and had a rough time debugging various discrepancies and unintended memory issues with my broken implementations. Additionally, the UNET in stable diffusion processes the images after they’ve been transformed into a latent space by a VAE…. My initial attempts at trying to put these two pieces together produced some….uninspiring results as seen in Figure 1.

Figure 1: Look closely, you can see some butterflies in there

Anyway, this post is gonna be about my UNET implementation, and there will be some images of fake butterflies at the end! As always, I’m going to scold you for reading some random dude’s blog instead of looking at the following:

If you wanna skip to the code portion of this post, here it is, or you can see the associated colab notebook.

A Brief Primer on Diffusion Models

Diffusion models are based around the random process of….diffusion…in which an input has noise slowly added to it until it is pure noise. There’s some physics background on this but sorry physicists, the AI people are in charge of all science now and have even taken your nobel prizes. I present an incomplete description of the process here, but really just read Ho, Jain, and Abbeel (2020).

We consider a series of inputs \(x_t\): \(t \in {0,1,...T}\). They describe the diffusion process as \(p_{\theta}(x_0) := \int p_{\theta}(x_{0:T})dx_{1:T}\). They begin by defining the ‘reverse process’, that is, the process that takes noise and maps it back to the input. It is defined as a Markov chain with Gaussian transitions: \[ \begin{align} p_{\theta}(x_{0:T}):= p(x_T) \prod_{t=1}^T p_{\theta}(x_{t-1}\vert x_t) \quad \quad p_{\theta}(x_{t-1}\vert x_t) := \mathcal{N}(x_t;\mu_{\theta}(x_t,t), \Sigma_{\theta}(x_t, t)) \end{align} \label{eq1}\tag{eqn-1} \]

i.e. we have some image \(x_t\), and we ‘transition’ to a different (hopefully denoised) image \(x_{t-1}\) by adding a draw from a normal distribution with mean and variance dependent on \(x_t\) and the timestep \(t\).

The other part of the model is the forward process, which adds noise based on a series of coefficients \(\beta_1, \beta_2, ..., \beta_T \in (0,1)\):

\[ q(x_{1:T} \vert x_0):= \prod_{t=1}^T q(x_t \vert x_{t-1})\quad \quad q(x_t \vert x_{t-1}):= \mathcal{N}(x_t; \sqrt{1-\beta_t}x_{t-1}, \beta_t \mathbf{I}) \label{eq2}\tag{eqn-2} \]

That is, the forward process takes an input (say, an image) \(x_0\) and gradually transitions to a noise distribution centered at 0. If you look at the equation, it should make sense why this is: Each step has a mean that is closer to zero than the mean before it, and after sufficient steps, this is gonna be around 0 with high probability.

From here you can optimize the evidence lower bound as described in the VAE literature, except with \(x_0\) as the evidence and \(x_{1:T}\) is the latent representation \(z\). See my other post about VAE’s.

One particular property of the forward process is that we can sample an \(x_t\) at an arbitrary timestep in closed form:

\[q(x_t \vert x_0) = \mathcal{N}(x_t; \sqrt{\bar{\alpha}_t} x_0, (1-\bar{\alpha}_t)\mathbf{I}) \quad \quad \alpha_t = 1 - \beta_t \quad \quad \bar{\alpha}_t= \prod_{s=1}^t \alpha_s \label{eq3}\tag{eqn-3}\]

It is a nice exercise to try and show that this is the case … no I’m not going to show it, I did it once and forgot okay, look at (\(\ref{eq2}\)) and cascade that second equality all the way back to \(x_0\) to show that it is the case.

The objective reduces to something where we randomly sample elements of the forward process and optimize the evidence lower bound by optimizing the paramers of the model, which are currently \(\beta_t\) and \(\theta\). In Ho, Jain, and Abbeel (2020) They hold fixed \(\beta\), as well as reparametrize the elements of the reverse process \(p_{\theta}(x_{0:T})\), which we notice were defined by two parameterized functions \(\mu_{\theta}(x_t,t)\) and \(\Sigma_{\theta}(x_t, t)\). Well, first they actually just set the variance to be fixed to: \(\Sigma_{\theta}(x_t, t) = \sigma_t^2\mathbf{I}\), where they experiment with a couple schemes for the values of \(\sigma_t^2\) based on \(\beta_t\) which seemed to both work well.

They eventually get to the parametrization:

\[\mu_{\theta}(x_t, t) = \tilde{\mu}\left(x_t, \frac{1}{\sqrt{\bar{\alpha}_t}}(x_t - \sqrt{1 - \bar{\alpha}_t} \epsilon_{\theta}(x_t))\right) = \frac{1}{\alpha_t}(x_t - \frac{\beta_t}{\sqrt{1-\bar{\alpha}_t}}\epsilon_{\theta}(x_t, t))\]

Forget what \(\tilde{\mu}\) is, or read the paper for the details … the tldr of this whole thing is that what we need to optimize in the end is this \(\epsilon_{\theta}(x_t, t)\) function which predicts the noise component of \(x_t\). That is, the training simply consists of training a model to predict the noise component given an input \(x_t\) and timestep \(t\).

The math works out that to get the ‘previous sample’ \(x_{t-1} \sim p_{\theta}(x_{t-1} \vert x_t)\) we can simply compute: \[x_{t-1} = \frac{1}{\sqrt{\alpha_t}}(x_t - \frac{\beta_t}{\sqrt{1-\bar{\alpha}_t}}\epsilon_{\theta}(x_t, t)) + \sigma_t z \quad \quad z \sim \mathcal{N}(0, \mathbf{I})\label{eq4}\tag{eqn-4}\]

Mmmkay so this is the framework for training and sampling from an unconditional diffusion model. We will simply do the following:

  1. Represent \(\epsilon_{\theta}(x_t, t)\) as a UNET that takes in a noisy image (\(x_t\)) and \(t\) as a timestep embedding.
  2. Draw some \(\mathcal{N}(0, \mathbf{I})\) noise, some random timesteps, and produce noisy images according to (\(\ref{eq3}\))
  3. Make ‘time embeddings’ from the integer timesteps and shove the time embeddings and noisy image through the UNET.
  4. Compute MSE loss between the output and noise, compute loss, update.

Implementing the UNET

As for what I actually coded up for this, it was the UNET, the model which predicts the noise component from \(x_t\). I wanted to just…practice model building, so I built the UNET based on the OpenAI implementation in the codebase for stable diffusion. “So you copied it?”. No! I used it as a guide for developing my code. “So you copied it”. Sigh, fine yea pretty much, but even so it was a pain in the ass to actually get it to line up perfectly with their implementation:

  • My first attempt blew up the A100’s in Colab, I still have no idea why.
  • I had to meticulously go through each layer and figure out where discrepancies in layer/block sizes were. The stable diffusion code is a bit opaque…

Anyway, this post uses the code listed here, and I’ll go through some of it in this post, the notebook that uses the code in that repo is here.

The things to implement for the UNET:

  1. The time embeddings
  2. ResNet blocks that can incorporate a time embedding.
  3. Attention blocks (attention + feed-forward)
  4. Spatial transformers (basically expands the spatial dimensions into a ‘time’ dimension and attends over it)

Technically for the unconditional diffusion I don’t need 3. and 4. but I’ll include them anyway. Okay first the time embedding, this is almost literally copied from the stable diffusion repo. Lets take a look:

Code
def time_embeddings(t: torch.tensor, out_dim: int, max_period: int = 10000):
    """Create timestep embeddings from a vector of ints representing timesteps
    Args:
        t (torch.Tensor): Tensor of shape (B,) containing timesteps
        out_dim (int): Dimension of the output embeddings
        max_period (int): Maximum period for the sine and cosine functions

    Returns:
        torch.Tensor: Tensor of shape (B, out_dim) containing the timestep embeddings
    """
    half = out_dim // 2
    denom = torch.exp(
        -torch.tensor(max_period).log() * torch.arange(0, half, dtype=torch.float32) / half
    )

    denom = denom.to(t.device)

    phases = t[:, None] * denom[None]

    # concatentate to form a tensor of shape B x out_dim
    out_emb = torch.cat([phases.sin(), phases.cos()], dim=-1)

    # if out_dim is odd, add a zero column
    if out_dim % 2:
        out_emb = torch.cat([out_emb, torch.zeros_like(out_emb[:, :1])], dim=-1)

    return out_emb

This is essentially creating these \(d\)-dimensional sine-cosine embeddings according to

\[ \begin{align} &P(t,j) = sin(\frac{t}{10,000^{2j/d}}), j \in {1,2,... d//2} \\ &P(t,j) = cos(\frac{t}{10,000^{2j/d}}), j \in {d//2+1, d//2+2... d} \end{align} \]

Yes, it really is, look at the form of whats inside the function and do some algebra, it is the same. Am I going to show you? No.

Okay we have these \(d\) dimensional embeddings that will get injected into each of the ResNet blocks. We’ll see how that happens below. The ResNet blocks consist of the following steps:

  1. Group normalization
  2. SiLU activation
  3. Resampling (either up-sample or down-sample), cutting dimension of the image in half.
  4. Convolution layer
  5. Addition of time embeddings
  6. Output block consisting of groupnorm -> SiLU -> dropout -> convolution

These layers are used to reduce the spatial dimension of the image in 3. and increase (or decrease) the channel dimension usually in 6. The time embeddings work by first projecting the \(d\) dimensional embedding to the channel dimension via a linear layer and them element-wise adding this new embedding to every spatial dimension.

Code
class ResBlock(nn.Module):
    def __init__(self, channels_in, d_emb, channels_out=None, resample=None, dropout=0.0):
        super().__init__()

        self.channels_in = channels_in
        self.channels_out = channels_out or channels_in

        self.group_norm_in = nn.GroupNorm(32, channels_in)

        self.conv_in = conv_nd(2, channels_in, self.channels_out, 3, padding=1)
        self.conv_out = conv_nd(2, self.channels_out, self.channels_out, 3, padding=1)

        if resample == "down":
            self.resample_h = DownSample(channels_in, dims=2, use_conv=False)
            self.resample_x = DownSample(channels_in, dims=2, use_conv=False)
        elif resample == "up":
            self.resample_h = Upsample(channels_in, dims=2, use_conv=False)
            self.resample_x = Upsample(channels_in, dims=2, use_conv=False)
        else:
            self.resample_x = self.resample_h = nn.Identity()

        if self.channels_out == channels_in:
            self.skip_connection = nn.Identity()
        else:
            self.skip_connection = conv_nd(2, channels_in, self.channels_out, 1)

        self.time_emb = nn.Sequential(nn.SiLU(), nn.Linear(d_emb, self.channels_out))

        self.out_block = nn.Sequential(
            nn.GroupNorm(32, self.channels_out), nn.SiLU(), nn.Dropout(dropout), self.conv_out
        )

    def forward(self, x: torch.tensor, emb: torch.tensor):
        """
        Args:
            x (torch.tensor): Input tensor of shape (B, C, H, W)
            emb (torch.tensor): Time embedding of dimension (B, D)

        Returns:
            torch.tensor: Output tensor of shape (B, C*, H*, W*)
        """

        h = self.group_norm_in(x)
        h = F.silu(h)
        h = self.resample_h(h)
        h = self.conv_in(h)

        emb = self.time_emb(emb)
        h = h + emb[:, :, None, None]  # expand spatial dims, add along channel dim

        h = self.out_block(h)

        x = self.resample_x(x)

        return h + self.skip_connection(x)

In the above, the Upsample and Downsample classes are implemented either as convolutions with stride 2, or a simple average pooling operation that cuts the spatial dimension in half. In the forward implementation, you can see I’m adding the time embedding to the channel dimension via broadcasting over all spatial dimensions.

The other main layer type in the UNET is the spatial transformer, which, as I said before, is not really necessary in this unguided (no text embeddings) version of the diffusion model. There’s uh, a few pieces to this. There’s a standard cross-attention layer, which is meant to take in context. In this case, we are not using context embeddings, so we fix the query dimension to be the same as the key and value dimensions and simply have the inputs to the QKV projection layers be the same input….see below:

Cross-attention layers
class FeedForward(nn.Sequential):
    """Feed forward module for the attention block.  Has two  linear layers with a GeLU and dropout layer in between.

    Args:
        d_in (int): Input dimension to the first linear layer.
        d_out (int): Output dimension of the second linear layer.
        mult (int): Multiplier (of the input dimension) for the hidden dimension. Default: 4
        dropout (float): Dropout rate. Default: 0.1
    """

    def __init__(self, d_in, d_out, mult=4, dropout=0.1):
        super().__init__()

        self.proj_in = nn.Linear(d_in, int(d_in * mult))
        self.gelu = nn.GELU()
        self.dropout = nn.Dropout(dropout)
        self.proj_out = nn.Linear(int(d_in * mult), d_out)


# Cross attention module, from scratch
class CrossAttention(nn.Module):
    """Basic cross attention module.

    Args:
        d_q (int): Input dimension of the query.
        d_model (int): Inner dimension of the QKV projection layers. Default: 512
        d_cross (int): Input dimension of the key and value inputs, for cross attention. Default: None
        n_heads (int): Number of attention heads. Default: 8
        dropout (float): Dropout rate. Default: 0.0
    """

    def __init__(self, d_q, d_model=512, d_cross=None, n_heads=8, dropout=0.0):
        super().__init__()

        assert d_model % n_heads == 0, f"n_heads {n_heads} must divide d_model {d_model}"

        if d_cross is None:
            d_cross = d_q

        self.proj_q = nn.Linear(d_q, d_model, bias=False)
        self.proj_k = nn.Linear(d_cross, d_model, bias=False)
        self.proj_v = nn.Linear(d_cross, d_model, bias=False)

        self.proj_out = nn.Linear(d_model, d_q)
        self.dropout = nn.Dropout(dropout)
        self.n_heads = n_heads

    def forward(self, x, context=None, mask=None):
        # prevent einops broadcasting
        if context is not None:
            assert (
                x.shape[0] == context.shape[0]
            ), f"Batch size of x and context must match, found {x.shape[0]} and {context.shape[0]}"

        if context is None:
            context = x

        q = self.proj_q(x)
        k = self.proj_k(context)
        v = self.proj_v(context)

        # at this point we've already flattened the h/w of the input
        q = einops.rearrange(q, "b n (h d) -> b h n d", h=self.n_heads)
        k = einops.rearrange(k, "b m (h d) -> b h m d", h=self.n_heads)
        v = einops.rearrange(v, "b m (h d) -> b h m d", h=self.n_heads)

        qk = einops.einsum(q, k, "b h n d, b h m d -> b h n m") / (q.shape[-1] ** 0.5)

        if mask is not None:
            # mask initially of shape b x m, need to expand to b x h x 1 x m
            mask = einops.repeat(mask, "b m -> b h () m", h=self.n_heads)
            min_value = -torch.finfo(qk.dtype).max
            qk.masked_fill_(~mask, min_value)

        qk = F.softmax(qk, dim=-1)
        out = einops.einsum(qk, v, "b h n m, b h m d -> b h n d")
        out = einops.rearrange(out, "b h n d -> b n (h d)")

        out = self.dropout(self.proj_out(out))

        return out

The spatial transformer is essentially a wrapper around blocks that contain some number of these cross-attention blocks. In the forward pass, we flatten the spatial dimensions into a single dimension and attend over this dimension:

Code
class SpatialTransformer(nn.Module):
    """Spatial transformer module for the UNet architecture.  Contains cross-attention layers that attend over the spatial dimensions of an image, while ingesting cross-attention embeddings from e.g. a text embedding model.

    Args:
        in_channels (int): Number of input channels.
        d_q (int): Input dimension of the query.
        d_cross (int): Input dimension of the key and value inputs, for cross attention. Default: None
        d_model (int): Inner dimension of the QKV projection layers. Default: 512
        n_heads (int): Number of attention heads. Default: 8
        dropout (float): Dropout rate. Default: 0.0
        depth (int): Number of attention blocks. Default: 1
    """

    def __init__(
        self, in_channels, d_q, d_cross=None, d_model=512, n_heads=8, dropout=0.0, depth=1
    ):
        super().__init__()

        self.in_channels = in_channels
        self.d_q = d_q
        self.d_cross = d_cross
        self.d_model = d_model
        self.n_heads = n_heads

        self.norm = torch.nn.GroupNorm(
            num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
        )

        self.conv_in = nn.Conv2d(self.in_channels, d_q, kernel_size=1, stride=1, padding=0)

        self.blocks = nn.ModuleList(
            [AttentionBlock(d_q, d_cross, d_model, n_heads, dropout) for _ in range(depth)]
        )

        self.conv_out = nn.Conv2d(d_q, in_channels, kernel_size=1, stride=1, padding=0)

    def forward(self, x, context=None):
        x_in = x

        x = self.norm(x)
        x = self.conv_in(x)  # B, d_q, H, W

        b, d_q, h, w = x.shape

        x = einops.rearrange(x, "b c h w -> b (h w) c")  # attention mechanism expects B T C

        for b in self.blocks:
            x = b(x, context)

        x = einops.rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
        x = self.conv_out(x)
        return x + x_in

MMkay, the definition of the full UNET model essentially builds these blocks into the UNET in a loop:

Code
class EmbeddingWrapper(nn.Sequential):
    """
    Wrapper for a sequence of layers that take an input tensor and optionally either a tensor of time embeddings or context embeddings.
    """

    def forward(self, x: torch.tensor, t_emb: torch.tensor = None, context: torch.tensor = None):
        """
        Args:
            x (torch.tensor): main input image tensor B x C x H x W
            t_emb (torch.tensor): time embedding B x t_emb_dim
            context (torch.tensor): context tensor B x d_cross to be passed as context to cross-attention mechanism.

        Returns:
            torch.tensor: output tensor
        """
        for layer in self:
            if isinstance(layer, ResBlock):
                x = layer(x, t_emb)
            elif isinstance(layer, SpatialTransformer):
                x = layer(x, context)
            else:
                x = layer(x)

        return x


class UNET(nn.Module):
    """A simpler implementation of the UNET at https://github.com/CompVis/stable-diffusion/blob/21f890f9da3cfbeaba8e2ac3c425ee9e998d5229/ldm/modules/diffusionmodules/openaimodel.py
    Here I force the use of spatial attention when adding the guidance layers.

    Args:
        channels_in (int): number of input channels
        channels_model (int): number of initial channels which is then multiplied by the values in `channel_mults`
        channels_out (int): number of output channels
        context_dim (int): context dimension when performing guided diffusion
        d_model (int): embedding dimension of the attention layers
        t_emb_dim (int): time embedding dimension
        channel_mults (list): list of channel multipliers which will determine the number of channels at each block depending on `channels_model`
        attention_resolutions (list): list of attention resolutions where attention is applied
        dropout (float): dropout rate
    """

    def __init__(
        self,
        channels_in: int,
        channels_model: int,
        channels_out: int,
        context_dim: int,
        d_model: int,
        t_emb_dim: int = None,
        channel_mults: list[int] = [1, 2, 4, 8],
        attention_resolutions: list[int] = [2, 4],
        dropout: float = 0.0,
    ):
        super().__init__()

        self.channel_mults = channel_mults
        self.channels_in = channels_in
        self.channels_model = channels_model
        self.context_dim = context_dim
        self.t_emb_dim = t_emb_dim or channels_model

        # will fill up the downsampling and upsampling trunks in for loops
        self.down_blocks = nn.ModuleList(
            [EmbeddingWrapper(nn.Conv2d(channels_in, channels_model, kernel_size=3, padding=1))]
        )
        self.up_blocks = nn.ModuleList()

        track_chans = [channels_model]
        ch_in = channels_model

        # create each downsampling trunk
        for i, mult in enumerate(channel_mults):
            # lets assume 1 depth for each downsampling layer
            # append the reblock first, then spatial transformer
            # what is timestep dimension???? t_emb -> d_t -> d_model
            ch_out = channels_model * mult

            resblock = ResBlock(channels_in=ch_in, d_emb=self.t_emb_dim, channels_out=ch_out)
            layers = [resblock]

            if i in attention_resolutions:
                sp_transformer = SpatialTransformer(
                    in_channels=ch_out,
                    d_q=ch_out,
                    d_cross=context_dim if context_dim else ch_out,
                    d_model=d_model,
                    dropout=dropout,
                    n_heads=2,
                )
                layers.append(sp_transformer)

            self.down_blocks.append(EmbeddingWrapper(*layers))

            track_chans.append(ch_out)

            # downsample after every mult except the last
            if i != len(channel_mults) - 1:
                res_ds = ResBlock(
                    ch_out, d_emb=self.t_emb_dim, channels_out=ch_out, resample="down"
                )
                self.down_blocks.append(EmbeddingWrapper(res_ds))
                track_chans.append(ch_out)
                ch_in = ch_out

        # middle block, this is Res, Attention, Res
        # ch_out is the last channel dimension for constructing the downsampling layers
        self.middle_block = EmbeddingWrapper(
            ResBlock(channels_in=ch_out, d_emb=self.t_emb_dim),
            SpatialTransformer(
                in_channels=ch_out,
                d_q=ch_out,
                d_cross=context_dim if context_dim else ch_out,
                d_model=d_model,
                dropout=dropout,
            ),
            ResBlock(channels_in=ch_out, d_emb=self.t_emb_dim),
        )

        # upsampling block
        # this block has 2x the channels, why?  Because of the UNET architecture, we concatenate the channels from the corresponding layer of the downsample section.
        # There is also an additional res + attention block, this additional block 'matches' the channel dimension of the downsampling module in the downsampling trunk.
        for i, mult in reversed(list(enumerate(channel_mults))):
            # We assume there's two resblocks here for simplicity

            # first res block
            down_ch = track_chans.pop()

            # We have two of these, one that matches the Res + Transformer block, and another that matches the downsampling block

            # first res block
            res1 = ResBlock(
                ch_out + down_ch, d_emb=self.t_emb_dim, channels_out=channels_model * mult
            )

            # this block will output this many channels, we set it here since we want the next iteration to start the channels of the previous block.
            ch_out = channels_model * mult

            layers = [res1]

            if i in attention_resolutions:
                sp_trf = SpatialTransformer(
                    in_channels=channels_model * mult,
                    d_q=channels_model * mult,
                    d_cross=context_dim if context_dim else channels_model * mult,
                    d_model=d_model,
                    dropout=dropout,
                )
                layers.append(sp_trf)

            self.up_blocks.append(EmbeddingWrapper(*layers))

            down_ch = track_chans.pop()

            # and again, same dimension ...
            layers = []

            layers.append(ResBlock(ch_out + down_ch, d_emb=self.t_emb_dim, channels_out=ch_out))

            if i in attention_resolutions:
                layers.append(
                    SpatialTransformer(
                        in_channels=ch_out,
                        d_q=ch_out,
                        d_cross=context_dim if context_dim else ch_out,
                        d_model=d_model,
                        dropout=dropout,
                    )
                )

            # ... with an upsampling layer at all but the last, since at the last we are matching the initial convolutional layer and a res + spatial transformer block, not an upsampling layer
            if i > 0:
                layers.append(
                    ResBlock(ch_out, d_emb=self.t_emb_dim, channels_out=ch_out, resample="up")
                )

            self.up_blocks.append(EmbeddingWrapper(*layers))

        # output block that normalizes and maps back to
        self.out_block = nn.Sequential(
            nn.GroupNorm(32, channels_model),
            nn.SiLU(),
            nn.Conv2d(channels_model, channels_out, kernel_size=3, padding=1),
        )

    def forward(self, x, timesteps=None, context=None):
        """
        Args:
            x (torch.tensor): input tensor
            timesteps (torch.tensor): Size (B,) tensor containing time indices to be turned into embeddings.
            context (torch.tensor): context tensor
        """
        if context is None:
            assert self.context_dim is None, "Must pass context if context_dimension is set"
        else:
            assert (
                self.context_dim is not None
            ), "You must set context_dim when creating the model if planning on passing context embeddings."

        if timesteps is not None:
            timesteps = time_embeddings(timesteps, self.t_emb_dim)

        # downsample
        downsampled = []
        for block in self.down_blocks:
            x = block(x, timesteps, context)
            downsampled.append(x)

        # middle block
        x = self.middle_block(x, timesteps, context)

        # upsample
        for block in self.up_blocks:
            x = torch.cat([x, downsampled.pop()], dim=1)
            x = block(x, timesteps, context)

        x = self.out_block(x)
        return x

Bleh, what a mess. You can look through the loop, but essentially we identify the number of input and output channels, as well as a ‘model channels’ dimension. This is the channel dimension which is multiplied by the i-th element of channel_mults at the i-th level of the UNET when it is downsampling, and in reverse (x8, x4 ,x2 ,x1) when upsampling. We additionally choose the levels at which we perform context-embedding with attention_resolutions, but again, we just ignore that in the case by setting context_dim to None.

Training

Just see the notebook if you wanna run something, but I’ll go over the inner part of training as well.

Code
for step, batch in enumerate(train_dataloader):
    with torch.autocast("cuda"):
        input = batch['images'].to(device)

        timepoints = torch.LongTensor(input.shape[0]).random_(0, dconfig.n_timesteps).to(device) # 1

        noise = torch.randn_like(input).to(device) # 2

        noisy_input = noise_scheduler.add_noise(input, noise, timepoints) # 3

        with accelerator.accumulate(model):
            out = model(noisy_input, timesteps = timepoints) # 4

            # compute MSE between this and the NOISE
            loss = torch.nn.functional.mse_loss(out, noise) # 5

            optimizer.zero_grad()
            accelerator.backward(loss)

            accelerator.clip_grad_norm_(model.parameters(), 1.0)

            optimizer.step()
            lr_scheduler.step()

            epoch_loss.append(loss.item())

Ignore the stuff that isn’t defined, we are doing the following:

  1. We create a random set of integer timepoints between 0 and a user-selected maximum number of timepoints, here 1000.
  2. We sample some \(\mathcal{N}(0, \mathbf{I})\) noise.
  3. We use \(\ref{eq3}\) to produce noisy samples from the forward process at random timepoints. The diffusers package has implemented this computation.
  4. Use our UNET to preduct the noise from the noisy inputs, also passing the timepoints, which internally get embedded as we discussed.
  5. We compute the MSE loss between the predicted noise and the \(\mathcal{N}(0, \mathbf{I})\) noise and perform model updates.

Sampling

To to this, we simply implement \(\ref{eq4}\), using implementations in the diffusers package:

Code
# Do the diffusion process, start with noise...
cur_img = torch.randn(16, 3, 128, 128, dtype=torch.float32).to(device)

pbar = tqdm.tqdm(noise_scheduler.timesteps)

for t in pbar:
  with torch.no_grad():
      timesteps = torch.LongTensor([t]*cur_img.shape[0]).to(device)
      # predict the noise component
      noisy_residual = model(cur_img, timesteps)

  # reverse the diffusion process using the predicted noise
  previous_noisy_sample = noise_scheduler.step(noisy_residual, t, cur_img).prev_sample
  cur_img = previous_noisy_sample

Specifically, this noise_sheduler object is implementing (\(\ref{eq4}\)), and you can see it ingests everything needed in the equation. I’m assuming \sigma_t is set to something internally, perhaps \(\beta_t\) as suggested in the paper, though there are different schedulers that probably implement different schemes. Would be a nice exercise to dig in and verify whats going on in there…

Anyway, after training and producing samples on the Smithsonian Butterflies dataset, we can produce some weird butterflies from pure noise:

Figure 2: Honestly, kind of an interesting watercolor look to em.

Next Up…

Okay so now I want to actually try this with some text guidance, and with the VAE at the start. The idea is essentially, instead of an image, we’ll ingest the lower-dimensional representation at the bottleneck of a pretrained VAE, and actually add some context embeddings via some pretrained text embedding model like CLIP (Radford et al. (2021)). Also I need a dataset that actually has image-text pairs…sigh.

References

Ho, Jonathan, Ajay Jain, and Pieter Abbeel. 2020. “Denoising Diffusion Probabilistic Models.” arXiv. https://doi.org/10.48550/arXiv.2006.11239.
Radford, Alec, Jong Wook Kim, Chris Hallacy, Aditya Ramesh, Gabriel Goh, Sandhini Agarwal, Girish Sastry, et al. 2021. “Learning Transferable Visual Models From Natural Language Supervision.” arXiv. https://doi.org/10.48550/arXiv.2103.00020.
Rombach, Robin, Andreas Blattmann, Dominik Lorenz, Patrick Esser, and Björn Ommer. 2022. “High-Resolution Image Synthesis with Latent Diffusion Models.” arXiv. https://doi.org/10.48550/arXiv.2112.10752.
Ronneberger, Olaf, Philipp Fischer, and Thomas Brox. 2015. “U-Net: Convolutional Networks for Biomedical Image Segmentation.” arXiv. https://doi.org/10.48550/arXiv.1505.04597.