SOCR ≫ | DSPA ≫ | DSPA2 Topics ≫ |
This DSPA Appendix
shows how to design, train, and use diffusion PDE models for generative
AI modeling. We can use the torch
package, an
R
interface to PyTorch
, to build and train the
diffusion model. In this example, we’ll train a diffusion model
on the MNIST
handwritten digits dataset and use it
to synthetically generate artificially handwritten digits.
Diffusion models are a class of generative AI models that learn to generate data by reversing a gradual noising process. Diffusion models start with pure noise and progressively and iteratively denoise it to produce data resembling the training data. This method has been effective in generating high-quality images. We will examine the foundations of diffusion models, including their mathematical formulation.
Diffusion models define a Markov chain of successive latent variables \(\{x_t\}\) over discrete timesteps \(t = 0, 1, \dotsc, T\), starting from real data \(x_0\) and gradually adding Gaussian noise until the data is completely destroyed into a noise distribution \(x_T\). The generative process then involves learning to reverse this noising process to recover the original data from the noise.
The forward diffusion process \(q\) is a fixed Markov chain that gradually adds Gaussian noise to the data at each timestep \(t\)
\[q(x_{t} \mid x_{t-1}) = \mathcal{N}(x_{t}; \sqrt{\alpha_t} x_{t-1}, (1 - \alpha_t) \mathbf{I}) , \]
where \(\alpha_t \in (0, 1)\) is a variance schedule that controls the amount of noise added at each timestep, and \(\mathcal{N}(\mu, \sigma^2 \mathbf{I})\) denotes a Gaussian distribution with mean \(\mu\) and variance \(\sigma^2 \mathbf{I}\). The cumulative effect over \(t\) timesteps can be derived
\[q(x_t \mid x_0) = \mathcal{N}(x_t; \sqrt{\bar{\alpha}_t} x_0, (1 - \bar{\alpha}_t) \mathbf{I}) , \]
where \(\bar{\alpha}_t = \prod_{s=1}^t \alpha_s\) is the cumulative product of \(\alpha_s\) up to time \(t\).
The reverse diffusion process \(p_\theta\) is parameterized by a neural network and aims to reverse the diffusion process
\[p_\theta(x_{t-1} \mid x_t) = \mathcal{N}(x_{t-1}; \mu_\theta(x_t, t), \Sigma_\theta(x_t, t)) .\]
The goal is to learn \(\mu_\theta\) and \(\Sigma_\theta\) such that the reverse process reconstructs the data from noise.
Here is one approach to parameterize the mean \(\mu_\theta\) in terms of \(x_t\) and the predicted noise \(\epsilon_\theta(x_t, t)\)
\[\mu_\theta(x_t, t) = \frac{1}{\sqrt{\alpha_t}} \left( x_t - \frac{1 - \alpha_t}{\sqrt{1 - \bar{\alpha}_t}} \epsilon_\theta(x_t, t) \right) . \]
Often, the variance \(\Sigma_\theta\) is fixed or simplified.
The training objective is derived from variational inference by minimizing the variational lower bound (VLB) on the negative log-likelihood
\[\mathcal{L} = \mathbb{E}_q \left[ -\log p_\theta(x_0 \mid x_1) + \sum_{t=2}^T D_{\text{KL}} \left( q(x_{t-1} \mid x_t, x_0) \parallel p_\theta(x_{t-1} \mid x_t) \right) \right] .\]
However, this objective can be simplified. When using certain variance schedules and parameterizations, the training objective reduces to a mean squared error (MSE) between the true noise \(\epsilon\) and the predicted noise \(\epsilon_\theta(x_t, t)\)
\[\mathcal{L}_{\text{simple}} = \mathbb{E}_{x_0, \epsilon, t} \left[ \left\| \epsilon - \epsilon_\theta\left( \sqrt{\bar{\alpha}_t} x_0 + \sqrt{1 - \bar{\alpha}_t} \epsilon, t \right) \right\|^2 \right] ,\]
where \(x_0\) is a data sample from the real data distribution, \(\epsilon \sim \mathcal{N}(0, \mathbf{I})\) is standard Gaussian noise, and \(t\) is sampled uniformly from \(\{1, \dotsc, T\}\). This objective trains the model to predict the noise added at each timestep, enabling it to reverse the diffusion process during generation.
At each timestep \(t\), noise is added to the data as follows \(x_t = \sqrt{\alpha_t} x_{t-1} + \sqrt{1 - \alpha_t} \epsilon_{t-1},\) where \(\epsilon_{t-1} \sim \mathcal{N}(0, \mathbf{I})\). By recursively applying this, we can express \(x_t\) directly in terms of \(x_0\)
\[x_t = \sqrt{\bar{\alpha}_t} x_0 + \sqrt{1 - \bar{\alpha}_t} \epsilon ,\] where \(\epsilon \sim \mathcal{N}(0, \mathbf{I})\).
The reverse process aims to sample from \(p_\theta(x_{t-1} \mid x_t)\), approximating \(q(x_{t-1} \mid x_t, x_0)\).
The Posterior Distribution \(q(x_{t-1} \mid x_t, x_0)\) is estimated using Bayes’ theorem and properties of Gaussian distributions
\[q(x_{t-1} \mid x_t, x_0) = \mathcal{N}\left( x_{t-1}; \tilde{\mu}(x_t, x_0), \tilde{\beta}_t \mathbf{I} \right) , \] where \(\tilde{\mu}(x_t, x_0) = \frac{\sqrt{\bar{\alpha}_{t-1}} \beta_t}{1 - \bar{\alpha}_t} x_0 + \frac{\sqrt{\alpha_t} (1 - \bar{\alpha}_{t-1})}{1 - \bar{\alpha}_t} x_t\), \(\tilde{\beta}_t = \frac{1 - \bar{\alpha}_{t-1}}{1 - \bar{\alpha}_t} \beta_t\), and \(\beta_t = 1 - \alpha_t\).
Assuming \(\beta_t\) is small, and simplifying the expressions yields
\[\tilde{\mu}(x_t, x_0) \approx \frac{1}{\sqrt{\alpha_t}} \left( x_t - \frac{\beta_t}{\sqrt{1 - \bar{\alpha}_t}} \epsilon \right) .\]
We can replace \(\epsilon\) with \(\epsilon_\theta(x_t, t)\), the neural network’s prediction.
To simplify the training objective function, we can decompose the variational lower bound into several terms. The most significant term (in terms of learning useful representations) is
\[\mathcal{L}_{t} = D_{\text{KL}} \left( q(x_{t-1} \mid x_t, x_0) \parallel p_\theta(x_{t-1} \mid x_t) \right) .\]
Minimizing this Kullback–Leibler (KL) divergence is equivalent to maximizing the likelihood of the data under the model’s reverse transitions. Using the simplified parameterization, the training objective reduces to
\[\mathcal{L}_{\text{simple}} = \mathbb{E}_{x_0, \epsilon, t} \left[ \left\| \epsilon - \epsilon_\theta(x_t, t) \right\|^2 \right] .\] At each timestep \(t\), the model is trained to predict the noise \(\epsilon\) that was added to \(x_0\) to obtain \(x_t\).
To generate new data, we start from pure noise \(x_T \sim \mathcal{N}(0, \mathbf{I})\) and sequentially apply the reverse transitions. For \(t = T, T-1, \dotsc, 1\)
\[x_{t-1} = \frac{1}{\sqrt{\alpha_t}} \left( x_t - \frac{1 - \alpha_t}{\sqrt{1 - \bar{\alpha}_t}} \epsilon_\theta(x_t, t) \right) + \sigma_t z , \] where \(\sigma_t\) is the standard deviation of the added noise, often set to \(\sqrt{\beta_t}\) and \(z \sim \mathcal{N}(0, \mathbf{I})\) is noise added for stochasticity, except at \(t = 1\). At the end of this process, \(x_0\) is an approximation of a sample from the data distribution.
The choice of \(\beta_t\) or \(\alpha_t\) (variance schedule) is crucial for the model’s performance. Common schedules include
Linear Schedule: \(\beta_t = \beta_{\text{start}} + \frac{t - 1}{T - 1} (\beta_{\text{end}} - \beta_{\text{start}}) .\)
Cosine Schedule: \(\bar{\alpha}_t = \frac{f(t / T)}{f(0)}\), where \(f(\cdot)\) is a cosine function.
These schedules control how quickly noise is added in the forward process and, consequently, the difficulty of the reverse denoising task.
The neural network \(\epsilon_\theta(x_t, t)\) predicts the added noise at each timestep based on:
There are a number of diffusion model extensions and variations that are proposed to improve the model inference. Examples include:
Diffusion generative AI models offer a powerful framework for generative modeling by learning to reverse a gradual noising process applied to data. The key components include the forward diffusion process, the reverse denoising process, and the training objective that guides the model to predict the noise added at each timestep. With their strong theoretical foundation and impressive empirical results, diffusion models have become a cornerstone in generative AI research.
There is a direct connection between partial differential equations (PDEs) and diffusion generative AI models. Diffusion models are deeply rooted in concepts from stochastic differential equations (SDEs) and their associated PDEs, such as the Fokker-Planck equation.
Diffusion generative models are inspired by physical diffusion processes, where particles spread out over time due to random motion. In the context of generative modeling, data points are gradually transformed into noise through a diffusion process, and the model learns to reverse this process to generate new data. The forward diffusion process can be described by a continuous-time stochastic differential equation (SDE), \(dx = f(x, t) dt + g(t) dw\), where \(x\) is the state variable (e.g., an image), \(f(x, t)\) is the drift coefficient, \(g(t)\) is the diffusion coefficient (noise intensity), and \(dw\) is the Wiener process increment (standard Brownian motion).
In many diffusion models, the forward process is designed such that \(f(x, t) = 0\) and \(g(t) = \sqrt{\beta(t)}\), where \(\beta(t)\) is a time-dependent noise schedule.
The reverse of this diffusion process, which the generative model aims to learn, can also be described by an SDE \[dx = \left[ f(x, t) - g(t)^2 \nabla_x \log p_t(x) \right] dt + g(t) dw ,\] where \(\nabla_x \log p_t(x)\) is the score function, the gradient of the log probability density at time \(t\), and the term \(-g(t)^2 \nabla_x \log p_t(x)\) adjusts the drift to reverse the diffusion. This reverse SDE allows sampling from the data distribution by starting from noise and integrating the reverse-time SDE.
The Fokker-Planck equation is a PDE that describes the time evolution of the probability density function \(p(x, t)\) of a stochastic process described by an SDE. For the forward diffusion process, the Fokker-Planck equation is: \[\frac{\partial p(x, t)}{\partial t} = -\nabla_x \cdot [f(x, t) p(x, t)] + \frac{1}{2} \nabla_x^2 [g(t)^2 p(x, t)] ,\] where \(\nabla_x \cdot\) denotes the divergence operator and \(\nabla_x^2\) is the Laplacian operator. This PDE describes how the probability density \(p(x, t)\) diffuses over time due to the stochastic dynamics.
In score-based generative models, the key idea is to estimate the score function \(\nabla_x \log p_t(x)\) at different noise levels \(t\). This estimation is often achieved by training a neural network \(s_\theta(x, t)\) to approximate the score function. The reverse diffusion process relies on solving the reverse-time SDE, which involves the estimated score function. Since the score function is related to the gradient of the log-density, which is a solution to the Fokker-Planck equation, there is a direct connection to PDEs.
There is also a connection between diffusion models and optimal transport theory through the Schrödinger bridge problem, which involves finding the most likely evolution between two probability distributions under a stochastic process. This problem leads to PDEs similar to the Fokker-Planck equation but conditioned on the initial and terminal distributions.
The forward SDE representing the diffusion process is \(dx = \sqrt{2} dW_t\). This represents a simple Wiener process (standard Brownian motion), where \(W_t\) is a Wiener process.
The probability density \(p(x, t)\) evolves according to the PDE \(\frac{\partial p}{\partial t} = \Delta p ,\) where \(\Delta\) is the Laplacian operator and this is the classic heat equation, which is a fundamental PDE in diffusion processes. To generate data, we solve the reverse-time SDE \[dx = [ - \nabla_x \log p_t(x) ] dt + \sqrt{2} d\bar{W}_t ,\] where \(\bar{W}_t\) is a reverse-time Wiener process and the drift term involves the score function \(\nabla_x \log p_t(x)\).
Estimating the score function \(\nabla_x \log p_t(x)\) can be connected to solving certain PDEs. Specifically, the score function is related to the solution of the Fokker-Planck equation. An alternative perspective involves a deterministic process described by an ordinary differential equation (ODE) that transports the data distribution to the noise distribution. This ODE is known as the probability flow ODE \[\frac{dx}{dt} = f(x, t) - \frac{1}{2} g(t)^2 \nabla_x \log p_t(x) .\] Solving this ODE requires knowledge of the score function and is directly connected to the underlying PDEs governing the probability densities.
By leveraging the connection to PDEs, we can use analytical tools from stochastic calculus and PDE theory to analyze and improve diffusion models.
Some models directly work in continuous time, leveraging SDE solvers and benefiting from adaptive step sizes and better approximations. The training objective can be derived from variational principles, connecting to PDEs through the minimization of divergences between probability densities.
The evolution of probability densities in diffusion models is governed by PDEs like the Fokker-Planck equation. By framing the diffusion process in terms of SDEs and associated PDEs facilitates leveraging mathematical tools from these fields to better understand, analyze, and improve diffusion-based generative models.
First, we should ensure we have all necessary packages installed.
# Install torch package if not already installed
if (!require(torch)) {
install.packages("torch")
torch::install_torch()
}
# Install torchvision for datasets and transforms
if (!require(torchvision)) {
install.packages("torchvision")
}
Let’s use the MNIST dataset to demonstrate the formulation and implementation of a simple diffusion generative AI model with a linear noise schedule and a basic U-Net architecture for the denoising function.
# # Install and load necessary packages
# if (!require(torch)) {
# install.packages("torch")
# torch::install_torch()
# }
#
# if (!require(torchvision)) {
# install.packages("torchvision")
# }
library(torch)
library(torchvision)
library(ggplot2)
library(coro)
# Set device to CPU or GPU if available
device <- if (cuda_is_available()) torch_device("cuda") else torch_device("cpu")
cat("Using device:", device$type, "\n")
## Using device: cpu
# Define the transformations
transform <- function(x) {
x <- torch_tensor(x)$unsqueeze(1)$to(dtype = torch_float())$div(255)
x
}
# Load the MNIST dataset
train_dataset <- torchvision::mnist_dataset(
root = "./data",
train = TRUE,
download = TRUE,
transform = transform
)
# Create a data loader
batch_size <- 128
train_loader <- dataloader(train_dataset, batch_size = batch_size, shuffle = TRUE)
# Number of diffusion steps
timesteps <- 1000L # Ensure timesteps is an integer
# Linear noise schedule
beta_start <- 0.0001
beta_end <- 0.02
betas <- torch_linspace(
start = beta_start,
end = beta_end,
steps = timesteps,
dtype = torch_float(),
device = device
)
# Precompute alphas and other terms
alphas <- 1 - betas
alphas_cumprod <- torch_cumprod(alphas, dim = 1)
alphas_cumprod_prev <- torch_cat(
list(torch_tensor(1.0, device = device, dtype = betas$dtype), alphas_cumprod[1:(timesteps - 1)]),
dim = 1
)
sqrt_alphas_cumprod <- torch_sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod <- torch_sqrt(1 - alphas_cumprod)
# U-Net Model Definition
# Define a simplified U-Net model for the denoising function.
# Define conv_block
conv_block <- nn_module(
classname = "conv_block",
initialize = function(in_size, out_size) {
self$conv_block <- nn_sequential(
nn_conv2d(in_size, out_size, kernel_size = 3L, padding = 1L),
nn_relu(),
nn_dropout(0.6),
nn_conv2d(out_size, out_size, kernel_size = 3L, padding = 1L),
nn_relu()
)
},
forward = function(x){
self$conv_block(x)
}
)
# Define down_block
down_block <- nn_module(
classname = "down_block",
initialize = function(in_size, out_size) {
self$conv_block <- conv_block(in_size, out_size)
},
forward = function(x) {
self$conv_block(x)
}
)
# Define up_block
up_block <- nn_module(
classname = "up_block",
initialize = function(in_size, out_size) {
self$up <- nn_conv_transpose2d(in_size, out_size, kernel_size = 2L, stride = 2L)
self$conv_block <- conv_block(in_size, out_size)
},
forward = function(x, bridge) {
up <- self$up(x)
x <- torch_cat(list(up, bridge), dim = 2L)
self$conv_block(x)
}
)
# Define U-Net
unet <- nn_module(
name = "unet",
initialize = function(channels_in=1L, n_classes=1L, depth=3L, n_filters=4L) {
self$down_path <- nn_module_list()
prev_channels <- channels_in
for (i in 1:depth) {
self$down_path$append(down_block(prev_channels, 2^(n_filters + i - 1)))
prev_channels <- 2^(n_filters + i - 1)
}
self$up_path <- nn_module_list()
for (i in ((depth - 1):1)) {
self$up_path$append(up_block(prev_channels, 2^(n_filters + i - 1)))
prev_channels <- 2^(n_filters + i - 1)
}
self$last <- nn_conv2d(prev_channels, n_classes, kernel_size = 1L)
},
forward = function(x) {
blocks <- list()
for (i in 1:length(self$down_path)) {
x <- self$down_path[[i]](x)
if (i != length(self$down_path)) {
blocks[[i]] <- x
x <- nnf_max_pool2d(x, 2L)
}
}
for (i in 1:length(self$up_path)) {
x <- self$up_path[[i]](x, blocks[[length(blocks) - i + 1]])
}
self$last(x)
}
)
# Instantiate the model
model <- unet$new(channels_in = 1L, n_classes = 1L, depth = 3L,
n_filters = 4L)$to(device = device)
# Training the Model
# Define the training loop
# Optimizer
optimizer <- optim_adam(model$parameters, lr = 1e-4)
# Number of epochs
epochs <- 5
# Training loop
for (epoch in 1:epochs) {
cat("Epoch:", epoch, "\n")
model$train()
total_loss <- 0
coro::loop(for (batch in train_loader) {
optimizer$zero_grad()
# Get data and move to device
x <- batch[[1]]$to(device = device)
batch_size <- x$size(1)
# Sample random times t
t <- torch_randint(low = 1L, high = timesteps + 1L, size=list(batch_size),
dtype = torch_long(), device = device)
# Get corresponding alphas
alpha_t <- alphas_cumprod$index_select(1, t)$reshape(c(batch_size, 1, 1, 1))
sqrt_alpha_t <- torch_sqrt(alpha_t)
sqrt_one_minus_alpha_t <- torch_sqrt(1 - alpha_t)
# Sample noise
noise <- torch_randn_like(x)
# Forward diffusion (q)
x_t <- sqrt_alpha_t * x + sqrt_one_minus_alpha_t * noise
# Time embedding (simple scaling for demonstration)
t_norm <- t$to(dtype = x$dtype)$div(timesteps)
t_embed <- t_norm$unsqueeze(-1)$unsqueeze(-1)$unsqueeze(-1)
x_t_with_time <- x_t + t_embed
# Predict noise
noise_pred <- model(x_t_with_time)
# Loss
loss <- nnf_mse_loss(noise_pred, noise)
loss$backward()
optimizer$step()
total_loss <- total_loss + loss$item()
})
avg_loss <- total_loss / length(train_loader)
cat("Average Loss:", avg_loss, "\n")
}
## Epoch: 1
## Average Loss: 0.6119902
## Epoch: 2
## Average Loss: 0.4000889
## Epoch: 3
## Average Loss: 0.3632654
## Epoch: 4
## Average Loss: 0.3360836
## Epoch: 5
## Average Loss: 0.3107214
# Save current model
torch_save(model$state_dict(), paste0("DSPA_DiffModel_MNIST_epoch_", epoch, ".pt"))
torch_save(optimizer$state_dict(), paste0("DSPA_DiffModelOptimizer_MNIST_epoch_", epoch, ".pt"))
################### Continue retraining for another 10 epochs###############
# Set the number of additional epochs you want to train
# additional_epochs <- 10
#
# # Calculate the total number of epochs
# total_epochs <- 5 + additional_epochs # Assuming initial training was 5 epochs
#
# # Continue training from epoch 6 to total_epochs
# for (epoch in 6:total_epochs) {
# cat("Epoch:", epoch, "\n")
# model$train()
# total_loss <- 0
#
# coro::loop(for (batch in train_loader) {
# optimizer$zero_grad()
#
# # Get data and move to device
# x <- batch[[1]]$to(device = device)
# batch_size <- x$size(1)
#
# # Sample random times t
# t <- torch_randint(low = 1L, high = timesteps + 1L, size = list(batch_size), dtype = torch_long(), device = device)
#
# # Get corresponding alphas
# alpha_t <- alphas_cumprod$index_select(1, t)$reshape(c(batch_size, 1, 1, 1))
# sqrt_alpha_t <- torch_sqrt(alpha_t)
# sqrt_one_minus_alpha_t <- torch_sqrt(1 - alpha_t)
#
# # Sample noise
# noise <- torch_randn_like(x)
#
# # Forward diffusion (q)
# x_t <- sqrt_alpha_t * x + sqrt_one_minus_alpha_t * noise
#
# # Time embedding
# t_norm <- t$to(dtype = x$dtype)$div(timesteps)
# t_embed <- t_norm$unsqueeze(-1)$unsqueeze(-1)$unsqueeze(-1)
# x_t_with_time <- x_t + t_embed
#
# # Predict noise
# noise_pred <- model(x_t_with_time)
#
# # Loss
# loss <- nnf_mse_loss(noise_pred, noise)
# loss$backward()
# optimizer$step()
#
# total_loss <- total_loss + loss$item()
# })
#
# avg_loss <- total_loss / length(train_loader)
# cat("Average Loss:", avg_loss, "\n")
#
# # Optionally, save the model after each epoch
# torch_save(model$state_dict(), paste0("DSPA_DiffModel_MNIST_epoch_", epoch, ".pt"))
# torch_save(optimizer$state_dict(), paste0("DSPA_DiffModelOptimizer_MNIST_epoch_", epoch, ".pt"))
# }
# Epoch: 1 Average Loss: 0.6414872
# Epoch: 2 Average Loss: 0.3970883
# Epoch: 3 Average Loss: 0.3658508
# Epoch: 4 Average Loss: 0.3412773
# Epoch: 5 Average Loss: 0.3187902
# Epoch: 6 Average Loss: 0.2975965
# Epoch: 7 Average Loss: 0.2760923
# Epoch: 8 Average Loss: 0.2564409
# Epoch: 9 Average Loss: 0.2407181
# Epoch: 10 Average Loss: 0.228597
# Epoch: 11 Average Loss: 0.2178866
# Epoch: 12 Average Loss: 0.2069024
# Epoch: 13 Average Loss: 0.1947486
# Epoch: 14 Average Loss: 0.1814034
# Epoch: 15 Average Loss: 0.1720325
# Epoch: 16 Average Loss: 0.165519
# Epoch: 17 Average Loss: 0.1594901
# Epoch: 18 Average Loss: 0.1541408
# Epoch: 19 Average Loss: 0.1497078
# Epoch: 20 Average Loss: 0.1448885
# Epoch: 21 Average Loss: 0.1412742
# Epoch: 22 Average Loss: 0.138038
# Epoch: 23 Average Loss: 0.1358833
# Epoch: 24 Average Loss: 0.1329042
# Epoch: 25 Average Loss: 0.1312926
# Epoch: 26 Average Loss: 0.1288575
# Epoch: 27 Average Loss: 0.1274151
# Epoch: 28 Average Loss: 0.1246548
# Epoch: 29 Average Loss: 0.1230804
# Epoch: 30 Average Loss: 0.120948
# Epoch: 31 Average Loss: 0.1190694
# Epoch: 32 Average Loss: 0.1166119
# Epoch: 33 Average Loss: 0.1146551
# Epoch: 34 Average Loss: 0.113805
# Epoch: 35 Average Loss: 0.1124808
# Generating New Samples
# After training, we can generate new samples by reversing the diffusion process.
# Generate new samples
model$eval()
num_samples <- 16
img_size <- c(28, 28)
# Start from pure noise
x <- torch_randn(c(num_samples, 1, img_size[1], img_size[2]), device = device)
# Reverse diffusion process
for (timestep in seq(timesteps, 1, -1)) {
# Create a tensor of timesteps
t <- torch_full(size = c(num_samples), fill_value = timestep, dtype = torch_long(), device = device)
beta_t <- betas$index_select(1, torch_tensor(timestep, dtype = torch_long(), device = device))
sqrt_one_minus_alpha_t <- torch_sqrt(1 - alphas_cumprod$index_select(1, torch_tensor(timestep, dtype = torch_long(), device = device)))
sqrt_recip_alpha_t <- 1 / torch_sqrt(alphas$index_select(1, torch_tensor(timestep, dtype = torch_long(), device = device)))
# Time embedding
t_norm <- t$to(dtype = x$dtype)$div(timesteps)
t_embed <- t_norm$unsqueeze(-1)$unsqueeze(-1)$unsqueeze(-1)
x_t_with_time <- x + t_embed
# Predict noise
with_no_grad({
noise_pred <- model(x_t_with_time)
})
# Compute x_{t-1}
if (timestep > 1) {
noise <- torch_randn_like(x)
} else {
noise <- torch_zeros_like(x)
}
x <- sqrt_recip_alpha_t * (x - beta_t / sqrt_one_minus_alpha_t * noise_pred) + torch_sqrt(beta_t) * noise
}
# Move to CPU and detach
x <- x$cpu()$detach()
# Visualizing the Results
# Display the generated images.
# # Function to plot images in a grid
# plot_images <- function(images, nrow = 4, ncol = 4) {
# par(mfrow = c(nrow, ncol), mar = c(0, 0, 0, 0))
# for (i in 1:(nrow * ncol)) {
# img <- as.array(images[i, 1, , ])
# image(t(apply(1 - img, 2, rev)), axes = FALSE, col = gray.colors(256))
# }
# }
#
# # Plot the generated images
# plot_images(x, nrow = 4, ncol = 4)
# # Install and load plotly
# if (!require(plotly)) {
# install.packages("plotly")
# }
library(plotly)
# Function to plot images in a grid using plotly
plot_images <- function(images, nrow = 4, ncol = 4) {
total_images <- nrow * ncol
plots <- vector("list", total_images)
for (i in 1:total_images) {
img <- as.array(images[i, 1, , ])
img <- 1 - img # Invert colors for better visualization
img <- t(apply(img, 2, rev)) # Flip the image vertically
# Create a plotly heatmap for the image
p <- plot_ly(
z = img,
type = "heatmap",
colorscale = "Gray",
showscale = FALSE,
hoverinfo = "none"
) %>%
layout(
xaxis = list(
showticklabels = FALSE,
zeroline = FALSE,
showgrid = FALSE
),
yaxis = list(
showticklabels = FALSE,
zeroline = FALSE,
showgrid = FALSE
),
margin = list(l = 0, r = 0, b = 0, t = 0),
paper_bgcolor = 'rgba(0,0,0,0)',
plot_bgcolor = 'rgba(0,0,0,0)'
)
plots[[i]] <- p
}
# Arrange the plots in a grid using subplot
subplot(plots, nrows = nrow, shareX = TRUE, shareY = TRUE, margin = 0.01)
}
# Plot the generated images
plot_images(x, nrow = 4, ncol = 4)
This example demonstrates the end-to-end process of training and
testing a diffusion generative AI model in R
. The
protocol includes data preparation, model definition, training,
sampling, and visualization.
In the above example, we only trained the diffusion model over \(5\) epochs to demonstrate the pragmatic aspects of AI model training. In reality, we always need to run more epochs to get better models. In these situations, we can first train offline and save the diffusion model and then load back in the interactive session the pretrained diffusion model for inference and synthetic generation of handwritten digits.
# Specify the file path where you want to save the model
model_save_path <- "diffusion_model_state.pt"
# Save the model's state dictionary
torch_save(model$state_dict(), model_save_path)
# Specify the file path where you want to save the optimizer state
optimizer_save_path <- "optimizer_state.pt"
# Save the optimizer's state dictionary
torch_save(optimizer$state_dict(), optimizer_save_path)
The function model$state_dict()
retrieves a list of all
the parameters and buffers (weights, biases, etc.) of the model. Saving
models allows subsequent resumption of training later and also supports
saving the optimizer’s state:
Here is how to load a pretrained model back into an active
R
session. Also, if we need to continue the model training
after a pause, we can also load in any saved optimizer’s
state.
# Load the necessary libraries
library(torch)
# Ensure you have the same model architecture code loaded
# (Assuming the 'unet' module is defined as before)
# Re-instantiate the model architecture
model <- unet$new(channels_in = 1L, n_classes = 1L, depth = 3L, n_filters = 4L)
# Load the model's state dictionary from the saved file
state_dict <- torch_load("diffusion_model_state.pt")
# Load the state dictionary into the model
model$load_state_dict(state_dict)
# Move the model to the appropriate device
device <- if (cuda_is_available()) torch_device("cuda") else torch_device("cpu")
model$to(device = device)
# Re-instantiate the optimizer with the model's parameters
optimizer <- optim_adam(model$parameters, lr = 1e-4)
# Load the optimizer's state dictionary from the saved file
optimizer_state_dict <- torch_load("optimizer_state.pt")
# Load the state dictionary into the optimizer
optimizer$load_state_dict(optimizer_state_dict)
Loading for Inference or Further Training*
# Load the necessary libraries and define the model architecture
library(torch)
library(torchvision)
# [Include any other necessary library and model definitions]
# Re-instantiate the model architecture
model <- unet$new(channels_in = 1L, n_classes = 1L, depth = 3L, n_filters = 4L)
# Load the model's state dictionary
state_dict <- torch_load("diffusion_model_state.pt")
# Load the state dictionary into the model
model$load_state_dict(state_dict)
# Move the model to the appropriate device
device <- if (cuda_is_available()) torch_device("cuda") else torch_device("cpu")
model$to(device = device)
# If resuming training, re-instantiate and load the optimizer
optimizer <- optim_adam(model$parameters, lr = 1e-4)
optimizer_state_dict <- torch_load("optimizer_state.pt")
optimizer$load_state_dict(optimizer_state_dict)
# Continue training or perform inference ...
When loading the model, it’s crucial that the model architecture
matches exactly the one used during training. All custom modules and
functions (like unet
, down_block
,
up_block
, etc.) must be defined in the current
R
session before loading the state dictionary. If the
diffusion model is trained on a GPU but is loaded on a CPU (or vice
versa), we need appropriate device mapping
state_dict <- torch_load("diffusion_model_state.pt", map_location = device)
It’s also important to ensure that the versions of the
torch
and torchvision
packages are the same
when saving and loading the model. Differences in
package versions can sometimes lead to incompatibilities. Saving the
entire model object using
torch_save(model, "model.pt")
, it’s generally not
recommended, since saving only the state dictionary is more flexible and
less prone to errors due to code changes.
Once a diffusion model is loaded back in, it can be used for generating new synthetic samples (handwritten digit images).
# Ensure the model is in evaluation mode
model$eval()
# Number of samples to generate
num_samples <- 16
img_size <- c(28, 28)
# Start from pure noise
x <- torch_randn(c(num_samples, 1, img_size[1], img_size[2]), device = device)
# Reverse diffusion process (same as before)
for (timestep in seq(timesteps, 1, -1)) {
t <- torch_full(size = c(num_samples), fill_value = timestep, dtype = torch_long(), device = device)
beta_t <- betas$index_select(1, torch_tensor(timestep, dtype = torch_long(), device = device))
sqrt_one_minus_alpha_t <- torch_sqrt(1 - alphas_cumprod$index_select(1, torch_tensor(timestep, dtype = torch_long(), device = device)))
sqrt_recip_alpha_t <- 1 / torch_sqrt(alphas$index_select(1, torch_tensor(timestep, dtype = torch_long(), device = device)))
# Time embedding
t_norm <- t$to(dtype = x$dtype)$div(timesteps)
t_embed <- t_norm$unsqueeze(-1)$unsqueeze(-1)$unsqueeze(-1)
x_t_with_time <- x + t_embed
# Predict noise
with_no_grad({
noise_pred <- model(x_t_with_time)
})
# Compute x_{t-1}
if (timestep > 1) {
noise <- torch_randn_like(x)
} else {
noise <- torch_zeros_like(x)
}
x <- sqrt_recip_alpha_t * (x - beta_t / sqrt_one_minus_alpha_t * noise_pred) + torch_sqrt(beta_t) * noise
}
# Move to CPU and detach
x <- x$cpu()$detach()
# Plot the generated images using your preferred method
plot_images(x, nrow = 4, ncol = 4)
For long diffusion model training sessions, we can periodically save checkpoints (e.g., every few epochs) to prevent data loss in case of interruptions.
The U-Net architecture may not have sufficient capacity to model the complexity of the data. By expanding the depth of the U-Net we can add more layers or increase the number of filters in each layer.
# Increase the depth and number of filters
model <- unet$new(channels_in = 1L, n_classes = 1L, depth = 5L,
n_filters = 6L)$to(device = device)
Also, we can add attention mechanism layers to help the model focus on important features. Using residual connections ensures the U-Net effectively improves the gradient flow. Alternative network architectures to consider with diffusion models are discussed in DDPM.
More epochs may be necessary for the model-training process to converge, especially with complex models and data. Increase the number of epochs to train the model (e.g., \(> 100\)) and monitor the loss to check for convergence. Consider implementing checkpointing; save model checkpoints to resume training if needed.
The learning rate may be too high or too low, causing the model to converge poorly. In such situations, adjust and experiment with the learning rate, e.g., 1e-4, 5e-5, or use learning rate schedulers.
# Use a learning rate scheduler
scheduler <- lr_scheduler_reduce_on_plateau(optimizer, patience = 5,
factor = 0.5)
Also, we can try alternative optimizers, e.g., AdamW or RMSprop, or implement gradient clipping to prevent exploding gradients.
Simple time embeddings may not provide sufficient temporal information to the model. We can try using Sinusoidal Time Embedding which is very effective in transformer models.
# Example of sinusoidal time embedding
sinusoidal_embedding <- function(timesteps, embedding_dim) {
position <- torch_arange(0, timesteps - 1, dtype = torch_float(),
device = device)
div_term <- torch_exp(torch_arange(0, embedding_dim - 1, 2,
dtype=torch_float(), device=device) * (-math$log(10000.0)/embedding_dim))
pe <- torch_zeros(timesteps, embedding_dim, device = device)
pe[, seq(1, embedding_dim, 2)] <-
torch_sin(position$unsqueeze(1) * div_term)
pe[, seq(2, embedding_dim, 2)] <-
torch_cos(position$unsqueeze(1) * div_term)
pe
}
Then, modify the U-Net to accept time embeddings and incorporate them into the network, possibly via concatenation or addition at various layers.
Initially, we used a linear noise schedule, which may not be optimal. Alternative schedules include cosine or quadratic variance schedules.
# Cosine schedule example
s <- 0.008
steps <- timesteps
t <- torch_linspace(0, steps, steps + 1, device = device)
alphas_cumprod <- torch_cos((t / steps + s) / (1 + s) * (math$pi / 2))^2
alphas_cumprod <- alphas_cumprod / alphas_cumprod[1]
betas <- 1 - (alphas_cumprod[2:] / alphas_cumprod[1:-1])
Also, we need to ensure that the beta values start small and gradually increase.
If the model does not see enough diversity in the data, we can use data augmentation, including random rotations, shifts, or scaling to the MNIST dataset to increase diversity.
# Define data augmentation transforms
transform <- function(x) {
x <- x$unsqueeze(1)$to(dtype = torch_float())$div(255)
x <- torchvision::transform_random_rotation(x, degrees = 15)
x
}
The data should be normalized appropriately (mean \(0\), standard deviation \(1\)). Combining the native MNIST images with other handwritten digit datasets may be beneficial.
The mean squared error (MSE) loss may not be sufficient, in which case, we should experiment with alternative Loss Functions, e.g., L1 loss or a combination of MSE and perceptual loss. We can also weigh the loss differently at various timesteps to focus on more challenging ones.
When the model appears to be overfitting or underfitting, we can modify the dropout rates or apply weight decay (e.g., L2 regularization) in the optimizer.
# Add weight decay to the optimizer
optimizer <- optim_adam(model$parameters, lr = 1e-4, weight_decay = 1e-5)
Incorporating batch normalization layers may help stabilize the model training process, which can be tracked and plotted over epochs to check for convergence.
# After each epoch
loss_values <- c(loss_values, avg_loss)
plot(loss_values, type = 'l', xlab = 'Epoch', ylab = 'Loss')
Hyperparameter tuning, e.g., using Bayesian optimization may yield optimal parameter settings (computationally expensive). Tools like Optuna or Hyperopt may also be used to automate parameter tuning.
Small batch sizes may lead to noisy gradient estimates. When appropriate resources are available, we can use a larger batch size to stabilize training. If there is limited memory, we can accumulate gradients over multiple batches. Finally, the basic diffusion model may be enhanced by incorporating classifier guidance (when labels are available to guide the generation process using class information), exploring Denoising Diffusion Implicit Models (DDIM) for faster and potentially better generation, or by implementing Noise Scheduling Strategies to improve sample quality.
with_no_grad()
.CUDA
installed if you plan to use a GPU.torch
and torchvision
.Sampling Random Times \(t\):
t <- torch_randint(low = 1L, high = timesteps + 1L, size = list(batch_size), dtype = torch_long(), device = device)
The torch_randint()
function generates random
integers between low (inclusive) and high (exclusive).
Therefore, we need to set high = timesteps + 1L to include
timesteps.
Using index_select()
to index
alphas_cumprod
with \(t\):
alphas_cumprod
is a \(1\)-dimensional tensor of size
[timesteps], \(t\) is a tensor
of indices of size [batch_size]. We use
index_select(dim, index)
to select elements from
alphas_cumprod
at the indices specified in \(t\). The resulting tensor alpha_t
has size [batch_size], which is reshaped to [batch_size, 1,
1, 1] for broadcasting purposes in subsequent operations.
alpha_t <- alphas_cumprod$index_select(1, t)$reshape(c(batch_size, 1, 1, 1))
t <- torch_full(size = c(num_samples), fill_value = timestep, dtype = torch_long(), device = device)
index_select()
to Get beta_t and Other
Variables: We use index_select()
to get the values of
beta_t, alphas_cumprod
, and alphas at the
current timestep. To create a tensor containing the current timestep we
use
torch_tensor(timestep, dtype = torch_long(), device = device)
.
beta_t <- betas$index_select(1, torch_tensor(timestep, dtype = torch_long(), device = device))
sqrt_one_minus_alpha_t <- torch_sqrt(1 - alphas_cumprod$index_select(1, torch_tensor(timestep, dtype = torch_long(), device = device)))
sqrt_recip_alpha_t <- 1 / torch_sqrt(alphas$index_select(1, torch_tensor(timestep, dtype = torch_long(), device = device)))
torch_long()
. All tensors are
moved to the correct device (cpu or cuda/gpu) to prevent device mismatch
errors.The index_select()
function is essential to index a
tensor using another tensor of indices in torch for R
.
Avoiding Direct Tensor Indexing with \([ \
]\): In torch for R
, direct indexing of tensors with
another tensor using square brackets \([ \
]\) is not supported as it is in some other languages or
libraries. Instead, use functions like index_select()
,
gather()
, or take()
.
R
or other languages for
inspiration.