Diffusion Theory

June 11, 2023 · View on GitHub

Open In Collab 01-Diffusion-Sandbox - This note relates to the notebook, which includes more visualizations of the diffusion process!

Diffusion models solve a task similar to GANs (and other generative model types, like VAEs or Normalizing Flows) - they attempt to approximate some probability distribution of a given domain q(x)q(x) and most importantly, provide a way to sample from that distribution xq(x)x \sim q(x).

This is achieved by optimizing some parameters θ\theta (represented as a neural network) that result in a probability distribution pθ(x)p_\theta(x). The objective of training is so that pθp_\theta produces similar samples xx to those drawn from the true underlying distribution q(x)q(x).

:bulb: What's different from GANs?

  • GANs produce samples from a latent vector in a single forward pass through the Generator network. The likelihood of the produced samples is controlled by the Discriminator network, which is trained to distinguish between xq(x)x \sim q(x) and xpθ(x)x \sim p_\theta(x).
  • Diffusion models use a single network that is used to sequentially converge to an approximation of a real sample xq(x)x \sim q(x) through several estimation steps. So, the model input and output are generally of the same dimensionality.

:wrench: Mechanism of the Denoising Diffusion Process

Denoising Diffusion Process consists of a chain of steps in two directions, corresponding do destruction and creation of information in the sample.

:point_right: Forward Process

With access to a sample at a time step tt, one can make an estimation about the next sample in the forward process, defined by the true distribution qq:

q(xtxt1)(1)q(x_{t}|x_{t-1})\tag{1}

Quite often, what is available are the samples at time step $0$ (meaning clean samples), and it is then useful to use the types of operation that allow easy and efficient formulation of:

q(xtx0)(2)q(x_{t}|x_{0})\tag{2}

So far, the most common choice for a forward process has been Gaussian. Easy to compute and convenient in various respects:

q(xtxt1)=N(1βtxt1,βtI)(3)q(x_{t}|x_{t-1}) = \mathcal{N}(\sqrt{1-\beta_t}x_{t-1}, \beta_t I)\tag{3}

the notation above simply means that the previous sample is scaled down by a factor of 1βt\sqrt{1-\beta_t} and additional Gaussian noise (sampled from a zero-mean unit-variance Gaussian) multiplied by βt\beta_t is added.

Furthermore, the $0\to t$ step can also be easily defined as:

q(xtx0)=N(αtˉx0,(1αtˉ)I)(4)q(x_{t}|x_{0}) = \mathcal{N}(\sqrt{\bar{\alpha_t}}x_0, (1-\bar{\alpha_t}) I) \tag{4}

where αt=1βt\alpha_t = 1-\beta_t and

αtˉ=i=0tαt(5)\bar{\alpha_t}=\prod_{i=0}^{t}\alpha_t \tag{5}

:point_left: Reverse Process

The reverse process is designed to restore the information in the sample, which allows to generate a new sample from the distribution. Generally, it will start at some high time step tt (very often at t=Tt=T, which indicates the end of the diffusion chain, where the probability distribution is extremely close to a pure Gaussian), and attempt to approximate the distribution of the previous sample t1t-1.

pθ(xt1xt)p_\theta(x_{t-1}|x_t)

If diffusion steps are small enough, the reverse process of a Gaussian forward process can also be approximated by a Gaussian:

pθ(xt1xt)=N(μθ(xt,t),Σθ(xt,t))(6)p_\theta(x_{t-1}|x_t) = \mathcal{N}(\mu_\theta(x_t,t),\Sigma_\theta(x_t,t))\tag{6}

The reverse process is often parameterized using a neural network θ\theta, a common good candidate for approximating complex transformations. In many cases, a standard deviation function σt\sigma_t independent of xtx_t can be used:

pθ(xt1xt)=N(μθ(xt,t),σt2I)(7)p_\theta(x_{t-1}|x_t) = \mathcal{N}(\mu_\theta(x_t,t),\sigma_t^2 I)\tag{7}

:steam_locomotive: DDPM: Denoising Diffusion Probabilistic Model

DDPM is one of the first popular approaches to denoising diffusion. It generates samples by following the reverse process through all T steps of the diffusion chain.

When it comes to parameterizing the mean μθ(xt,t)\mu_\theta(x_t,t) of the reverse process distribution, the network can either:

  1. Predict it directly as μθ(xt,t)\mu_\theta(x_t,t)
  2. Predict the original t=0t=0 sample x0x_0, where
μ~θ=αˉt1βt1αˉtx0+αt(1αˉt1)1αˉtxt(8)\tilde{\mu}_\theta = \frac{\sqrt{\bar{\alpha}_{t-1}}\beta_t}{1-\bar{\alpha}_t}x_0 + \frac{\sqrt{\alpha_t}(1-\bar{\alpha}_{t-1})}{1-\bar{\alpha}_t}x_t \tag{8}
  1. Predict the normal noise sample ϵ\epsilon (from a unit-variance distribution), which has been added to the sample x0x_0
x0=1αˉt(xt1αˉtϵ)(9)x_0=\frac{1}{\sqrt{\bar{\alpha}_t}}(x_t-\sqrt{1-\bar{\alpha}_t}\epsilon) \tag{9}

The third option, where the network predicts ϵ\epsilon appears to be most common, and that's what is being done in DDPM. This yields to a new equation for μ~θ\tilde{\mu}_{\theta} expressed in terms of xtx_t and ϵ\epsilon:

μ~θ=αˉt1βt1αˉt(1αˉt(xt1αˉtϵ))+αt(1αˉt1)1αˉtxt(10)\tilde{\mu}_\theta = \frac{\sqrt{\bar{\alpha}_{t-1}}\beta_t}{1-\bar{\alpha}_t}(\frac{1}{\sqrt{\bar{\alpha}_t}}(x_t-\sqrt{1-\bar{\alpha}_t}\epsilon)) + \frac{\sqrt{\alpha_t}(1-\bar{\alpha}_{t-1})}{1-\bar{\alpha}_t}x_t \tag{10}

and hence

μ~θ=1αt(xtβt1αˉtϵ)(11)\tilde{\mu}_\theta =\frac{1}{\sqrt{\alpha_t}}(x_t-\frac{\beta_t}{\sqrt{1-\bar{\alpha}_t}}\epsilon) \tag{11}

...which is the key equation for DDPM used for sampling.

Training

Training a model tasked to predict the noise shape ϵ\epsilon is quite straightforward.

At each training step:

  1. Use forward process to generate a sample xtq(xtx0)x_t \sim q(x_t|x_0) for a tt sampled uniformly from [1,T][1,T]:
    1. Sample time step tt from a uniform distribution tU(1,T)t \sim \mathcal{U}(1,T)
    2. Sample ϵ\epsilon from a normal Gaussian ϵN(0,1)\epsilon \sim \mathcal{N}(0,1)
    3. Compute noisy input sample xtx_t for training via xt=αˉtx0+1αˉtϵx_t = \sqrt{\bar{\alpha}_t}x_0 + \sqrt{1-\bar{\alpha}_t}\epsilon
  2. Compute the approximation of noise ϵt^=pθ(xt,t)\hat{\epsilon_t}=p_\theta(x_t,t) using the model with parameters θ\theta
  3. Minimize the error between ϵt\epsilon_t and ϵt^\hat{\epsilon_t} by optimizing parameters θ\theta

Sampling

Generation begins at t=Tt=T by sampling from the last step xTN(0,1)x_T \sim \mathcal{N}(0,1) in the diffusion process, which is modelled by a normal Gaussian.

Then, until t=0t=0 is reached, the network makes a prediction of noise in the sample ϵ~=pθ(xt,t)\tilde{\epsilon}=p_\theta(x_t,t) and then approximates the mean of the process at t1t-1, using:

μ~θ=1αt(xtβt1αˉtϵ~)(12)\tilde{\mu}_\theta =\frac{1}{\sqrt{\alpha_t}}(x_t-\frac{\beta_t}{\sqrt{1-\bar{\alpha}_t}}\tilde{\epsilon})\tag{12}

Hence, the next sample at t1t-1 is sampled from the Gaussian distribution like below:

xt1N(μ~θ,σt2I)(13)x_{t-1} \sim \mathcal{N}(\tilde{\mu}_\theta,\sigma_t^2 I)\tag{13}

...until x0x_0 is reached, in which case only the mean μ~θ\tilde{\mu}_\theta is extracted as output.

:bullettrain_front: (Sampling Faster) DDIM: Denoising Diffusion Implicit Model

Warning: if you look up the original DDIM paper, you will see the symbol αt\alpha_t used for αˉt\bar{\alpha}_t. In this note, no such notation change is made for the sake of consistency.

DDPM reverse process attempts to navigate the diffusion chain of T steps in the reverse order. However, there as shown in (9), the reverse process involves an approximation of the clean sample x0x_0.

If we substitute t1t-1 for tt in (4):

q(xt1x0)=N(αˉt1x0,(1αˉt1)I)(14)q(x_{t-1}|x_{0}) = \mathcal{N}(\sqrt{\bar{\alpha}_{t-1}}x_0, (1-\bar{\alpha}_{t-1}) I)\tag{14}

which yields

xt1αˉt1x0+1αˉt1ϵt1(15)x_{t-1} \leftarrow \sqrt{\bar{\alpha}_{t-1}}x_0 + \sqrt{1-\bar{\alpha}_{t-1}} \epsilon_{t-1}\tag{15}

...and based on a specific ϵt\epsilon_t measured at the previous step tt, it can be rewritten as:

xt1αˉt1x0+1αˉt1σt2ϵt+σtϵ(16)x_{t-1} \leftarrow \sqrt{\bar{\alpha}_{t-1}}x_0 + \sqrt{1-\bar{\alpha}_{t-1}-\sigma_t^2} \epsilon_{t} + \sigma_t \epsilon\tag{16}

Generally, σt\sigma_t is set to:

σt2=β~t=1αˉt11αˉtβt(17)\sigma_t^2 = \tilde{\beta}_t = \frac{1-\bar{\alpha}_{t-1}}{1-\bar{\alpha}_t}\beta_t\tag{17}

Further, we can introduce a new parameter η\eta to control the magnitude of the stochastic component:

\sigma_t^2 = \eta \tilde{\beta}_t \tag{18}

As found in the original DDIM paper, setting η=0\eta=0 appears to be particularly beneficial when fewer steps of the reverse process are applied and that specific type of process is known as Denoising Diffusion Implicit Model (DDIM). The above formulation is still consistent with DDPM when η=1\eta=1.

:flashlight: So, how can the reverse chain be navigated in the reverse direction? First, a sequence of fewer steps SS is defined as a subset {τ1,τ2,...,τS}\{\tau_1, \tau_2, ..., \tau_S\} of the original temporal steps of the forward process. Sampling is then based on (16).

At each step:

  1. Predict x0x_0
  2. Compute the direction towards current xtx_t
  3. (If not DDIM) inject some noise for the stochastic functionality

It can generally be assumed that DDIM

  • Offers better sample quality at fewer steps
  • Allows for deterministic matching between the starting noise xTx_T and the generated sample x0x_0
  • Performs worse than DDPM for large numbers of steps (such as 1000)