Updated: 2024-02-26 Mon 10:47

Diffusion and score-based generative modeling

Diffusion and score-based generative modeling

There are great blog posts of what diffusion and score matching is elsewhere, in particular, see Lilian Weng's literature review and the great exposition of Yang Song on learning score functions for generative modeling. Here I will mainly lean on the blog post of Yang Song and his and his collaborators paper Generative Modeling by Estimating Gradients of the Data Distribution (Song and Ermon 2020) as I find it very comprehensive and well-written.

Some of the sections are pretty technical, for the actual implementation you only need to

  • Read the Setup section,
  • Understand how the loss \(\ell(\theta; \sigma)\) defined in \ref{eq:simplified-score-matching-objective} is used to build the optimization objective \(\hat{\mathcal{E}}(\theta; (\sigma_{l})_{l=1}^{L})\) defined in \ref{eq:aggregated-final-empirical-risk} which we train on to produce the estimator \(\hat{\theta}_{n}\) which are the learned parameters of the score network \(s_{\theta}\),
  • Read the Generating samples section to understand how to generate samples using \(s_{\hat{\theta}_{n}}\),

and you should then be able to follow along with the implementation section.

Setup

To start with, we assume that we have a dataset of iid samples \((x_{i})_{i=1}^{n}\) sampled from some unknown data distribution \(p^{\ast}\), where the datapoints live in some space \(\mathcal{X}\) which we will take to be some Euclidean vector space (for example, \(\mathbb{R}^{D}\) for a vector or \(\mathbb{R}^{H \times W \times C}\) for an image with width \(W\), height \(H\) and \(C\) color channels). Everything is nice so we assume that \(p^{\ast}\) has a pdf and identify the distribution with this pdf (so \(p^{\ast}(x)\) is the density at \(x\)). The goal is to learn a model which would allow us to sample from \(p^{\ast}\). One way to do this would be to model \(p^{\ast}\) directly, but as fortune has it, it is enough to learn a model of the score function \(s^{\ast}(x) = \nabla_{x} \log p^{\ast}(x)\) to accomplish this.

Learning the score function using score matching allows for much easier training and modelling than trying to learn a model of \(p^{\ast}\) directly. This is due to not having to learn a properly normalized distribution but only up to a constant. If rewrite \(p^{\ast}(x) = \exp(-f^{\ast}(x))/Z^{\ast}\), the score function takes the form \(-\nabla_{x} f^{\ast}(x)\) since \[ \nabla_{x} \log p^{\ast}(x) = -\nabla_{x} f^{\ast}(x) - \nabla_{x}\log Z^{\ast} = -\nabla_{x} f^{\ast}(x) \] as \(Z^{\ast}\) is independent of \(x\).

Score matching

Score matching aim to minimize the least-squares objective

\begin{equation} \label{eq:lsq-score-matching-objective} \frac{1}{2}\mathbb{E}_{X \sim p^{\ast}}\|s_{\theta}(X) - s^{\ast}(X)\|^{2} \end{equation}

where \(s_{\theta}: \mathcal{X} \to \mathcal{X}\) is a model of the score function, for example a neural network. Of course, we don't know \(s^{\ast}\) so this objective is not very good, but it can be shown to be proportional to

\begin{equation} \label{eq:tr-score-matching-objective} \mathbb{E}_{X \sim p^{\ast}}\left[\mathrm{tr}\left(\nabla_{x} s_{\theta}(X))\right) + \frac{1}{2}\|s_{\theta}(X)\|_{2}^{2}\right]. \end{equation}

In practice, we replace the distribution \(p^{\ast}\) by the empirical version \(\hat{p}_{n}\) using the train dataset \((x_{i})_{i=1}^{n}\). When the input dimension is large the trace computation becomes too computational burdensome so we rely on other approximation. We will use denoising score matching, but there are other ways, in (Song and Ermon 2020) they also mention sliced score matching as an alternative.

To get to the point, denoising score matching replaces the distribution \(p^{\ast}\) with a smoothed version \(q_{\sigma}(x) = \mathbb{E}_{X' \sim p^{\ast}}\left[q_{\sigma}(x|X')\right]\) where \(q_{\sigma}\) is a some symmetric bell-curved distribution, for example a gaussian with standard deviation \(\sigma\) and mean \(X'\). Intuitively the scale parameter \(\sigma\) allow us to trade off some bias for variance by interpolating between the true (empirical) distribution as \(\sigma \to 0\) and a uniform distribution as \(\sigma \to \infty\)1, in addition to making training possible as it makes the resulting smoothed empirical distribution have full support on \(\mathcal{X}\) (so, it is never zero anywhere). Without this smoothing, \(\hat{p}_{n}\) will always be zero on points outside of the train set which comes with all kinds of problems. Choosing \(q_{\sigma}(x | x')\) to be an isotropic Gaussian pdf / distribution with covariance matrix \(\sigma I\) and mean \(x'\) simplifies objective \ref{eq:tr-score-matching-objective} to

\begin{equation} \label{eq:simplified-score-matching-objective} \mathcal{L}(\theta; \sigma) = \frac{1}{2}\mathbb{E}_{X \sim p^{\ast}}\mathbb{E}_{X' \sim q_{\sigma}(\cdot | X)}\left[\|s_{\theta}(X', \sigma) - (- (X' - X)/\sigma^{2})\|_{2}^{2}\right] \end{equation}

where both the risk \(\mathcal{L}\) and the score model \(s_{\theta}\) are now indexed by \(\sigma\). We may think of this as parameterizing a family of score models by \(\sigma\) for some fixed \(\theta\). Let's call the empirical risk \(\ell(\theta; \sigma)\) where we replace \(p^{\ast}\) with the empirical distribution \(\hat{p}_{n}\). The final objective defining the Noise Conditional Score Network above average losses over a geometrically spaced grid of scales \(\sigma\). For such a grid \((\sigma_{l})_{l=1}^{L}\) we have

\begin{equation} \label{eq:aggregated-final-empirical-risk} \hat{\mathcal{E}}(\theta; (\sigma_{l})_{l=1}^{L}) = \frac{1}{L}\sum_{l=1}^{L}\lambda(\sigma_{l})\ell(\theta; \sigma_{l}) \end{equation}

where \(\lambda\) is some weighing function which we will fix to be \(\lambda(\sigma) = \sigma^{2}\) according to the heuristic in (Song and Ermon 2020). Let us call the learned parameters \(\hat{\theta}_{n}\).

Generating samples

We can use Langevin dynamics to produce samples from the learned score model \(s_{\hat{\theta}_{n}}\). Usually, Langevin dynamics allow us to sample from some distribution \(p\) as long as we can evaluate the score function \(\nabla_{x} \log p(x)\). Fixing a step size (or more generally, a schedule) \(\eta\) and some prior distribution \(\pi\) we can sample an initial value \(x_{0}\) and iterate using

\begin{equation} \label{eq:langevin-dynamics} x_{t+1} = x_{t} + \frac{\eta}{2}\nabla_{x}\log p(x_{t}) + \sqrt{\eta}Z_{t} \end{equation}

where \(Z_{t}\)'s are sampled iid from a unit Gaussian. Replacing \(\nabla_{x}\log p(x)\) with \(s_{\hat{\theta}_{n}}(x)\) we can generate samples hopefully resembling those from \(p^{\ast}\).

More generally, for any procedure which produces samples from a distribution \(p\) using only the score function, we can plug-in \(s_{\hat{\theta}_{n}}\) which we've learned and produce samples, using the plugin-estimator method. This is pretty nice, we can tap into all the work which has been done in the field of MCMC2, for example Hamiltonian Monte-Carlo or NUTS. The decoupling of training and inference leads to many benefits, as we can repurpose \(s_{\hat{\theta}_{n}}\) for other downstream tasks.

Implementation

Imports

import functools
import math

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
sns.set_style("white")

import jax.numpy as jnp
from jax import grad, jax, vmap, lax
from jax import random
from jax import value_and_grad
import jax.tree_util as jtu
import jax

import equinox as eqx
import optax
from jaxtyping import Array, Float, Int, PyTree

import tensorflow_datasets as tfds
from tensorflow_probability.substrates import jax as tfp
tfd = tfp.distributions
tfb = tfp.bijectors
tfpk = tfp.math.psd_kernels
2024-02-05 13:35:57.262064: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-02-05 13:35:57.262106: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-02-05 13:35:57.263175: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-02-05 13:35:58.023617: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT

Tractable mixture models

For some very simple models such as mixtures of tractable base models or using bijectors we don't even need to learn the score function since it's available to us in closed form. The simplest way to get some intuition for this is to visualize the the log-probability function \(\log p(x)\) using for example level-sets and the vector field corresponding to \(s(x)\) in 2 dimensions.

def plot_logdistribution(fig, ax, distribution, xlim=(-1.0, 1.0), ylim=(-1.0, 1.0), n_contour=100, n_quiver=10):
    # Define the grid for contour
    x = np.linspace(*xlim, n_contour)
    y = np.linspace(*ylim, n_contour)
    X, Y = np.meshgrid(x, y)
    XY = np.stack([X.ravel(), Y.ravel()], axis=-1)

    # Compute the log-distribution
    Z = distribution.log_prob(XY).reshape(n_contour, n_contour)
    cont = ax.contour(X, Y, Z)
    plt.colorbar(cont, ax=ax)

    # Compute the gradients
    x = np.linspace(*xlim, n_quiver)
    y = np.linspace(*ylim, n_quiver)
    X, Y = np.meshgrid(x, y)
    XY = np.stack([X.ravel(), Y.ravel()], axis=-1)
    grads = vmap(grad(distribution.log_prob))(XY)
    grad_X = grads[:, 0].reshape(n_quiver, n_quiver)
    grad_Y = grads[:, 1].reshape(n_quiver, n_quiver)
    ax.quiver(X, Y, grad_X, grad_Y)
    return fig, ax

We simply show the level sets and quiver plot (the vector field) of the log-distribution and the score function

key = random.PRNGKey(0)  # Use a different key for different runs

# Define a 2-component Gaussian Mixture model
num_components = 2
component_means = [(0.5, 0.5), (-0.5, -0.5)]
sd = 0.4
component_sds = [(sd, sd), (sd, sd)]
p1 = 0.5
component_probs = [p1, 1 - p1]

mixture_dist = tfd.Categorical(probs=component_probs)
component_dist = tfd.MultivariateNormalDiag(loc=component_means, scale_diag=component_sds)
mixture_model = tfd.MixtureSameFamily(
    mixture_distribution=mixture_dist,
    components_distribution=component_dist,
    name="MoG"
)

fig, ax = plt.subplots(figsize=(8, 6))
fig, ax = plot_logdistribution(fig, ax, mixture_model)
Quiver plot of the score function of a mixture of Gaussians
Figure 1: The score function points towards the means of the component means (the peaks)

Generating samples

We already know from the previous section on generating samples how to do this, and the implementation is straightforward.

Let's quickly enable plotting the distribution. We will use this as a background for the evolving particle systems according to the Langevin dynamics.

def plot_distribution(fig, ax, distribution, xlim=(-1.0, 1.0), ylim=(1.0, 1.0), n_contour=100):
    # Define the grid for contour
    x = np.linspace(*xlim, n_contour)
    y = np.linspace(*ylim, n_contour)
    X, Y = np.meshgrid(x, y)
    XY = np.stack([X.ravel(), Y.ravel()], axis=-1)

    # Compute the distribution
    Z = distribution.prob(XY).reshape(n_contour, n_contour)
    cont = ax.contour(X, Y, Z)
    return fig, ax

We define the update step (return tuple due to using lax.scan later) and evolve a particle over many steps, lax.scan simply makes this efficient.

def update_x(x, z, distribution, step_size):
    g = grad(distribution.log_prob)(x)
    xp1 = x + (step_size / 2) * g + jnp.sqrt(step_size) * z
    return xp1, xp1

step_size = 0.01
num_steps = 200
key = random.PRNGKey(0)
z_key, x0_key, key = random.split(key, 3)
z = random.normal(z_key, shape=(num_steps, 2))
x0 = random.normal(x0_key, shape=(2,)) * 0.5
update_fun = functools.partial(update_x, distribution=mixture_model, step_size=step_size)
final, result = lax.scan(update_fun, x0, z)

Let's look at the result in this case. To see the path of the particle more clearly I'll outline the path by simply drawing a line between each point in result.

fig, ax = plt.subplots(figsize=(6, 6))
fig, ax = plot_distribution(fig, ax, mixture_model, xlim=(-1.5, 1.5), ylim=(-1.5, 1.5))
ax.plot(result[:, 0], result[:, 1], marker=".", linewidth=1.0)
Mixture of Gaussians Langevin dynamics of a particle
Figure 2: Path of a particle following the Langevin dynamics of the mixture model

Finally, let's make a video for goodies. The video is the above plot shown in time, following the particle according to the Langevin dynamics.

# We will animate this using the FuncAnimation class from matplotlib
from matplotlib.animation import FuncAnimation
fig, ax = plt.subplots(figsize=(6, 6))
fig, ax = plot_distribution(fig, ax, mixture_model, xlim=(-1.5, 1.5), ylim=(-1.5, 1.5))

# Initialize the line plot
line, = ax.plot([], [], marker='.', linewidth=1.0)
# Initialize the particle positions
positions = result

# Function to update the line plot
def update(frame):
    # Update the line plot data
    line.set_data(positions[:frame+1, 0], positions[:frame+1, 1])
    return line,

# Create the FuncAnimation
animation = FuncAnimation(fig, update, frames=len(positions),
                          interval=100, repeat=False)

# Save using ffmpeg
animation.save("mog-langevin-dynamics.mp4", writer="ffmpeg", dpi=200)

Learning the score function

The reason we didn't have to learn the score function in Tractable mixture models was because we restricted ourselves to a distribution with a tractable score function. In reality this is seldom the case, and even if we could do it in theory, it may be too computationally expensive to do it directly. Additionally, If we have a set of points which we interpret as an empirical distribution then the score function is not even well-defined as there is no density. We have to resort to learning it in some way.

First, we will use the MNIST dataset where we view an image as a discrete distribution by normalizing the pixel intensities over the total intensity of all the pixels in the image. Since each pixel is a value between 0 and 1, we can view this as a distribution over pixel coordinates. To make this point clear, we assume that the underlying space is a 2d cartesian square, \(\mathcal{X} = [0, 1]^{2}\)3, with each pixel coordinate being normalized to be between 0 and 1. So, an image is a collection of coordinate pairs and pixel intensity values, in the case of MNIST which is \(28 \times 28\) we have pixel coordinates \((i, j)\) where \(i, j \in \{(l + 1/2) / 28\}_{l=0}^{27}\) and the corresponding pixel intensities \(I(i, j) \in [0, 1]\). With this we have an empirical distribution where \(\hat{p}(i, j) = I(i, j) / \sum_{i', j'}I(i', j') \propto I(i, j)\).

Behold, the first image to the MNIST training dataset!

# Get an image from mnist
import torchvision

mnist = torchvision.datasets.MNIST("~/data", download=True)
mnist_images = mnist.data.numpy()
image = mnist_images[0]
image = image.astype(float) / 255.0

def create_sample_fn(image):
    """Generate a function that samples from the image distribution"""
    def sample(num_samples, key):
        h, w = image.shape
        # Note that random.categorical takes as inputs logits which is why we do not have to normalize
        return jnp.array(
            [divmod(x.item(), w) for x in random.categorical(logits=jnp.log(image.ravel()), key=key, shape=(num_samples,))]
        )
    return sample

fig, ax = plt.subplots()
im = ax.imshow(image, cmap="gray")
plt.colorbar(im)
ax.axis("off")
5-mnist-imshow.webp

Let's check the histogram when we sample many times according to the distribution defined by the image, we should get something similar as the sample size becomes large. The histogram function will rotate the image though.

key = jax.random.PRNGKey(sum(ord(c) for c in "five"))
sample = create_sample_fn(image)
x = sample(10000, key)
fig, ax = plt.subplots(figsize=(4, 4))
h = ax.hist2d(x[:, 0], x[:, 1], cmap="gray")
ax.axis("off")
5-mnist-hist2d.webp

Now we set up the training. The architecture here is a combination of things

  1. We use the insights from (Tancik et al. 2020) which roughly says that using a pre-processing fourier feature map before the MLP is helpful for learning high-frequency mappings for coordinate based inputs. We add a residual connection here, so that the input to the MLP is jnp.concatenate(f_layer(x), x).
  2. The RFLayer has noise-level specific parameters alpha, beta which linearly transform the random features and we learn one such transformation for each noise level (the rest of the architecture is shared, like the MLP and the original random feature mappings).
  3. We freeze the random feature parameters B_cos, B_sin by following the guide on how to freeze layer in the equinox documents.
  4. The rest of the training is done using the objective \(\hat{\mathcal{E}}(\theta; (\sigma_{l})_{l=1}^{L})\).
class RFLayer(eqx.Module):
    """Random Feature layer with learnable linear output transformations alpha, beta"""
    B_cos: jax.Array
    B_sin: jax.Array
    alpha: jax.Array
    beta: jax.Array
    num_noise_levels: int
    sigma: float

    def __init__(self, in_size: int, num_rf: int, num_noise_levels: int, key, sigma: float = 1.0):
        cos_key, sin_key = random.split(key, 2)
        self.B_cos = random.normal(cos_key, (num_rf, in_size)) * sigma
        self.B_sin = random.normal(sin_key, (num_rf, in_size)) * sigma
        self.sigma = sigma
        self.num_noise_levels = num_noise_levels
        self.alpha = jnp.ones(num_noise_levels)
        self.beta = jnp.zeros(num_noise_levels)

    def __call__(self, x: jax.Array, noise_level_idx: int) -> jax.Array:
        rf_features = jnp.concatenate(
            [jnp.cos(2 * math.pi * self.B_cos @ x), jnp.sin(2 * math.pi * self.B_sin @ x)], axis=-1
        )
        return self.alpha[noise_level_idx] * rf_features + self.beta[noise_level_idx]

class Model(eqx.Module):
    rf_layer: RFLayer
    mlp: eqx.nn.MLP

    def __init__(self, in_size: int, num_rf: int, width_size: int, depth: int, out_size: int, num_noise_levels: int, key):
        self.rf_layer = RFLayer(in_size, num_rf, num_noise_levels, key)
        self.mlp = eqx.nn.MLP(in_size=num_rf * 2 + 2,
                              width_size=width_size,
                              depth=depth,
                              out_size=out_size,
                              activation=jax.nn.softplus,
                              key=key)

    def __call__(self, x: jax.Array, noise_level_idx: int) -> jax.Array:
        x -= 0.5
        x = jnp.concatenate((self.rf_layer(x, noise_level_idx), x)) #Residual connection
        return self.mlp(x)

# Define the objective function
def one_sample_loss(model, x, sigmas, key):
    # Sample one gaussian for each noise level
    perturbations = random.normal(key, (sigmas.shape[0], x.shape[0])) * jnp.expand_dims(sigmas, 1)
    x_bars = x + perturbations
    # Predict over all noise levels
    scores_pred = vmap(model)(x_bars, jnp.arange(sigmas.shape[0]))
    scores = -(x_bars - x) / jnp.expand_dims(sigmas ** 2, 1)
    # Vectorized version of (x_bar[i] - x) / sigma[i] ** 2
    # mean(sigmas[i]**2 * mse(score_pred[i], scores[i]) for i in range(len(sigmas))))
    result = jnp.mean(jnp.square(scores_pred - scores).mean(-1) * sigmas ** 2)
    return result

def loss(diff_model, static_model, xs, sigmas, keys):
    """Objective function, we separeate the parameters into active and frozen parameters"""
    model = eqx.combine(static_model, diff_model)
    batch_loss = vmap(one_sample_loss, (None, 0, None, 0))
    return jnp.mean(batch_loss(model, xs, sigmas, keys))

def train(
        model: eqx.Module,
        filter_spec: PyTree,
        sample,
        optim: optax.GradientTransformation,
        steps: int,
        batch_size: int,
        print_every: int,
        sigmas: Float[Array, "..."],
        key
) -> eqx.Module:
    @eqx.filter_jit
    def make_step(
            model: eqx.Module,
            xs: Float[Array, "batch_size 2"],
            opt_state: PyTree,
            keys: Float[Array, "batch_size"],
    ):
        diff_model, static_model = eqx.partition(model, filter_spec)
        loss_value, grads = eqx.filter_value_and_grad(loss)(diff_model, static_model, xs, sigmas, keys)
        updates, opt_state = optim.update(grads, opt_state)
        model = eqx.apply_updates(model, updates)
        return model, opt_state, loss_value

    original_model = model
    opt_state = optim.init(eqx.filter(model, eqx.is_inexact_array))
    for step in range(steps):
        *loss_keys, sample_key, key = random.split(key, batch_size + 2)
        loss_keys = jnp.stack(loss_keys)
        xs = sample(batch_size, sample_key) / 27
        model, opt_state, loss_value = make_step(model, xs, opt_state, loss_keys)
        if step % print_every == 0:
            print(f"Step {step}, Loss {loss_value}")

    return model

Now let's train it. We squint and choose some good hyperparameters and pray to the ML-gods for an auspicious training run (actually I did some hand-tuning).

sigmas = jnp.geomspace(0.0001, 1, 30, endpoint=True)
DEPTH = 3
WIDTH_SIZE = 128
NUM_RF = 256
BATCH_SIZE = 128
STEPS = 5 * 10 ** 4
PRINT_EVERY = 5000

model = Model(in_size=2,
              num_rf=NUM_RF,
              width_size=WIDTH_SIZE,
              depth=DEPTH,
              out_size=2,
              num_noise_levels=len(sigmas),
              key=random.PRNGKey(0))

LEARNING_RATE = 1e-3
optim = optax.adam(LEARNING_RATE)

# The filter spec is a pytree of the same shape as the parameters
# True and False represent whether this part of the pytree will be updated
# using the optimizer by splitting the parameters into diff_model and static_model
filter_spec = jtu.tree_map(lambda x: True if isinstance(x, jax.Array) else False, model)
filter_spec = eqx.tree_at(
    lambda tree: (tree.rf_layer.B_cos, tree.rf_layer.B_sin),
    filter_spec,
    replace=(False, False),
)
model = train(model, filter_spec, sample, optim, STEPS, BATCH_SIZE, PRINT_EVERY, sigmas, key)
Step 0, Loss 1.0177688598632812
Step 5000, Loss 0.7470076084136963
Step 10000, Loss 0.6700457334518433
Step 15000, Loss 0.6010410785675049
Step 20000, Loss 0.5470178127288818
Step 25000, Loss 0.5063308477401733
Step 30000, Loss 0.47549256682395935
Step 35000, Loss 0.4591177701950073
Step 40000, Loss 0.4523712992668152
Step 45000, Loss 0.43943890929222107

Let's visualize the vector field for this new model by repurposing the plot_logdistribution function to just plot the vector field. Since we don't have an actual density we will not plot the level

def plot_vector_field(fig, ax, score_fun, xlim=(0.0, 1.0), ylim=(0.0, 1.0), n_quiver=10):
    # Compute the gradients
    x = np.linspace(*xlim, n_quiver)
    y = np.linspace(*ylim, n_quiver)
    X, Y = np.meshgrid(x, y)
    XY = np.stack([X.ravel(), Y.ravel()], axis=-1)
    grads = vmap(score_fun)(XY)
    grad_X = grads[:, 0].reshape(n_quiver, n_quiver)
    grad_Y = grads[:, 1].reshape(n_quiver, n_quiver)
    ax.quiver(X, Y, grad_X, grad_Y)
    return fig, ax

fig, ax = plt.subplots(3, 3, figsize=(3 * 2, 3 * 2))
for axis, i in zip(ax.ravel(), range(0, 30, 3)):
    axis.axis('off')
    axis.set_aspect('equal')
    axis.set_title(f"noise level {i}: {sigmas[i]:.2f}")
    plot_vector_field(fig, axis, functools.partial(model, noise_level_idx=i), n_quiver=15)

plt.tight_layout()
5-mnist-vector-fields.webp
Figure 3: The score function for all noise levels used to train the model. It seems like the sweetspot is around 18 (I choose 17 after inspecting all noise levels manually).

Let's see what we have learned. We define the update step (return tuple due to using lax.scan later)

@eqx.filter_jit
def update_x(x, z, model, step_size):
    g = model(x)
    xp1 = x + (step_size / 2) * g + jnp.sqrt(step_size) * z
    return xp1, xp1

and evolve a particle over many steps, lax.scan simply makes this efficient

step_size = 0.001
num_steps = 400_000
key = random.PRNGKey(0)
z_key, x0_key, key = random.split(key, 3)
z = random.normal(z_key, shape=(num_steps, 2))
x0 = jnp.ones(2,) * 0.5
score_model = functools.partial(model, noise_level_idx=17)
update_fun = functools.partial(update_x, model=score_model, step_size=step_size)
final, result = lax.scan(update_fun, x0, z)

Let's look at this. Since we sample so many particles, let's just plot a 2d histogram of this (choosing noise_level_idx being 17 but other indices in the vicinity should work too). Note that this has a finer resolution than the original mnist images which are \(27 \times 27\)

fig, ax = plt.subplots(figsize=(6, 6))
h = ax.hist2d(result[:, 0], result[:, 1], cmap="gray", bins=(50, 50))
ax.axis("off")
5-mnist-from-samples.webp

Conclusion

This was a great way to learn jax and how diffusion works. Looking back I think it may be overkill to do this on images as distributions as I did above, learning the distribution directly may be better in this case and faster in this case. I like the fact that the score model generalizes from the grid points \((i/27, j/27)_{i, j}^{27}\) to any tuple of points \((i, j)_{i, j \in [0, 1]}\) which is pretty cool and makes me wonder if you can use this to create a way to combine images of different resolutions as long as the aspect ratio is the same.

Reference

Song, Yang, and Stefano Ermon. 2020. “Generative Modeling by Estimating Gradients of the Data Distribution.” arXiv. http://arxiv.org/abs/1907.05600.
Tancik, Matthew, Pratul P. Srinivasan, Ben Mildenhall, Sara Fridovich-Keil, Nithin Raghavan, Utkarsh Singhal, Ravi Ramamoorthi, Jonathan T. Barron, and Ren Ng. 2020. “Fourier Features Let Networks Learn High Frequency Functions in Low Dimensional Domains.” arXiv. http://arxiv.org/abs/2006.10739.
1

Feels like there should be some way of looking at this through a regularization lense where \(\sigma\) takes the role as the regularization strength in traditional supervised learning such as Ridge Regression.

2

Of which I know very little.

3

Although when we convolve the inputs with Gaussians we will have that any point in \(\mathbb{R}^{2}\) will have positive probability, albeit maybe very small.