Train a Denoising Diffusion Probabilistic Model from scratch¶
Welcome! In this exercise you will train a DDPM from scratch. After training the model will be able to generate images of cars.
Let's get started!
Initial setup¶
Here we import a few modules and we set up the notebook so it is fully reproducible:
# Make results fully reproducible:
import os
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
import random
import torch
import numpy as np
seed = 42
random.seed(seed)
torch.manual_seed(seed)
torch.use_deterministic_algorithms(True)
# Import a few things we will need
import torch.nn.functional as F
import torch
from torch.optim import Adam, RAdam
from torch.optim.lr_scheduler import CosineAnnealingLR
from tqdm import tqdm
import multiprocessing
from torchvision import transforms
import torchvision
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from torchvision.utils import make_grid
Dataset¶
Let's start by loading our training dataset. We are going to use the Stanford Cars dataset. It consists of 196 classes of cars with a total of 16,185 images. For this exercise we do not need any label, and we also do not need a test dataset, so we are going to load both the training and the test dataset and concatenate them. We are also going to transform the images to 64x64 so the exercise can complete more quickly:
- The original Kaggle dataset: https://www.kaggle.com/datasets/jessicali9530/stanford-cars-dataset?datasetId=30084&sortBy=dateCreated&select=cars_test
- The devkit: car_devkit.tgz
- The cars_test_annos_withlabels.mat file: https://www.kaggle.com/code/subhangaupadhaya/pytorch-stanfordcars-classification/input?select=cars_test_annos_withlabels+%281%29.mat
The directory structure you provided earlier works well once we add the missing file!
└── stanford_cars
└── cars_test_annos_withlabels.mat
└── cars_train
└── *.jpg
└── cars_test
└── .*jpg
└── devkit
├── cars_meta.mat
├── cars_test_annos.mat
├── cars_train_annos.mat
├── eval_train.m
├── README.txt
└── train_perfect_preds.txt
IMG_SIZE = 64
BATCH_SIZE = 50 # reduce depends of yout gpu - initial was 100
def get_dataset(path):
data_transform = transforms.Compose(
[
transforms.Resize((IMG_SIZE, IMG_SIZE)),
# We flip horizontally with probability 50%
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
# Scales data into [-1,1]
transforms.Normalize(0.5, 0.5)
]
)
train = torchvision.datasets.StanfordCars(root=path, download=False,
transform=data_transform)
test = torchvision.datasets.StanfordCars(root=path, download=False,
transform=data_transform, split='test')
return torch.utils.data.ConcatDataset([train, test])
data = get_dataset("data/")
dataloader = DataLoader(
data,
batch_size=BATCH_SIZE,
shuffle=True,
drop_last=False,
pin_memory=True,
num_workers=multiprocessing.cpu_count(),
persistent_workers=True
)
Let's look at a batch of images:
# Get a batch
batch, _ = next(iter(dataloader))
# Display it
def display_sequence(imgs, dpi=75, nrow=8):
fig, sub = plt.subplots(dpi=dpi)
sub.imshow(
np.transpose(
make_grid(
imgs,
padding=0,
normalize=True,
nrow=nrow,
).cpu(),
(1,2,0)
)
)
_ = sub.axis("off")
return fig
_ = display_sequence(batch[:8], dpi=150)
Noise scheduling and precomputation¶
In the forward process we need to add random noise according to a schedule. Here we use a linear schedule with 512 diffusion steps.
Let's define it:
# Define beta schedule
T = 512 # number of diffusion steps
# YOUR CODE HERE
betas = torch.linspace(start=0.0001, end=0.02, steps=T) # linear schedule
plt.plot(range(T), betas.numpy(), label='Beta Values')
plt.xlabel('Diffusion Step')
plt.ylabel('Beta Value')
_ = plt.title('Beta Schedule over Diffusion Steps')
As we have seen in the lesson, we need to use a re-parametrization of the forward process that allows us to generate noisy images at any step without having to sequentially go through all the previous steps:
$$ \left\{ \begin{align*} \bar{\alpha}_t &= \prod_{s=1}^t (1 - \beta_s) \\ q(x_t | x_0) &= \mathcal{N}\left(\sqrt{\bar{\alpha}_t} x_0, \ (1 - \bar{\alpha}_t) \mathbf{I}\right) \end{align*} \right. $$
At inference time we will also need the quantities involved in these other formulas:
$$ x_{t-1} = \frac{1}{\sqrt{\alpha_t}} \left( x_t - \frac{1-\alpha_t}{\sqrt{1-\bar{\alpha}_t}} \epsilon_\theta (x_t, t) \right) + \sigma_t z $$
$$ \sigma_t^2 = \frac{(1-\bar{\alpha}_t-1)}{(1-\bar{\alpha}_t)} \beta_t $$
Here we define and precompute all these constants:
# Pre-calculate different terms for closed form
alphas = 1. - betas
# alpha bar
alphas_cumprod = torch.cumprod(alphas, axis=0)
# alpha bar at t-1
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
# sqrt of alpha bar
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
# Inference:
# 1 / sqrt(alpha)
sqrt_recip_alphas = torch.sqrt(1.0 / alphas)
# sqrt of one minus alpha bar
sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod)
# sigma_t
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
Here we define two utility functions, one to visualize the forward diffusion process, and the other one to make an inference call on an existing DDPM:
Fill the sections marked with YOUR CODE HERE
# Make sure CUDA is available (i.e. the GPU is setup correctly)
assert torch.cuda.is_available()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
@torch.no_grad()
def forward_diffusion_viz(image, device='cpu', num_images=16, dpi=75, interleave=False):
"""
Generate the forward sequence of noisy images taking the input image to pure noise
"""
# Visualize only num_images diffusion steps, instead of all of them
stepsize = int(T/num_images)
imgs = []
noises = []
for i in range(0, T, stepsize):
t = torch.full((1,), i, device=device, dtype=torch.long)
# Forward diffusion process
bs = image.shape[0]
noise = torch.randn_like(image, device=device)
img = (
sqrt_alphas_cumprod[t].view(bs, 1, 1, 1) * image +
sqrt_one_minus_alphas_cumprod[t].view(bs, 1, 1, 1) * noise
)
imgs.append(torch.clamp(img, -1, 1).squeeze(dim=0))
noises.append(torch.clamp(noise, -1, 1).squeeze(dim=0))
if interleave:
imgs = [item for pair in zip(imgs, noises) for item in pair]
fig = display_sequence(imgs, dpi=dpi)
return fig, imgs[-1]
@torch.no_grad()
def make_inference(input_noise, return_all=False):
"""
Implements the sampling algorithm from the DDPM paper
"""
x = input_noise
bs = x.shape[0]
imgs = []
# YOUR CODE HERE
for time_step in range(0, T)[::-1]:
noise = torch.randn_like(x) if time_step > 0 else 0
t = torch.full((bs,), time_step, device=device, dtype=torch.long)
# YOUR CODE HERE
x = sqrt_recip_alphas[t].view(bs, 1, 1, 1) * (
x - betas[t].view(bs, 1, 1, 1) * model(x, t) /
sqrt_one_minus_alphas_cumprod[t].view(bs, 1, 1, 1)
) + torch.sqrt(posterior_variance[t].view(bs, 1, 1, 1)) * noise
imgs.append(torch.clamp(x, -1, 1))
if return_all:
return imgs
else:
return imgs[-1]
return x
Forward process¶
Let's now simulate our forward process. If everything went well, you should see a few images like this one:
which show a few of the diffusion steps, from the original image to the left all the way to pure noise to the right.
for image in batch[:5]:
_ = forward_diffusion_viz(image.unsqueeze(dim=0), num_images=7, dpi=150, interleave=False)
"""
---
title: U-Net model for Denoising Diffusion Probabilistic Models (DDPM)
summary: >
UNet model for Denoising Diffusion Probabilistic Models (DDPM)
---
# U-Net model for [Denoising Diffusion Probabilistic Models (DDPM)](index.html)
This is a [U-Net](../../unet/index.html) based model to predict noise
$\textcolor{lightgreen}{\epsilon_\theta}(x_t, t)$.
U-Net is a gets it's name from the U shape in the model diagram.
It processes a given image by progressively lowering (halving) the feature map resolution and then
increasing the resolution.
There are pass-through connection at each resolution.
![U-Net diagram from paper](../../unet/unet.png)
This implementation contains a bunch of modifications to original U-Net (residual blocks, multi-head attention)
and also adds time-step embeddings $t$.
The MIT License (MIT)
Copyright (c) 2020 Varuna Jayasiri
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
(taken from https://github.com/labmlai/annotated_deep_learning_paper_implementations)
"""
import math
from typing import Optional, Tuple, Union, List
import torch
from torch import nn
class Swish(nn.Module):
"""
### Swish actiavation function
$$x \cdot \sigma(x)$$
"""
def forward(self, x):
return x * torch.sigmoid(x)
class TimeEmbedding(nn.Module):
"""
### Embeddings for $t$
"""
def __init__(self, n_channels: int):
"""
* `n_channels` is the number of dimensions in the embedding
"""
super().__init__()
self.n_channels = n_channels
# First linear layer
self.lin1 = nn.Linear(self.n_channels // 4, self.n_channels)
# Activation
self.act = Swish()
# Second linear layer
self.lin2 = nn.Linear(self.n_channels, self.n_channels)
def forward(self, t: torch.Tensor):
# Create sinusoidal position embeddings
# [same as those from the transformer](../../transformers/positional_encoding.html)
#
# \begin{align}
# PE^{(1)}_{t,i} &= sin\Bigg(\frac{t}{10000^{\frac{i}{d - 1}}}\Bigg) \\
# PE^{(2)}_{t,i} &= cos\Bigg(\frac{t}{10000^{\frac{i}{d - 1}}}\Bigg)
# \end{align}
#
# where $d$ is `half_dim`
half_dim = self.n_channels // 8
emb = math.log(10_000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=t.device) * -emb)
emb = t[:, None] * emb[None, :]
emb = torch.cat((emb.sin(), emb.cos()), dim=1)
# Transform with the MLP
emb = self.act(self.lin1(emb))
emb = self.lin2(emb)
#
return emb
class ResidualBlock(nn.Module):
"""
### Residual block
A residual block has two convolution layers with group normalization.
Each resolution is processed with two residual blocks.
"""
def __init__(
self,
in_channels: int,
out_channels: int,
time_channels: int,
n_groups: int = 32,
dropout: float = 0.1,
):
"""
* `in_channels` is the number of input channels
* `out_channels` is the number of input channels
* `time_channels` is the number channels in the time step ($t$) embeddings
* `n_groups` is the number of groups for [group normalization](../../normalization/group_norm/index.html)
* `dropout` is the dropout rate
"""
super().__init__()
# Group normalization and the first convolution layer
self.norm1 = nn.GroupNorm(n_groups, in_channels)
self.act1 = Swish()
self.conv1 = nn.Conv2d(
in_channels, out_channels, kernel_size=(3, 3), padding=(1, 1)
)
# Group normalization and the second convolution layer
self.norm2 = nn.GroupNorm(n_groups, out_channels)
self.act2 = Swish()
self.conv2 = nn.Conv2d(
out_channels, out_channels, kernel_size=(3, 3), padding=(1, 1)
)
# If the number of input channels is not equal to the number of output channels we have to
# project the shortcut connection
if in_channels != out_channels:
self.shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=(1, 1))
else:
self.shortcut = nn.Identity()
# Linear layer for time embeddings
self.time_emb = nn.Linear(time_channels, out_channels)
self.time_act = Swish()
self.dropout = nn.Dropout(dropout)
def forward(self, x: torch.Tensor, t: torch.Tensor):
"""
* `x` has shape `[batch_size, in_channels, height, width]`
* `t` has shape `[batch_size, time_channels]`
"""
# First convolution layer
h = self.conv1(self.act1(self.norm1(x)))
# Add time embeddings
h += self.time_emb(self.time_act(t))[:, :, None, None]
# Second convolution layer
h = self.conv2(self.dropout(self.act2(self.norm2(h))))
# Add the shortcut connection and return
return h + self.shortcut(x)
class AttentionBlock(nn.Module):
"""
### Attention block
This is similar to [transformer multi-head attention](../../transformers/mha.html).
"""
def __init__(
self, n_channels: int, n_heads: int = 1, d_k: int = None, n_groups: int = 32
):
"""
* `n_channels` is the number of channels in the input
* `n_heads` is the number of heads in multi-head attention
* `d_k` is the number of dimensions in each head
* `n_groups` is the number of groups for [group normalization](../../normalization/group_norm/index.html)
"""
super().__init__()
# Default `d_k`
if d_k is None:
d_k = n_channels
# Normalization layer
self.norm = nn.GroupNorm(n_groups, n_channels)
# Projections for query, key and values
self.projection = nn.Linear(n_channels, n_heads * d_k * 3)
# Linear layer for final transformation
self.output = nn.Linear(n_heads * d_k, n_channels)
# Scale for dot-product attention
self.scale = d_k**-0.5
#
self.n_heads = n_heads
self.d_k = d_k
def forward(self, x: torch.Tensor, t: Optional[torch.Tensor] = None):
"""
* `x` has shape `[batch_size, in_channels, height, width]`
* `t` has shape `[batch_size, time_channels]`
"""
# `t` is not used, but it's kept in the arguments because for the attention layer function signature
# to match with `ResidualBlock`.
_ = t
# Get shape
batch_size, n_channels, height, width = x.shape
# Change `x` to shape `[batch_size, seq, n_channels]`
x = x.view(batch_size, n_channels, -1).permute(0, 2, 1)
# Get query, key, and values (concatenated) and shape it to `[batch_size, seq, n_heads, 3 * d_k]`
qkv = self.projection(x).view(batch_size, -1, self.n_heads, 3 * self.d_k)
# Split query, key, and values. Each of them will have shape `[batch_size, seq, n_heads, d_k]`
q, k, v = torch.chunk(qkv, 3, dim=-1)
# Calculate scaled dot-product $\frac{Q K^\top}{\sqrt{d_k}}$
attn = torch.einsum("bihd,bjhd->bijh", q, k) * self.scale
# Softmax along the sequence dimension $\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_k}}\Bigg)$
attn = attn.softmax(dim=2)
# Multiply by values
res = torch.einsum("bijh,bjhd->bihd", attn, v)
# Reshape to `[batch_size, seq, n_heads * d_k]`
res = res.view(batch_size, -1, self.n_heads * self.d_k)
# Transform to `[batch_size, seq, n_channels]`
res = self.output(res)
# Add skip connection
res += x
# Change to shape `[batch_size, in_channels, height, width]`
res = res.permute(0, 2, 1).view(batch_size, n_channels, height, width)
#
return res
class DownBlock(nn.Module):
"""
### Down block
This combines `ResidualBlock` and `AttentionBlock`. These are used in the first half of U-Net at each resolution.
"""
def __init__(
self, in_channels: int, out_channels: int, time_channels: int, has_attn: bool
):
super().__init__()
self.res = ResidualBlock(in_channels, out_channels, time_channels)
if has_attn:
self.attn = AttentionBlock(out_channels)
else:
self.attn = nn.Identity()
def forward(self, x: torch.Tensor, t: torch.Tensor):
x = self.res(x, t)
x = self.attn(x)
return x
class UpBlock(nn.Module):
"""
### Up block
This combines `ResidualBlock` and `AttentionBlock`. These are used in the second half of U-Net at each resolution.
"""
def __init__(
self, in_channels: int, out_channels: int, time_channels: int, has_attn: bool
):
super().__init__()
# The input has `in_channels + out_channels` because we concatenate the output of the same resolution
# from the first half of the U-Net
self.res = ResidualBlock(
in_channels + out_channels, out_channels, time_channels
)
if has_attn:
self.attn = AttentionBlock(out_channels)
else:
self.attn = nn.Identity()
def forward(self, x: torch.Tensor, t: torch.Tensor):
x = self.res(x, t)
x = self.attn(x)
return x
class MiddleBlock(nn.Module):
"""
### Middle block
It combines a `ResidualBlock`, `AttentionBlock`, followed by another `ResidualBlock`.
This block is applied at the lowest resolution of the U-Net.
"""
def __init__(self, n_channels: int, time_channels: int):
super().__init__()
self.res1 = ResidualBlock(n_channels, n_channels, time_channels)
self.attn = AttentionBlock(n_channels)
self.res2 = ResidualBlock(n_channels, n_channels, time_channels)
def forward(self, x: torch.Tensor, t: torch.Tensor):
x = self.res1(x, t)
x = self.attn(x)
x = self.res2(x, t)
return x
class Upsample(nn.Module):
"""
### Scale up the feature map by $2 \times$
"""
def __init__(self, n_channels):
super().__init__()
self.conv = nn.ConvTranspose2d(n_channels, n_channels, (4, 4), (2, 2), (1, 1))
def forward(self, x: torch.Tensor, t: torch.Tensor):
# `t` is not used, but it's kept in the arguments because for the attention layer function signature
# to match with `ResidualBlock`.
_ = t
return self.conv(x)
class Downsample(nn.Module):
"""
### Scale down the feature map by $\frac{1}{2} \times$
"""
def __init__(self, n_channels):
super().__init__()
self.conv = nn.Conv2d(n_channels, n_channels, (3, 3), (2, 2), (1, 1))
def forward(self, x: torch.Tensor, t: torch.Tensor):
# `t` is not used, but it's kept in the arguments because for the attention layer function signature
# to match with `ResidualBlock`.
_ = t
return self.conv(x)
class UNet(nn.Module):
"""
## U-Net
"""
def __init__(
self,
image_channels: int = 3,
n_channels: int = 64,
ch_mults: Union[Tuple[int, ...], List[int]] = (1, 2, 2, 4),
is_attn: Union[Tuple[bool, ...], List[int]] = (False, False, True, True),
n_blocks: int = 2,
):
"""
* `image_channels` is the number of channels in the image. $3$ for RGB.
* `n_channels` is number of channels in the initial feature map that we transform the image into
* `ch_mults` is the list of channel numbers at each resolution. The number of channels is `ch_mults[i] * n_channels`
* `is_attn` is a list of booleans that indicate whether to use attention at each resolution
* `n_blocks` is the number of `UpDownBlocks` at each resolution
"""
super().__init__()
# Number of resolutions
n_resolutions = len(ch_mults)
# Project image into feature map
self.image_proj = nn.Conv2d(
image_channels, n_channels, kernel_size=(3, 3), padding=(1, 1)
)
# Time embedding layer. Time embedding has `n_channels * 4` channels
self.time_emb = TimeEmbedding(n_channels * 4)
# #### First half of U-Net - decreasing resolution
down = []
# Number of channels
out_channels = in_channels = n_channels
# For each resolution
for i in range(n_resolutions):
# Number of output channels at this resolution
out_channels = in_channels * ch_mults[i]
# Add `n_blocks`
for _ in range(n_blocks):
down.append(
DownBlock(in_channels, out_channels, n_channels * 4, is_attn[i])
)
in_channels = out_channels
# Down sample at all resolutions except the last
if i < n_resolutions - 1:
down.append(Downsample(in_channels))
# Combine the set of modules
self.down = nn.ModuleList(down)
# Middle block
self.middle = MiddleBlock(
out_channels,
n_channels * 4,
)
# #### Second half of U-Net - increasing resolution
up = []
# Number of channels
in_channels = out_channels
# For each resolution
for i in reversed(range(n_resolutions)):
# `n_blocks` at the same resolution
out_channels = in_channels
for _ in range(n_blocks):
up.append(
UpBlock(in_channels, out_channels, n_channels * 4, is_attn[i])
)
# Final block to reduce the number of channels
out_channels = in_channels // ch_mults[i]
up.append(UpBlock(in_channels, out_channels, n_channels * 4, is_attn[i]))
in_channels = out_channels
# Up sample at all resolutions except last
if i > 0:
up.append(Upsample(in_channels))
# Combine the set of modules
self.up = nn.ModuleList(up)
# Final normalization and convolution layer
self.norm = nn.GroupNorm(8, n_channels)
self.act = Swish()
self.final = nn.Conv2d(
in_channels, image_channels, kernel_size=(3, 3), padding=(1, 1)
)
def forward(self, x: torch.Tensor, t: torch.Tensor):
"""
* `x` has shape `[batch_size, in_channels, height, width]`
* `t` has shape `[batch_size]`
"""
# Get time-step embeddings
t = self.time_emb(t)
# Get image projection
x = self.image_proj(x)
# `h` will store outputs at each resolution for skip connection
h = [x]
# First half of U-Net
for m in self.down:
x = m(x, t)
h.append(x)
# Middle (bottom)
x = self.middle(x, t)
# Second half of U-Net
for m in self.up:
if isinstance(m, Upsample):
x = m(x, t)
else:
# Get the skip connection from first half of U-Net and concatenate
s = h.pop()
x = torch.cat((x, s), dim=1)
#
x = m(x, t)
# Final normalization and convolution
return self.final(self.act(self.norm(x)))
model = UNet(ch_mults = (1, 2, 1, 1))
# Uncomment this
# if you want to do the _VERY_ long training,
# model = UNet(ch_mults = (1, 2, 2, 2))
n_params = sum(p.numel() for p in model.parameters())
print(
f"Number of parameters: {n_params:,}"
)
Our model has around 9.1 Million parameters. When compared to Stable Diffusion, which has 1 Billion parameters, it is very small! However, for this dataset, it can still give remarkable results.
Training loop¶
Let's now do some preparation for the training loop. First we transfer the model as well as all our precomputed quantities to the GPU, so they can be used efficiently during training:
device = "cuda" if torch.cuda.is_available() else "cpu"
# Move everything to GPU
model.to(device)
sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device)
alphas = alphas.to(device)
alphas_cumprod = alphas_cumprod.to(device)
alphas_cumprod_prev = alphas_cumprod_prev.to(device)
sqrt_recip_alphas = sqrt_recip_alphas.to(device)
sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device)
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device)
posterior_variance = posterior_variance.to(device)
betas = betas.to(device)
Now we can define the loss we are going to minimize:
Complete the section marked with YOUR CODE HERE
# YOUR CODE HERE
criterion = torch.nn.MSELoss()
Then we define a few parameters for our training. We are going to use Cosine Annealing for the learning rate, with a warmup period. This means that we are going to start from a very low learning rate, increase it linearly for a few epochs, then start decreasing it again with a cosine shape:
base_lr = 0.0006 # Maximum learning rate we will use
epochs = 300 # Total number of epochs
T_max = epochs # Number of epochs for Cosine Annealing. We do only one cycle
warmup_epochs = 2 # Number of warm-up epochs
# Uncomment the following lines
# if you want to do the _VERY_ long training,
# base_lr = 0.0001 # Maximum learning rate we will use
# epochs = 300 # Total number of epochs
# T_max = epochs # Number of epochs for Cosine Annealing. We do only one cycle
# warmup_epochs = 10 # Number of warm-up epochs
optimizer = Adam(model.parameters(), lr=base_lr)
scheduler = CosineAnnealingLR(
optimizer,
T_max=T_max - warmup_epochs,
eta_min=base_lr / 10 # starting value for the LR
)
Finally let's train! We train only for 5 epochs, which should mean around 20 min of training time. This won't get us to a good result, but you will see a few hints of cars appearing little by little.
Complete the section marked with YOUR CODE HERE
# We will use this noise to generate some images during training to check
# where we stand
fixed_noise = torch.randn((1, 3, IMG_SIZE, IMG_SIZE), device=device)
alpha = 0.1 # Smoothing factor
ema_loss = None # Initialize EMA loss
for epoch in range(epochs):
if epoch < warmup_epochs:
# Linear warm-up
lr = base_lr * (epoch + 1) / warmup_epochs
for param_group in optimizer.param_groups:
param_group['lr'] = lr
else:
# Cosine Annealing after warm-up
scheduler.step()
current_lr = optimizer.param_groups[0]['lr']
for batch, _ in tqdm(dataloader):
batch = batch.to(device)
bs = batch.shape[0]
optimizer.zero_grad()
# YOUR CODE HERE
t = torch.randint(0, T, (batch.shape[0],), device=device).long()
# Generate targets for the UNet and apply them to the images
noise = torch.randn_like(batch, device=device)
x_noisy = (
sqrt_alphas_cumprod[t].view(bs, 1, 1, 1) * batch +
sqrt_one_minus_alphas_cumprod[t].view(bs, 1, 1, 1) * noise
)
noise_pred = model(x_noisy, t)
loss = criterion(noise, noise_pred)
loss.backward()
optimizer.step()
if ema_loss is None:
# First batch
ema_loss = loss.item()
else:
# Exponential moving average of the loss
ema_loss = alpha * loss.item() + (1 - alpha) * ema_loss
if epoch == epochs-1:
with torch.no_grad():
# fig, _ = sample_image(fixed_noise, forward=False, device=device)
imgs = make_inference(fixed_noise, return_all=True)
fig = display_sequence([imgs[0].squeeze(dim=0)] + [x.squeeze(dim=0) for x in imgs[63::64]], nrow=9, dpi=150)
plt.show(fig)
os.makedirs("diffusion_output_long", exist_ok=True)
fig.savefig(f"diffusion_output_long/frame_{epoch:05d}.png")
#plt.close(fig)
print(f"epoch {epoch+1}: loss: {ema_loss:.3f}, lr: {current_lr:.6f}")
Inference¶
We can now have a look at what our model can produce:
Complete the section marked with YOUR CODE HERE
# YOUR CODE HERE
input_noise = torch.randn((8, 3, IMG_SIZE, IMG_SIZE), device=device)
imgs = make_inference(input_noise)
_ = display_sequence(imgs, dpi=75, nrow=4)
This is a fairly good result considering how small the model is and how little we trained it. We can already tell that it is indeed creating cars, with windshields and wheels, although it is still very early on. If we were to train for much longer, and/or use a larger model (for example, the one defined above in the commented lines has 55 Million parameters) and let it train for several hours, we would get something even better, like this:
Created: 2024-10-23