DEV Community

Cover image for Tiny Diffusion
Unica2804
Unica2804

Posted on • Originally published at Medium

Tiny Diffusion

Have you ever wondered how the diffusion model works? I also wondered about it for a long time. It's so fancy; you type a prompt and, magically, a picture or video is created within a few seconds to minutes.
Introduction

Diffusion is a process of transferring energy from a high-energy state to a lower-energy state, as described in thermodynamics. The Diffusion model we use follows a similar technique. Imagine dropping an ink into a glass of water. At first, a shape forms, then slowly the ink diffuses with water and turns into a uniform blue color. In our model, we did a similar thing; we turned a perfectly good image into total noise by adding a little bit of noise in each step. This process is called Forward Diffusion.

Forward process

Now, coming back to the example of a glass of water. What if we filmed the whole process of ink diffusing and reversed it? Then we would see the blue-colored water reverse and form the initial blue dot shape. In the case of our model, we train our model to predict the noise and subtract the noise from the original image by showing it the process of adding noise thousands of times. This process is called Reverse Diffusion.

These two are the core philosophies behind diffusion, though there is a concept of a scheduler that decides how noise is added over time to the image and the algorithm that helps us reverse the noise, etc. Though nowadays people have moved from U net diffusion and moved to Diffusion Transformers.

Moving From Images to Motion

This is the Interesting part for a 2D diffusion model, it just needs to fix one single frame, but while we are dealing with time as a third Dimension for a video, we need consistency, i.e. If the model denoises Frame 1 into a "3" and Frame 2 into a "5," the animation would look like a flickering nightmare. For this, we need temporal attention, i.e., the model shouldn't only look at the pixel but look through time before and after the frame. For this, we ran Euclidean distance algorithm with masking to convert a static 28x28 MNIST image into a 15x28x28 batch of images. So now a single MNIST digit is a batch of 15 images that cover the digit's transformation during different timesteps. So now we have solved the problem of Data and we have to make a modified architecture.

EDT transformed data

Motion MNIST Architecture

The architecture follows the architecture from the DDPM paper, but with some modifications since we are dealing with video. Normally, you would use a 3D convolution layer and pass everything to it. But due to memory constraint I had to choose a different technique of splitting it into 2. One layer is responsible for spatial data, and another for temporal(time) data. That's why I used a kernel of (1x3x3) and (3x1x1).

nn.Conv3d(in_ch,out_ch,kernel_size=(1,3,3), padding=(0,1,1)),
nn.BatchNorm3d(out_ch),
nn.ReLU(),
nn.Conv3d(out_ch,out_ch,kernel_size=(3,1,1), padding=(1,0,0))
Enter fullscreen mode Exit fullscreen mode

For the time embedding in the DDPM paper, you will see a Sinusoidal embedding inspired by the transformer positional embedding. But I used a simple time embedding, and it worked.

self.time_mlp = nn.Sequential(
            nn.Linear(1,t_dim),
            nn.ReLU(),
            nn.Linear(t_dim,t_dim)
        )
Enter fullscreen mode Exit fullscreen mode

It worked because the problem was easy, i.e. It was predicting geometry since the data is EDT (Euclidean Distance Transform), which is a linear function and a low-frequency task. For a real video, there are low-frequency and high-frequency components. The network has to understand physics, which is not at all smooth. So that's why Sinusoidal embedding is required in that case.
So after that, everything is a simple U-net architecture. Where the image data is compressed, and more non-linearity is increased to learn complex patterns. Then it's upscaled again and concatenated with a similar downscaled step using residual connections. So it can learn both the Simple and Complex Structure.

Architecture Image

Fun fact: UNET was first introduced for medical image segmentation.

Training loop

It consists of a scheduler and a normal training loop. The Scheduler decides how much noise should be added at each step. In the original DDPM theory, to get to step 500, you'd logically have to add noise 500 times in a row. That is slow and costly. To fix this, the authors used a property of Gaussians, adding Gaussian noise to Gaussian noise just results in a bigger Gaussian. Instead of stepping through the mud 1,000 times, we use a formula that lets us teleport from the clean image x_0 directly to any noisy step x_t.
The Refinement: It's not just that noise is treated like a constant, but rather that we can express the noisy image at step t as a linear combination of the original image and one single chunk of noise.

Scheduler Equation

The loss function is simple: get the frame where the digit is full, and add noise at that step. Predict the noise against the original noise, calculate Mean Squared Error, and repeat until the loss minimizes.

Inference

This is where the actual magic usually happens. During Training we jumped steps using the reparameterization trick. But here we can't do that to reverse the noise. Since the model doesn't know what the image is, i.e. It doesn't have any idea about the global structure, whether it will be '3' or '5'. It only knows how to predict the amount of noise in each step. So it has to go through 1000 steps as per DDPM. But there is a faster method also, which helps us to do that in 50 steps, it is known as DDIM. After the model gives the amount of noise to subtract, the Sampler subtracts the noise from the image.

Sampler Equation

The Secret Sauce: Temporal Consistency
Since we are building an Animation Model, our sampler has an extra responsibility. In a normal image model, the sampler only cares about one frame. In ours, the sampler ensures that as Frame 1 becomes a "3," Frame 2 is doing the same thing in a slightly different position. By running the denoising process across the entire Temporal Volume (Frames x C x H x W) simultaneously, the sampler ensures that the "movement" is fluid. If we didn't do this, our animation would look like a flickering glitch rather than a moving digit.

Conclusion

So that's what I did for the project. Though the project can't capture the whole essence of video diffusion, it gives a rather simplified approach to how it works. In real video diffusion The conditioning is done on text and time. If you need a more conceptual understanding of diffusion models, you can check out 3blue1Brown and Welch Lab's video on how do ai images and video works. If you want to check out my code, you can go through my GitHub.

Top comments (0)