Generative Adversarial Network: training¶
In this exercise we will practice how to train a GAN on a real dataset and generate our first synthetic images. Let's get started!
First let's define some parameters for our GAN:
Generative Adversarial Networks (GANs) represent a significant leap in the field of artificial intelligence, particularly in image generation. At its core, a GAN consists of two main components: the Generator and the Discriminator.
- The Generator: Crafting Synthetic Images
The Generator's role is to create images. It starts with a random noise vector, often sampled from a high-dimensional distribution. This vector, known as latent z, is then passed through the Generator network, which uses strided convolutions to convert this latent representation into a synthetic image. For example, a latent vector of 100 elements might be transformed into a 64x64 pixel image, which translates to 4096 numbers.
- The Discriminator: The Arbiter of Realism
The Discriminator's job is to distinguish between real and generated (fake) images. It performs a binary classification to determine the authenticity of each image. In classical GANs, the Discriminator is typically a standard Convolutional Neural Network (CNN) used for image classification.
- Training the GAN: A Dance Between Generator and Discriminator
Training a GAN involves an iterative process where the Generator and Discriminator continuously improve through competition. Initially, the Generator creates images, and the Discriminator learns to distinguish them from real ones. As the Generator improves, it becomes better at fooling the Discriminator. The training process is a cycle of alternating between training the Discriminator and the Generator, each time making them more adept at their tasks.
In summary, GANs harness the power of two neural networks in a unique setup, where one creates and the other critiques, leading to the generation of increasingly realistic images. This dynamic interplay between the Generator and Discriminator underpins the success of GANs in creating convincing and high-quality synthetic images.
import multiprocessing
CONFIG = {
# For repeatibility we will fix the random seed
"manual_seed": 42,
# This defines a set of augmentations we will perform, see below
"policy": "color,translation", # ,cutout
# Dimension of the latent space
"latent_dimension": 256,
# Batch size for training
"batch_size": 256,
# Number of epochs. We will use 1200 epochs which corresponds to
# approximately 20 min of training
"n_epochs": 40,
# Input images will be resized to this, the generator will generate
# images with this size
"image_size": 64, # 64x64 pixels
# Number of channels in the input images
"num_channels": 3, # RGB
# Learning rate
"lr": 0.002,
# Momentum for Adam: in GANs you want to use a lower momentum to
# allow the Generator and the Discriminator to interact quicker
"beta1": 0.7,
# Number of feature maps in each layer of the Generator
"g_feat_map_size": 64,
# Number of feature maps in each layer of the Discriminator
"d_feat_map_size": 64,
# Where to save the data
"data_path": "data/",
# Number of workers to use to load the data
"workers": multiprocessing.cpu_count(),
# We will display progress every "save_iter" epochs
"save_iter": 10,
# Where to save the progress
"outdir": "data/stanford_cars/",
# Unused
"clip_value": 0.01,
}
In order to make the training repeatible, let's fix the random seed and set pytorch to use deterministic algorithms. This is normally not necessary, although it might not be a bad idea to keep your experimentation ordered. Keep in mind that the initial random seed can have quite an impact on the training of the GAN.
One thing to consider is that deterministic algorithms can be significantly slower than non-deterministic ones, so we pay a performance penalty for setting things this way. In a real training scenario you might want to reconsider this tradeoff.
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
import random
import torch
random.seed(CONFIG["manual_seed"])
torch.manual_seed(CONFIG["manual_seed"])
torch.use_deterministic_algorithms(True)
Let's import a few other modules, methods and functions that we will need:
import argparse
import json
from pathlib import Path
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
import torch.nn.utils.spectral_norm as spectral_norm
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import gridspec
import numpy as np
from torchvision.utils import make_grid
import os
from IPython.display import clear_output
from ema_pytorch import EMA
import time
import tqdm
# Create the output directory
output_dir = Path(CONFIG["outdir"])
output_dir.mkdir(parents=True, exist_ok=True)
# Save the configuration there for safekeeping
with open(output_dir / "config.json", "w") as f:
json.dump(CONFIG, f, indent=4)
# Make sure CUDA is available (i.e. the GPU is setup correctly)
assert torch.cuda.is_available()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
Helpers¶
def set_seeds(seed):
"""Set seeds for reproducibility."""
random.seed(seed)
torch.manual_seed(seed)
torch.use_deterministic_algorithms(True)
def initialize_weights(model):
"""Custom weight initialization."""
for m in model.modules():
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
nn.init.normal_(m.weight.data, 0.0, 0.02)
elif isinstance(m, nn.BatchNorm2d):
nn.init.normal_(m.weight.data, 1.0, 0.02)
nn.init.constant_(m.bias.data, 0)
def get_positive_labels(size, device, smoothing=True, random_flip=0.05):
if smoothing:
# Random positive numbers between 0.8 and 1.2 (label smoothing)
labels = 0.8 + 0.4 * torch.rand(size, device=device)
else:
labels = torch.full((size,), 1.0, device=device)
if random_flip > 0:
# Let's flip some of the labels to make it slightly harder for the discriminator
num_to_flip = int(random_flip * labels.size(0))
# Get random indices and set the first "num_to_flip" of them to 0
indices = torch.randperm(labels.size(0))[:num_to_flip]
labels[indices] = 0
return labels
def get_negative_labels(size, device):
return torch.full((size,), 0.0, device=device)
Diff augmented¶
def DiffAugment(x, policy="", channels_first=True):
if policy:
if not channels_first:
x = x.permute(0, 3, 1, 2)
for p in policy.split(","):
for f in AUGMENT_FNS[p]:
x = f(x)
if not channels_first:
x = x.permute(0, 2, 3, 1)
x = x.contiguous()
return x
def rand_brightness(x):
x = x + (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) - 0.5)
return x
def rand_saturation(x):
x_mean = x.mean(dim=1, keepdim=True)
x = (x - x_mean) * (
torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) * 2
) + x_mean
return x
def rand_contrast(x):
x_mean = x.mean(dim=[1, 2, 3], keepdim=True)
x = (x - x_mean) * (
torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) + 0.5
) + x_mean
return x
def rand_translation(x, ratio=0.125):
shift_x, shift_y = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
translation_x = torch.randint(
-shift_x, shift_x + 1, size=[x.size(0), 1, 1], device=x.device
)
translation_y = torch.randint(
-shift_y, shift_y + 1, size=[x.size(0), 1, 1], device=x.device
)
grid_batch, grid_x, grid_y = torch.meshgrid(
torch.arange(x.size(0), dtype=torch.long, device=x.device),
torch.arange(x.size(2), dtype=torch.long, device=x.device),
torch.arange(x.size(3), dtype=torch.long, device=x.device),
)
grid_x = torch.clamp(grid_x + translation_x + 1, 0, x.size(2) + 1)
grid_y = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1)
x_pad = F.pad(x, [1, 1, 1, 1, 0, 0, 0, 0])
x = (
x_pad.permute(0, 2, 3, 1)
.contiguous()[grid_batch, grid_x, grid_y]
.permute(0, 3, 1, 2)
.contiguous()
)
return x
def rand_cutout(x, ratio=0.5):
cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
offset_x = torch.randint(
0, x.size(2) + (1 - cutout_size[0] % 2), size=[x.size(0), 1, 1], device=x.device
)
offset_y = torch.randint(
0, x.size(3) + (1 - cutout_size[1] % 2), size=[x.size(0), 1, 1], device=x.device
)
grid_batch, grid_x, grid_y = torch.meshgrid(
torch.arange(x.size(0), dtype=torch.long, device=x.device),
torch.arange(cutout_size[0], dtype=torch.long, device=x.device),
torch.arange(cutout_size[1], dtype=torch.long, device=x.device),
)
grid_x = torch.clamp(
grid_x + offset_x - cutout_size[0] // 2, min=0, max=x.size(2) - 1
)
grid_y = torch.clamp(
grid_y + offset_y - cutout_size[1] // 2, min=0, max=x.size(3) - 1
)
mask = torch.ones(x.size(0), x.size(2), x.size(3), dtype=x.dtype, device=x.device)
mask[grid_batch, grid_x, grid_y] = 0
x = x * mask.unsqueeze(1)
return x
AUGMENT_FNS = {
"color": [rand_brightness, rand_saturation, rand_contrast],
"translation": [rand_translation],
"cutout": [rand_cutout],
}
Data¶
- The original Kaggle dataset: https://www.kaggle.com/datasets/jessicali9530/stanford-cars-dataset?datasetId=30084&sortBy=dateCreated&select=cars_test
- The devkit: car_devkit.tgz
- The cars_test_annos_withlabels.mat file: https://www.kaggle.com/code/subhangaupadhaya/pytorch-stanfordcars-classification/input?select=cars_test_annos_withlabels+%281%29.mat
The directory structure you provided earlier works well once we add the missing file!
└── stanford_cars
└── cars_test_annos_withlabels.mat
└── cars_train
└── *.jpg
└── cars_test
└── .*jpg
└── devkit
├── cars_meta.mat
├── cars_test_annos.mat
├── cars_train_annos.mat
├── eval_train.m
├── README.txt
└── train_perfect_preds.txt
def get_dataloader(
root_path,
image_size,
batch_size,
workers=multiprocessing.cpu_count(),
donwload=False,
):
transform = transforms.Compose(
[
transforms.Resize(image_size),
transforms.CenterCrop((image_size, image_size)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]
)
dataset_train = datasets.StanfordCars(
root=root_path, download=donwload, split="train", transform=transform
)
dataset_test = datasets.StanfordCars(
root=root_path, download=donwload, split="test", transform=transform
)
dataset = torch.utils.data.ConcatDataset([dataset_train, dataset_test])
print(f"Using {workers} workers")
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=batch_size,
shuffle=True,
num_workers=workers,
pin_memory=True,
persistent_workers=True if workers > 0 else False,
# collate_fn=collate_fn
)
return dataloader
Viz¶
# Visualize the output tensor as a grayscale image
def visualize_batch(batch):
b = batch.detach().cpu()
fig, sub = plt.subplots(dpi=150)
sub.imshow(np.transpose(make_grid(b, padding=0, normalize=True).cpu(), (1, 2, 0)))
_ = sub.axis("off")
def training_tracking(D_losses, G_losses, D_acc, fake_data):
fig = plt.figure(dpi=150)
gs = gridspec.GridSpec(2, 8)
# Create subplots
ax_a = fig.add_subplot(gs[0, :3]) # Top-left subplot
ax_b = fig.add_subplot(gs[1, :3]) # Bottom-left subplot
ax_c = fig.add_subplot(gs[:, 4:]) # Right subplot spanning both rows
subs = [ax_a, ax_b, ax_c]
# Losses
subs[0].plot(D_losses, label="Discriminator")
subs[0].plot(G_losses, label="Generator")
subs[0].legend()
subs[0].set_ylabel("Loss")
# Accuracy
subs[1].plot(D_acc)
subs[1].set_ylabel("D accuracy")
# Examples of generated images
subs[2].imshow(
np.transpose(
make_grid(
fake_data.detach().cpu(), padding=0, normalize=True, nrow=4
).cpu(),
(1, 2, 0),
)
)
subs[2].axis("off")
fig.tight_layout()
return fig
Generator¶
class Generator(nn.Module):
"""
Generator class for DCGAN.
The latent is passed through the generator network that ouptupts a synthetic image.
:param image_size: size of the input image (assumed to be square). Must be a power of 2
:param latent_dimension: dimension of the latent space
:param feat_map_size: number of feature maps in the last layer of the generator
:param num_channels: number of channels in the input image
"""
def __init__(self, image_size, latent_dimension, feat_map_size, num_channels):
super(Generator, self).__init__()
# The following defines the architecture in a way that automatically
# scales the number of blocks depending on the size of the input image
# Number of blocks between the first and the last (excluded)
n_blocks = int(np.log2(image_size)) - 3
# Initial multiplicative factor for the number of feature maps
factor = 2 ** (n_blocks)
# The first block takes us from the latent space to the feature space with a
# 4x4 kernel with stride 1 and no padding
blocks = [
self._get_transpconv_block(
latent_dimension, feat_map_size * factor, 4, 1, 0, nn.LeakyReLU(0.2)
)
]
# The following blocks are transposed convolutional layers with stride 2 and
# kernel size 4x4. Every block halves the number of feature maps but double the
# size of the image (upsampling)
# (NOTE that we loop in reverse order)
prev_dim = feat_map_size * factor
for f in range(int(np.log2(factor) - 1), -1, -1):
blocks.append(
self._get_transpconv_block(
prev_dim, feat_map_size * 2**f, 4, 2, 1, nn.LeakyReLU(0.2)
)
)
prev_dim = feat_map_size * 2**f
# Add last layer
blocks.append(
self._get_transpconv_block(
feat_map_size, num_channels, 4, 2, 1, nn.Tanh(), batch_norm=False
)
)
self.model = nn.Sequential(*blocks)
def _get_transpconv_block(
self,
in_channels,
out_channels,
kernel_size,
stride,
padding,
activation,
batch_norm=True,
):
return nn.Sequential(
nn.ConvTranspose2d(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
),
nn.BatchNorm2d(out_channels) if batch_norm else nn.Identity(),
activation,
)
def forward(self, latents):
return self.model(latents)
Discriminator¶
class Discriminator(nn.Module):
"""
Discriminator class for DCGAN.
The Discriminator network tries to divide fake from real images.
:param image_size: size of the input image (assumed to be square). Must be a power of 2
:param feat_map_size: number of feature maps in the first layer of the discriminator
:param num_channels: number of channels in the input image
:param dropout: dropout probability
"""
def __init__(self, image_size, feat_map_size, num_channels, dropout=0):
super(Discriminator, self).__init__()
blocks = []
prev_dim = num_channels
for i in range(int(np.log2(image_size)) - 2):
blocks.append(
self._get_conv_block(
in_channels=prev_dim,
out_channels=feat_map_size * (2**i),
kernel_size=4,
stride=2,
padding=1,
dropout=dropout,
activation=nn.LeakyReLU(0.2, inplace=True),
batch_norm=False if i == 0 else True,
)
)
prev_dim = feat_map_size * (2**i)
blocks.append(
self._get_conv_block(
in_channels=prev_dim,
out_channels=1,
kernel_size=4,
stride=1,
padding=0,
dropout=0,
activation=nn.Sigmoid(),
batch_norm=False,
)
)
self.model = nn.Sequential(*blocks)
def _get_conv_block(
self,
in_channels,
out_channels,
kernel_size,
stride,
padding,
dropout,
activation,
batch_norm=True,
):
return nn.Sequential(
spectral_norm(
nn.Conv2d(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
bias=not batch_norm,
)
),
nn.Dropout(p=dropout) if dropout > 0 else nn.Identity(),
nn.BatchNorm2d(out_channels) if batch_norm else nn.Identity(),
activation,
)
def forward(self, images):
return self.model(images)
Input dataset: real data¶
In order to train a GAN we need to show it real data of the type we want to generate. In this case we are going to focus on the Stanford Cars dataset:
# Get data loader
dataloader = get_dataloader(
CONFIG["data_path"],
CONFIG["image_size"],
CONFIG["batch_size"],
CONFIG["workers"],
donwload=False,
)
print(f"Total number of examples: {len(dataloader.dataset)}")
visualize_batch(next(iter(dataloader))[0][:16])
These images represent cars.
Generator¶
Generator Training and Inference Generative Adversarial Networks (GANs) have revolutionized the field of AI-driven image generation. A crucial aspect of their success lies in the training of the Generator, which is responsible for creating realistic synthetic images.
The Generator's Objective
The Generator in a GAN starts with a random noise vector, known as latent z, and transforms it into a synthetic image. The goal of the Generator is to create images so convincing that they can fool the Discriminator into believing they are real. This is accomplished by trying to maximize the loss of the Discriminator on the fake data.
Training the Generator
The training process of the Generator involves several key steps:
Generating Fake Images: The Generator creates fake images from the latent z vector. We are actually going to reuse the fake images we have generated previously during the training of the Discriminator, but this is just an optimization. Discriminator's Evaluation: These fake images are then passed through the Discriminator, which is kept frozen during this phase. The Discriminator evaluates these images and assigns a probability score to each, indicating how likely they are to be real. Loss Calculation and Backpropagation: The Generator then adjusts its parameters to maximize the loss derived from the Discriminator’s evaluation. This loss reflects how well the Generator is fooling the Discriminator.
Binary Cross Entropy Trick
It can be shown mathematically that maximizing the Binary Cross Entropy loss of the Discriminator on fake data (with label y=0) is equivalent to minimizing the same loss assigning y=1 instead of y=0.
Let's create the Generator network and look at its architecture:
# Initialize models and optimizers
G = Generator(
CONFIG["image_size"],
CONFIG["latent_dimension"],
CONFIG["g_feat_map_size"],
CONFIG["num_channels"],
).to(device)
print(G)
Let's create a latent vector and put it through the Generator. What shape would you expect?
Complete the code marked by the YOUR CODE HERE placeholder
# Generate a latent vector of shape (1, CONFIG['latent_dimension'], 1, 1)
# Remember that the latent vector is just a vector of noise taken from a
# Normal distribution
# HINT: you can use torch.randn to sample from a Normal distribution
latent = torch.randn(1, CONFIG['latent_dimension'], 1, 1)
latent = latent.to(device)
fake_img = G(latent)
print(fake_img.shape)
Let's look at what the Generator is producing right now:
visualize_batch(G(latent))
This is of course just noise, because the Generator has not been trained yet. Now let's look at the shape of the tensor as it flows through the architecture:
x = latent
for i in range(5):
x = G.model[i](x.cuda())
b, c, w, h = x.shape
print(f"Channels: {c:3d}, w x h: {w:2d} x {h:2d}")
We can see that the input latent is mapped to 512 feature maps of size 4x4 pixels. After the first convolution, we have 256 feature maps of size 8x8 pixels, and so on, until we get to 3 output channels and a size of 64x64 pixels, which is the expected size for our fake image (matching the size of the input dataset).
Discriminator¶
Now let's have a look at the Discriminator:
The discrimnator take the input of a picture and classify the image.
- The Role of the Discriminator
The Discriminator's task in a GAN is to distinguish between real and generated images. During training, this component learns to identify nuances that differentiate authentic images from those created by the Generator.
- Training Process: The Split-Batch Method
A popular method for training the Discriminator is the 'split-batch' technique. This involves two main steps:
- Step 1: Handling Real Images The Discriminator is fed real images and learns to identify them as authentic. This process involves a forward pass of real data through the Discriminator, generating a probability score for each image being real. The Binary Cross Entropy (BCE) loss is then calculated by comparing the Discriminator's predictions against the true labels (real images).
- Step 2: Dealing with Fake Images Next, the Discriminator is presented with fake images produced by the Generator. These images undergo a similar process, with the Discriminator learning to label them as fake. The BCE loss is again used to compare the Discriminator's predictions against the true labels (fake images).
- Updating the Discriminator
After processing both real and fake images, the Discriminator's weights are updated. This is done using the gradients accumulated from both sets of data, ensuring that the Discriminator improves its ability to differentiate real from fake images.
D = (
Discriminator(
CONFIG["image_size"], CONFIG["d_feat_map_size"], CONFIG["num_channels"], dropout=0.1
)
.to(device)#.eval()
)
print(D)
the Discriminator is composed of 5 blocks (from 0 to 4), represented by the Sequential modules. This is a standard classification CNN for Binary classification but it does not use any pooling layer. Instead, all convolutional layers are using a stride of 2 so the feature maps become smaller and smaller at every iteration:
x = fake_img
for i in range(5):
x = D.model[i](x)
b, c, w, h = x.shape
print(f"Channels: {c:3d}, w x h: {w:2d} x {h:2d}")
Loss and optimizers¶
Like in any other task involving the training of neural networks, we need to setup the loss function we want to minimize and the optimizer.
In the case of GANs, we have two optimizers: the optimizer for the Generator, and the optimizer for the Discriminator:
Complete the code marked by the YOUR CODE HERE placeholder
# Complete this code using the appropriate loss for the
# binary classification task of the Discriminator
# HINT: some possible losses available in pytorch are:
# nn.MSELoss()
# nn.BCELoss()
# nn.CrossEntropyLoss()
# nn.NLLoss()
# Pick the one appropriate for binary classification
criterion = nn.BCELoss() # YOUR CODE HERE
# Optimizer for the Generator
# Instance the optimizer for the Generator
# HINT: the first parameter of optim.Adam()
# should be a list of parameters to optimize.
# Given a network N, you can obtain its parameters
# just by doing N.parameters(). Now do the same for the
# Generator
optimizerG = optim.Adam(
G.parameters(), # YOUR CODE HERE,
lr=CONFIG["lr"],
betas=(CONFIG["beta1"], 0.999),
)
# Optimizer for the Discriminator
# Do the same thing you did for the Generator, but this time
# for the Discriminator (i.e., complete the initialization
# of the Adam optimizer with the parameters of D)
optimizerD = optim.Adam(
D.parameters(), # YOUR CODE HERE,
lr=CONFIG["lr"] / 4,
betas=(CONFIG["beta1"], 0.999),
)
Trick 1: Exponential Moving Average¶
GANs are notoriously difficult to train as the balance between the Generator and the Discriminator is easy to break. There are many tricks that can be used to stabilize that, and we're going to apply some here.
The first trick is the Exponential Moving Average: while the Generator is training, we keep a moving average of its weights. At the end we use this smoothed version of the model to generate inference. This model jumps around less and it is less sensitive to sudden changes.
The EMA class accepts a parameter called beta, which controls the size of the window used for averaging. The number of steps (i.e. batches) we are going to average over is approximately equal to 1 / (1 - beta). Since there are 20 batches in our dataloader, if we use beta=0.995 we are averaging over 10 epochs:
ema_G = EMA(
G,
beta = 0.995, # average over the last ~10 epochs
update_after_step = 100, # start averaging after the first 5 epochs
update_every = 1
)
# Initialize weights
_ = G.apply(initialize_weights)
_ = D.apply(initialize_weights)
ok, we're now ready to start training! We will use the "split-batch" technique we have seen in the lesson, where we compute separately the gradients for the Discriminator, first on a batch of real images and then on a batch of fake images. Then we accumulate the gradients and perform one backward pass. For the Generator, we adopt the trick of maximizing log(D(G(z))) instead of minimizing log(1−D(G(z))). This is accomplished by setting the labels for the fake images generated by the Generator to 1 ("real") instead of 0 ("fake"), as we have seen in the lesson.
But first let's look at some more tricks we're using in the training loop.
Trick 3: Label smoothing¶
Label smoothing is a general technique originally proposed in this paper and described in detail here. It consists of substituting the probability for the target class from 1 (hard labels
) to something lower than 1. In case of Binary Classification, the BCELoss gets as input the probability for the positive class, so Label Smoothing becomes as simple as substituting 1 with a random number between 0.8 and 1.2. Label smoothing promotes less overconfidence in the Discriminator and slow down its convergence, especially at the beginning when the Generator is still pretty bad at generating realistic images.
Trick 4: Random flipping¶
In order to make the work of the Discriminator a little harder and prevent it to immediately overwhelm the Generator, it is suggested to add some random noise in the labels for the Discriminator. This is equivalent to flipping some labels from positive to negative. This effectively prevents the Discriminator to ever achieve zero loss.
Trick 5: DiffAugment¶
In this paper the authors introduce a simple set of augmentations to be applied on both the real and the fake images that prevent overfitting in the Discriminator. Since here we only have 5000 examples of real images, overfitting is very easy and this technique will prove very useful.
# Create batch of latent vectors that we will use to visualize
# the progression of the generator
fixed_noise = torch.randn(16, CONFIG['latent_dimension'], 1, 1, device=device)
# Lists to keep track of progress
G_losses = []
D_losses = []
D_acc = []
Complete the code marked by the YOUR CODE HERE placeholder
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
os.environ["TORCH_USE_CUDA_DSA"] = "1"
print("Starting Training Loop...")
tstart = time.time()
n_frame = 0
for epoch in range(CONFIG["n_epochs"]):
# Keep track of losses and accuracy of the Discriminator
# for each batch, we will average them at the end of the
# epoch
batch_G_losses = []
batch_D_losses = []
batch_D_acc = []
# Loop over batch of real data (we throw away the labels
# provided by the dataloader by saving them into `_`)
for data, _ in tqdm.tqdm(dataloader, total=len(dataloader)):
# Move batch to GPU and record its size
# (remember that the last batch could be smaller than batch_size)
data = data.to(device)
b_size = data.size(0)
# This function implements tricks 3 and 4 (smoothing and random label flipping)
labels = get_positive_labels(b_size, device, smoothing=True, random_flip=0.2)
print(labels)
################################################
# Discriminator training #
################################################
# The generator is frozen, gradients flow only
D.zero_grad() # Resets the gradients of all optimized torch
# Forward pass real batch through D using DiffAugment
# augmentation
D_pred = D(DiffAugment(data, policy=CONFIG["policy"])).view(
-1
) # probability to be real
# Measure accuracy for the positive batch
acc_pos = (D_pred > 0.5).sum() / D_pred.size(0)
# Loss on the real data
# Compute the loss on the real data by calling the
# criterion we have defined above
# >>>>>>>>>>>> YOUR CODE HERE
loss_on_real_data = criterion(D_pred, labels)
# Compute the gradients on the real data
# The improve the discriminator
# HINT: you can compute the gradients by calling
# .backward() on the loss
loss_on_real_data.backward() # YOUR CODE HERE
# No .step () to update weights, we need more work
# Now pass a batch of fake data through the model
# Generate batch of latent vectors
# HINT: generate a latent using torch.randn, the shape of the
# latent should be (b_size, CONFIG['latent_dimension'], 1, 1)
# NOTE: add the device=device option as in torch.randn(..., device=device)
# so the latent is created on the GPU (otherwise you'll get an error later)
latent_vectors = torch.randn(
b_size, CONFIG["latent_dimension"], 1, 1, device=device
) # YOUR CODE HERE
# Generate fake image batch with G
# HINT: just call the generator using the latent
fake_data = G(latent_vectors) # YOUR CODE HERE
# Assign negative label as ground truth
# (ground truth labels)
labels.fill_(0) # 0 is the label for fake images
# Get predictions from the Discriminator
# (applying DiffAugment augmentations)
# NOTE: here it is VERY important to use .detach() on the (augmented)
# fake data because we do NOT want the Generator to be part of the computation
# graph used to compute the gradients (we don't want to update the Generator yet)
D_pred = D(
DiffAugment(fake_data, policy=CONFIG["policy"]).detach() # VERY IMPORTANT
).view(-1)
# we need to reuse this data later on in the bacakr of the generator
# , and without the .detach() it will be destroyed
# Get accuracy for this all-fake batch
acc_neg = (D_pred < 0.5).sum() / D_pred.size(0)
# Loss on fake data
# HINT: call the criterion defined above providing the
# discriminator prediction D_pred and the ground truth
# labels
loss_on_fake_data = criterion(D_pred, labels) # YOUR CODE HERE
# This computes the gradients after the fake data
# forward pass and stores them in the tensors
# (model parameters are NOT updated here)
# Remember that .backward() by default does NOT replace
# the gradients coming from the backward pass on the real data.
# Instead, it sums the new gradients with the old gradients
loss_on_fake_data.backward()
# Now we can finally update the Discriminator
# HINT: call a step on the optimizer of the Discriminator
# (optimizerD)
# >>> YOUR CODE HERE
optimizerD.step()
# This will use the gradient accumulated on both: the real and
# the fake data
# Compute error of D as sum over the fake and the real batches
# for safekeeping
total_loss = loss_on_real_data + loss_on_fake_data
################################################
# Generator training #
################################################
# Depending on how good the discriminator is at this stage
# it will give us a higher or lower classification loss
# We then take a backward step and update the parameters
# of the generator in order to decrease the loss obtained
# from the prediction of the discriminator
# The discriminartor is part of the competition graph for this step
# however its weight are frozen so it plays a passive part
# In this phase the weight of the generator are changed
# so that it will fool more and more the discriminator
# Explanaition of the generator objective can be found here:
# https://www.youtube.com/watch?v=wMF0sQO7sNw&t=204s
# the summary is that the objective of the generator is
# equivalent to minimizing the BCE loss of D(G(z)) when
# imposing the binary label equal to 1
# L = -y*log(D(G(z)))
G.zero_grad() # Resets the gradients of all optimized torch
# Remember that BCELoss is −[y logx + (1−y)⋅log(1−x)]
labels.fill_(1) # 1 is the label for "real".
# Since we just updated D, perform another forward pass of
# the all-fake batch we already generated as part of the previous
# part (with DiffAugment)
# NOTE how we are NOT using .detach now, as this time we want the
# gradients for this operation to be accumulated
D_pred = D(DiffAugment(fake_data, policy=CONFIG["policy"])).view(-1)
# Probability of these images to be real according to the discriminator
# Because the weights of the discriminator has changed in the previous step
# minimizing the objective function of the GAN for the generator
# is equivalen to minimizing the BCE loss on the prediction of the discriminator
# when the labels are all positive.
# Loss from the Discriminator prediction that is going
# to be used to update G
# HINT: call the criterion on the prediction of the discriminator
# D_pred and the labels
loss_on_fake_G = criterion(D_pred, labels) # YOUR CODE HERE
# Calculate gradients for G
# HINT: you did this before
loss_on_fake_G.backward() # YOUR CODE HERE
# Update G
# HINT: call a step on the optimizer for the Generator
# (optimizer G)
# >>> YOUR CODE HERE
optimizerG.step()
# Update the Exponential Moving Average copy
ema_G.update()
# Save all losses
batch_G_losses.append(loss_on_fake_G.item())
batch_D_losses.append(total_loss.item())
batch_D_acc.append((0.5 * (acc_pos + acc_neg)).item())
# Take the mean over the epoch
G_losses.append(np.mean(batch_G_losses))
D_losses.append(np.mean(batch_D_losses))
D_acc.append(np.mean(batch_D_acc))
if epoch % CONFIG["save_iter"] == 0:
with torch.no_grad():
fake_viz_data = G(fixed_noise).detach().cpu()
clear_output(wait=True)
fig = training_tracking(D_losses, G_losses, D_acc, fake_viz_data)
plt.show()
fig.savefig(f"{CONFIG['outdir']}/frame_{n_frame:05d}.png")
n_frame += 1
print(f"Finished in {(time.time() - tstart)/60:.1f} min")
visualize_batch(ema_G(fixed_noise).detach().cpu())
Cars start to really appear, although we would probably need quite a bit more training (and parameter tuning) to make it really work!
Challenges in Training GANs
Unstable Balance: Training GANs is delicate; if either the Generator (G) or Discriminator (D) becomes too proficient too quickly, the other lags, disrupting the training process.
No Clear Convergence Indicator: Unlike traditional neural networks, GANs lack a clear metric like validation loss to signify convergence, making it hard to determine the optimal stopping point.
Mode Collapse: A critical issue where the Generator discovers a specific image type that always fools the Discriminator, leading to a lack of diversity in generated images.
Advanced Variants of GANs
To address these challenges, several GAN variants have been developed:
Wasserstein GAN (W-GAN): Introduces a Critic instead of a Discriminator, which assigns continuous scores to images, enhancing training dynamics and reducing mode collapse.
Progressive GANs: These GANs begin by generating low-resolution images, progressively adding details. This approach aids in faster convergence and enables the creation of high-resolution images.
Style GANs (v1, v2 and v3): Incorporate a mapping network to convert the latent vector into a style vector, which is fed along with the latent into the Generator. This, combined with added random noise and a few other innovations, significantly enhances sample quality and robustness.
- Conditional GANs
A notable extension of GANs is the development of conditional GANs (see for example here). They allow for manipulation of specific attributes in the output images, such as changing the view angle, gender, or adding a smile.
The Pros and Cons of Generative Adversarial Networks (GANs) Generative Adversarial Networks (GANs) have made significant strides in the field of AI-driven image generation. Understanding their strengths and weaknesses is key to leveraging their full potential.
Pros of GANs:
Speed: One of the standout features of GANs is their speed during inference. They require only a single forward pass of the latent vector, resulting in sub-second latency on modern GPUs. This makes them incredibly efficient for generating images quickly.
High Sample Quality: GANs are renowned for their excellent sample quality. They consistently rank as state-of-the-art in terms of Fréchet Inception Distance (FID) across various datasets, even when compared to newer generation algorithms like diffusion models. The level of detail in synthetic faces and animals created by GANs is often astonishingly high.
Cons of GANs:
Poor Mode Coverage: A notable drawback of GANs is their tendency to have poor coverage. The Generator in a GAN often prefers exploitation over exploration. Once it finds a method to deceive the Discriminator or Critic, it tends to overuse this approach, leading to a lack of diversity in the generated images.
Training Challenges: GANs are notorious for being tricky to train. They require a deep understanding and implementation of numerous training techniques to function efficiently and produce high-quality results.
In summary, while GANs boast impressive speed and sample quality, they face challenges in terms of mode coverage and training complexity. These factors must be considered when deploying GANs for practical applications in image generation.
Created: 2024-10-23