Fast Sampling of Diffusion Models with Exponential Integrator

January 3, 2023 · View on GitHub

Qinsheng Zhang·Yongxin Chen

PaperProject Page


deis


  • 2021-11-17 DEIS accelerates large scale text-to-image eDiff-I and achieves SOTA performance.

ediffi

Update

  • BREAKING CHANGE: v1.0 API changes greatly as we add ρRK-DEIS and ρAB-DEIS algorithms and more choice for time scheduling. If you are only interested in tAB-DEIS / iPNDM or previous codebase, check v0.1

Usage

# for pytorch user
pip install "jax[cpu]"

If diffusion models are trained with continuous time

import jax_deis as deis

def eps_fn(x_t, scalar_t):
    vec_t = jnp.ones(x_t.shape[0]) * scalar_t
    return eps_model(x_t, vec_t)

# pytorch
# import th_deis as deis
# def eps_fn(x_t, scalar_t):
#     vec_t = (th.ones(x_t.shape[0])).float().to(x_t) * scalar_t
#     with th.no_grad():
#         return eps_model(x_t, vec_t)

# mappings between t and alpha in VPSDE
# we provide popular linear and cos mappings
t2alpha_fn,alpha2t_fn = deis.get_linear_alpha_fns(beta_0=0.01, beta_1=20)

vpsde = deis.VPSDE(
    t2alpha_fn, 
    alpha2t_fn,
    sampling_eps, # sampling end time t_0
    sampling_T # sampling starting time t_T
)

sampler_fn = deis.get_sampler(
    # args for diffusion model
    vpsde,
    eps_fn,
    # args for timestamps scheduling
    ts_phase="t", # support "rho", "t", "log"
    ts_order=2.0,
    num_step=10,
    # deis choice
    method = "t_ab", # deis sampling algorithms: support "rho_rk", "rho_ab", "t_ab", "ipndm"
    ab_order= 3, # greater than 0, used for "rho_ab", "t_ab" algorithms, other algorithms will ignore the arg
    rk_method="3kutta" # used for "rho_rk" algorithms, other algorithms will ignore the arg
)

sample = sampler_fn(noise)

If diffusion models are trained with discrete time

#! by default the example assumes sampling 
#! from t=len(discrete_alpha) - 1 to t=0
#! totaly len(discrete_alpha) steps if we use delta_t = 1
vpsde = deis.DiscreteVPSDE(discrete_alpha)

A short derivation for DEIS

Exponential integrator in diffusion model

The key insight of exponential integrator is taking advantage of all math structures present in ODEs. The goal is to reduce discretization error as small as possible.

The math structure in diffusion models includes semilinear structure, the analytic formula for drift and diffusion coefficients.

Below we present a short derivation for applications of the exponential integrator in diffusion model.

Forward SDE

dx=Ftxdt+Gtdwdx = F_tx dt + G_td\mathbf{w}

Backward ODE

dx=Ftxdt+0.5GtGtTLtTϵ(x,t)dtdx = F_tx dt + 0.5 G_tG_t^T L_t^{-T} \epsilon(x, t) dt

where LtLtT=ΣtL_t L_t^{T} = \Sigma_t and Σt\Sigma_t are variance of p0t(xtx0)p_{0t}(x_t | x_0).

Exponential Integrator

We can get rid of semilinear structure with Exponential Integrator by introducing a new variable yy

yt=Ψ(t)xtΨ(t)=exp0tFτdτy_t = \Psi(t) x_t \quad \Psi(t) = \exp{-\int_0^{t} F_\tau d \tau}

And ODE is simplified into

y˙t=0.5Ψ(t)GtGtTLtTϵ(x(y),t)\dot{y}_t = 0.5 \Psi(t) G_t G_t^T L_t^{-T} \epsilon(x(y), t)

where x(y)x(y) maps yty_t to xtx_t.

Time scaling

We can take one step further when Ft,GtF_t, G_t are scalars by rescaling time

v˙ρ=ϵ(x(v),t(ρ))\dot{v}_\rho = \epsilon(x(v), t(\rho))

where yt=vρy_t = v_\rho and dρ=0.5Ψ(t)GtGtTLtTdtd \rho = 0.5 \Psi(t) G_t G_t^T L_t^{-T} dt. And x(v)x(v) maps vρv_\rho to xtx_t, t(ρ)t(\rho) maps ρ\rho to tt.

High order solver

By absorbing all math structure, we reach the following ODE

v˙ρ=ϵ(x(v),t(ρ))\dot{v}_\rho = \epsilon(x(v), t(\rho))

As RHS is a nerual network, we can not further simplify ODE unless we have knowledge for the black-box function. Then we can use well-established ODE solvers, such as multistep and runge kutta.

Demo

Reference

@article{zhang2022fast,
  title={Fast Sampling of Diffusion Models with Exponential Integrator},
  author={Zhang, Qinsheng and Chen, Yongxin},
  journal={arXiv preprint arXiv:2204.13902},
  year={2022}
}