import math
import os
import time
import matplotlib.pyplot as plt
import numpy as np
import ot as pot
import torch
import torchdyn
from torchdyn.core import NeuralODE
from torchdyn.datasets import generate_moons
= "models/8gaussian-moons"
savedir =True) os.makedirs(savedir, exist_ok
Conditional Flow Matching
This notebook is a self-contained example of conditional flow matching. This notebook is taken from this Github repo https://github.com/atong01/conditional-flow-matching
In this notebook, we show how to map from a source distribution \(q_0\) to a target distribution \(q_1\): * Conditional Flow Matching (CFM) * This is equivalent to the basic (non-rectified) formulation of “Flow Straight and Fast: Learning to Generate and Transfer Data with Rectified Flow” (Liu et al. 2023) * Is similar to “Stochastic Interpolants” (Albergo et al. 2023) with a non-variance preserving interpolant. * Is similar to “Flow Matching” (Lipman et al. 2023) but conditions on both source and target. * Optimal Transport CFM (OT-CFM), which directly optimizes for dynamic optimal transport
Note that this Flow Matching is different from the Generative Flow Network Flow Matching losses. Here we specifically regress against continuous flows, rather than matching inflows and outflows.
# Implement some helper functions
def eight_normal_sample(n, dim, scale=1, var=1):
= torch.distributions.multivariate_normal.MultivariateNormal(
m * torch.eye(dim)
torch.zeros(dim), math.sqrt(var)
)= [
centers 1, 0),
(-1, 0),
(0, 1),
(0, -1),
(1.0 / np.sqrt(2), 1.0 / np.sqrt(2)),
(1.0 / np.sqrt(2), -1.0 / np.sqrt(2)),
(-1.0 / np.sqrt(2), 1.0 / np.sqrt(2)),
(-1.0 / np.sqrt(2), -1.0 / np.sqrt(2)),
(
]= torch.tensor(centers) * scale
centers = m.sample((n,))
noise = torch.multinomial(torch.ones(8), n, replacement=True)
multi = []
data for i in range(n):
+ noise[i])
data.append(centers[multi[i]] = torch.stack(data)
data return data
def sample_moons(n):
= generate_moons(n, noise=0.2)
x0, _ return x0 * 3 - 1
def sample_8gaussians(n):
return eight_normal_sample(n, 2, scale=5, var=0.1).float()
class MLP(torch.nn.Module):
def __init__(self, dim, out_dim=None, w=64, time_varying=False):
super().__init__()
self.time_varying = time_varying
if out_dim is None:
= dim
out_dim self.net = torch.nn.Sequential(
+ (1 if time_varying else 0), w),
torch.nn.Linear(dim
torch.nn.SELU(),
torch.nn.Linear(w, w),
torch.nn.SELU(),
torch.nn.Linear(w, w),
torch.nn.SELU(),
torch.nn.Linear(w, out_dim),
)
def forward(self, x):
return self.net(x)
class GradModel(torch.nn.Module):
def __init__(self, action):
super().__init__()
self.action = action
def forward(self, x):
= x.requires_grad_(True)
x = torch.autograd.grad(torch.sum(self.action(x)), x, create_graph=True)[0]
grad return grad[:, :-1]
class torch_wrapper(torch.nn.Module):
"""Wraps model to torchdyn compatible format."""
def __init__(self, model):
super().__init__()
self.model = model
def forward(self, t, x, args=None):
return self.model(torch.cat([x, t.repeat(x.shape[0])[:, None]], 1))
def plot_trajectories(traj):
= 2000
n =(6, 6))
plt.figure(figsize0, :n, 0], traj[0, :n, 1], s=10, alpha=0.8, c="black")
plt.scatter(traj[0], traj[:, :n, 1], s=0.2, alpha=0.2, c="olive")
plt.scatter(traj[:, :n, -1, :n, 0], traj[-1, :n, 1], s=4, alpha=1, c="blue")
plt.scatter(traj["Prior sample z(S)", "Flow", "z(0)"])
plt.legend([
plt.xticks([])
plt.yticks([]) plt.show()
Conditional Flow Matching
First we implement the basic conditional flow matching. As in the paper, we have \[ \begin{align} z &= (x_0, x_1) \\ q(z) &= q(x_0)q(x_1) \\ p_t(x | z) &= \mathcal{N}(x | t * x_1 + (1 - t) * x_0, \sigma^2) \\ u_t(x | z) &= x_1 - x_0 \end{align} \] When \(\sigma = 0\) this is equivalent to zero-steps of rectified flow. We find that small \(\sigma\) helps to regularize the problem ymmv.
%%time
= 0.1
sigma = 2
dim = 256
batch_size = MLP(dim=dim, time_varying=True)
model = torch.optim.Adam(model.parameters())
optimizer
= time.time()
start for k in range(20000):
optimizer.zero_grad()= torch.rand(batch_size, 1)
t = sample_8gaussians(batch_size)
x0 = sample_moons(batch_size)
x1 = t * x1 + (1 - t) * x0
mu_t = sigma
sigma_t = mu_t + sigma_t * torch.randn(batch_size, dim)
x = x1 - x0
ut = model(torch.cat([x, t], dim=-1))
vt = torch.mean((vt - ut) ** 2)
loss
loss.backward()
optimizer.step()if (k + 1) % 5000 == 0:
= time.time()
end print(f"{k+1}: loss {loss.item():0.3f} time {(end - start):0.2f}")
= end
start = NeuralODE(
node ="dopri5", sensitivity="adjoint", atol=1e-4, rtol=1e-4
torch_wrapper(model), solver
)with torch.no_grad():
= node.trajectory(
traj 1024),
sample_8gaussians(=torch.linspace(0, 1, 100),
t_span
)
plot_trajectories(traj)f"{savedir}/cfm_v1.pt") torch.save(model,
5000: loss 8.896 time 14.94
10000: loss 8.825 time 15.87
15000: loss 8.178 time 17.13
20000: loss 8.456 time 19.18
CPU times: user 3min 35s, sys: 26.3 s, total: 4min 1s
Wall time: 1min 7s
Optimal Transport Conditional Flow Matching
Next we implement optimal transport conditional flow matching. As in the paper, here we have \[ \begin{align} z &= (x_0, x_1) \\ q(z) &= \pi(x_0, x_1) \\ p_t(x | z) &= \mathcal{N}(x | t * x_1 + (1 - t) * x_0, \sigma^2) \\ u_t(x | z) &= x_1 - x_0 \end{align} \] where \(\pi\) is the joint of an exact optimal transport matrix. We first sample random \(x_0, x_1\), then resample according to the optimal transport matrix as computed with the python optimal transport package. We use the 2-Wasserstein distance with an \(L^2\) ground distance for equivalence with dynamic optimal transport.
%%time
= 0.1
sigma = 2
dim = 256
batch_size = MLP(dim=dim, time_varying=True)
model = torch.optim.Adam(model.parameters())
optimizer
= time.time()
start = pot.unif(batch_size), pot.unif(batch_size)
a, b for k in range(20000):
optimizer.zero_grad()= torch.rand(batch_size, 1)
t = sample_8gaussians(batch_size)
x0 = sample_moons(batch_size)
x1
# Resample x0, x1 according to transport matrix
= torch.cdist(x0, x1) ** 2
M = M / M.max()
M = pot.emd(a, b, M.detach().cpu().numpy())
pi # Sample random interpolations on pi
= pi.flatten()
p = p / p.sum()
p = np.random.choice(pi.shape[0] * pi.shape[1], p=p, size=batch_size)
choices = np.divmod(choices, pi.shape[1])
i, j = x0[i]
x0 = x1[j]
x1 # calculate regression loss
= x0 * (1 - t) + x1 * t
mu_t = sigma
sigma_t = mu_t + sigma_t * torch.randn(batch_size, dim)
x = x1 - x0
ut = model(torch.cat([x, t], dim=-1))
vt = torch.mean((vt - ut) ** 2)
loss
loss.backward()
optimizer.step()if (k + 1) % 5000 == 0:
= time.time()
end print(f"{k+1}: loss {loss.item():0.3f} time {(end - start):0.2f}")
= end
start = NeuralODE(
node ="dopri5", sensitivity="adjoint", atol=1e-4, rtol=1e-4
torch_wrapper(model), solver
)with torch.no_grad():
= node.trajectory(
traj 1024),
sample_8gaussians(=torch.linspace(0, 1, 100),
t_span
)
plot_trajectories(traj)f"{savedir}/otcfm_v1.pt") torch.save(model,
5000: loss 0.138 time 76.86
10000: loss 0.103 time 75.88
15000: loss 0.217 time 81.70
20000: loss 0.114 time 86.51
CPU times: user 17min 42s, sys: 1min 52s, total: 19min 34s
Wall time: 5min 21s