Introduction to Generative AI: Vision Series (1/3) - VAE


Introduction


Generative A.I. is currently experiencing an unprecedented advance in popularity and popular use. In this series, I’ll dissect some of the more popular algorithms used for image generation. These algorithms form the backbone for enterprise algorithms like Stable Diffusion and DALLE-3. In particular, the Variational Autoencoder, Generative Adversarial Networks, and Diffusion models have become the de-facto industry standard. This series will first develop the theory behind each algorithm, and then dive into a practical implementation of each in Python and PyTorch. First on our list is the Variational Autoencoder.


The Variational Autoencoder


Originally proposed in 2013, the variational autoencoder (VAE) has since become an incredibly ubiquitous tool, used in image generation, anomaly detection, and data analysis and compression, etc. I’ll provide a rough overview of the paper, going through some of the math behind the VAE while skipping past some of the more technical details.


Setup

Consider a dataset X X consisting of N N i.i.d. samples of some random variable x x . A core assumption behind the VAE is that X X is generated by some random process involving another continuous random variable z z . In other words, X is generated using two distributions: p θ ( z ) p_{\theta}(z) and a conditional distribution p θ ( x z ) p_{\theta}(x | z) . We do not know anything about the parameters of the distribution θ \theta or the latent values z z .

The main goal is to efficiently find Maximum-Likelihood (ML) or Maximum a Posteriori (MAP) estimation of the parameters θ \theta (if you’re unfamiliar with Maximum-Likelihood estimation, I found this lecture from UC Berkeley to be super helpful). It’d also be nice to be quickly able to determine the values of the latent variable z z for a given input x x . This is useful if we’d like to encode x x in a smaller dimension or simply represent x x with less features. An added touch would be a way to quickly sample from the marginal distribution over x x , which is useful in tasks that require generating new samples from X X or some specialized data manipulation.


Approach

So what is the approach for tackling this trifecta?

To start, let’s define a new distribution q ϕ ( z x ) q_{\phi}(z | x) that’ll approximate the true posterior p θ ( z x ) p_{\theta}(z | x) , which is intractable. The distribution q ϕ ( z x ) q_{\phi}(z | x) can be considered an encoder , since it generates distributions over the possible latent representations z z for a particular value of x x . Similarly, we can refer to the distribution p θ ( x z ) p_{\theta}(x | z) as a decoder , since we can generate all possible values of x x from z z .

We want q ϕ ( z x ) q_{\phi}(z | x) to match p θ ( z x ) p_{\theta}(z | x) as closely as possible. In other terms, we want the Kullback-Leibler divergence between q ϕ ( z x ) q_{\phi}(z | x) and p θ ( z x ) p_{\theta}(z | x) to be as small as possible. We can derive this term for a particular x ( i ) x^{(i)} . Using E q \mathbb{E}_q to refer to E q ϕ ( z x ( i ) ) \mathbb{E}_{q_{\phi}(z | x^{(i)})} , this can be written as
D K L ( q ϕ ( z x ( i ) ) p θ ( z x ( i ) ) ) = E q [ log p θ ( z x ( i ) ) q ϕ ( z x ( i ) ) ] = E q [ log p θ ( x ( i ) z ) p θ ( z ) q ϕ ( z x ( i ) ) p θ ( x ( i ) ) ] = E q [ log p θ ( x ( i ) z ) p θ ( z ) q ϕ ( z x ( i ) ) ] + E q [ log p θ ( x ( i ) ) ] = E q [ log p θ ( x ( i ) , z ) q ϕ ( z x ( i ) ) ] + log p θ ( x ( i ) ) \begin{align*} D_{KL} (q_{\phi}(z | x^{(i)}) || p_{\theta}(z | x^{(i)})) &= -\mathbb{E}_{q}\left[\log \frac{p_{\theta}(z | x^{(i)})}{q_{\phi}(z | x^{(i)})}\right]\\ &= -\mathbb{E}_{q}\left[\log \frac{p_{\theta}(x^{(i)} | z) p_{\theta}(z)}{q_{\phi}(z | x^{(i)})p_{\theta}(x^{(i)})}\right]\\ &= -\mathbb{E}_{q}\left[\log \frac{p_{\theta}(x^{(i)} | z) p_{\theta}(z)}{q_{\phi}(z | x^{(i)})}\right] + \mathbb{E}_{q}[\log p_{\theta}(x^{(i)})]\\ &= -\mathbb{E}_{q}\left[\log \frac{p_{\theta}(x^{(i)}, z)}{q_{\phi}(z | x^{(i)})}\right] + \log p_{\theta}(x^{(i)})\\ \end{align*}
Since the Kullback-Leibler divergence is always non-negative, we can rearrange the above equation into the following expression
log p θ ( x ) E q [ log p θ ( x ( i ) , z ) q ϕ ( z x ( i ) ) ] \log p_{\theta}(x) \geq \mathbb{E}_{q}\left[\log \frac{p_{\theta}(x^{(i)}, z)}{q_{\phi}(z | x^{(i)})}\right]
The term on the right is called the variational lower bound or sometimes evidence lower bound (ELBO) simply because it is a lower bound on the log-likelihood of our evidence. We could also derive a different expression for the lower bound if we keep going.
log p θ ( x ) E q [ log p θ ( x ( i ) , z ) q ϕ ( z x ( i ) ) ] = E q [ log p θ ( x ( i ) z ) p θ ( z ) q ϕ ( z x ( i ) ) ] = E q [ log p θ ( z ) q ϕ ( z x ( i ) ) ] + E q [ log p θ ( x ( i ) z ) ] = D K L ( q ϕ ( z x ( i ) ) p θ ( z ) ) + E q [ log p θ ( x ( i ) z ) ] \begin{align*} \log p_{\theta}(x) &\geq \mathbb{E}_{q}\left[\log \frac{p_{\theta}(x^{(i)}, z)}{q_{\phi}(z | x^{(i)})}\right]\\ &= \mathbb{E}_{q}\left[\log \frac{p_{\theta}(x^{(i)} | z) p_{\theta}(z)}{q_{\phi}(z | x^{(i)})}\right]\\ &= \mathbb{E}_{q}\left[\log \frac{p_{\theta}(z)}{q_{\phi}(z | x^{(i)})}\right] + \mathbb{E}_{q}\left[\log p_{\theta}(x^{(i)} | z)\right]\\ &= -D_{KL}(q_{\phi}(z | x^{(i)}) || p_{\theta}(z)) + \mathbb{E}_{q}\left[\log p_{\theta}(x^{(i)} | z)\right]\\ \end{align*}
Our goal is to maximize the log-likelihood of our evidence, which from the above we can see corresponds to maximizing the ELBO.


The reparameterization trick

Everything above is fine but there’s still a big problem - we can’t actually optimize over this objective that we’ve constructed. There’s a lot of theoretical research that goes into studying mathematics like this but the tldr for our purposes is that currently this term has too much variance to be useful as a gradient signal for our optimization. Furthermore, our current objective isn’t actually differentiable over ϕ \phi .

Here’s where the magic happens. We can use the fact that we chose q ϕ ( z x ) q_{\phi}(z | x) to approximate our posterior, meaning we can choose q q such that it can be represented as a function of a differentiable transformation g g and some auxiliary noise variable ϵ \epsilon . In other words,
z ˉ q ϕ ( z x ) = g ϕ ( ϵ , x ) , ϵ p ( ϵ ) \bar{z} \sim q_{\phi}(z | x) = g_{\phi}(\epsilon, x), \epsilon \sim p(\epsilon)
Why is this useful? Well, for a smart choice of q q , g g , and p ( ϵ ) p(\epsilon) we can effectively reduce the variance of the gradient computation for the ELBO to the point that its tractable for our purposes. We can also now differentiate over ϕ \phi to obtain a gradient for our parameter updates.

There are a lot of choices of q q , g g and p ( ϵ ) p(\epsilon) that work for this, but the one most commonly seen is representing the approximate posterior as N ( μ , σ 2 ) \mathcal{N}(\mu, \sigma^2) , where we can then use the properties of the Gaussian distribution to rewrite z N ( μ , σ 2 ) = μ + σ ϵ z \sim \mathcal{N}(\mu, \sigma^2) = \mu + \sigma \epsilon where ϵ N ( 0 , 1 ) \epsilon \sim \mathcal{N}(0, 1) . This is known as the reparameterization trick . We’re essentially moving the “non-differentiable” part into another independent distribution p ( ϵ ) p(\epsilon) so that we can compute proper gradients for our model parameters ϕ \phi .

Recall that we derived two different ways to express the ELBO - both expressions will require an estimation by sampling. However, since the second expression has the KL divergence term (for which there exist analytical methods to compute it) we will typically encounter less variation. Originally, we had
L ( θ , ϕ ; x ( i ) ) = D K L ( q ϕ ( z x ( i ) ) p θ ( z ) ) + E q [ log p θ ( x ( i ) z ) ] \mathcal{L}(\theta, \phi; x^{(i)}) = -D_{KL}(q_{\phi}(z | x^{(i)}) || p_{\theta}(z)) + \mathbb{E}_{q}\left[\log p_{\theta}(x^{(i)} | z)\right]
Using the trick, we can now write a Monte Carlo estimator for this expression over a set of L L samples, resulting in
L ~ ( θ , ϕ ; x ( i ) ) = D K L ( q ϕ ( z x ( i ) ) p θ ( z ) ) + 1 L l = 1 L [ log p θ ( x ( i ) z ( i , l ) ) ] \widetilde{\mathcal{L}}(\theta, \phi; x^{(i)}) = -D_{KL}(q_{\phi}(z | x^{(i)}) || p_{\theta}(z)) + \frac{1}{L}\sum_{l=1}^{L}\left[\log p_{\theta}(x^{(i)} | z^{(i, l)})\right] where z ( i , l ) = g ϕ ( ϵ ( i , l ) , x ( i ) ) z^{(i, l)} = g_{\phi}(\epsilon^{(i, l)}, x^{(i)}) and ϵ ( l ) p ( ϵ ) \epsilon^{(l)} \sim p(\epsilon) .

We can identify two main components here - the KL loss, which incentivizes the model to make sure the approximate posterior and the original posterior are as close as possible. The other term is often referred to as a reconstruction loss, as it ensures the model properly represents the original input from the latent space. This final expression is tractable and differentiable with respect to ϕ \phi , exactly as we wanted.


Bringing it all together

We now have a differentiable and tractable objective which we can use to optimize over the parameters ϕ \phi and θ \theta of our models, meaning we have all the tools we need to construct a variational autoencoder. Before we construct our objective in PyTorch, we first need to write an analytical version of our KL divergence term. I’ll skip the derivation here (this blog post does a good job of explaining it), but if you assume that p θ ( z ) p_{\theta}(z) and q ϕ ( z x ( i ) ) q_{\phi}(z | x^{(i)}) both approximately follow the multivariate Gaussian distribution, i.e. the encoder q ϕ ( z x ( i ) ) = N ( z ; μ ( i ) , σ 2 ( i ) I ) q_{\phi}(z | x^{(i)}) = \mathcal{N}(z; \mu^{(i)}, \sigma^{2(i)}\mathbb{I}) and p θ ( z ) = N ( z ; 0 , I ) p_{\theta}(z) = \mathcal{N}(z; \textbf{0}, \mathbb{I}) , you can rewrite
D K L ( q ϕ ( z x ( i ) ) p θ ( z ) ) = 1 2 j = 1 J ( 1 + log ( ( σ j ( i ) ) 2 ) ( μ j ( i ) ) 2 ( σ j ( i ) ) 2 ) -D_{KL}(q_{\phi}(z | x^{(i)}) || p_{\theta}(z)) = \frac{1}{2} \sum_{j=1}^{J} \left(1 + \log ((\sigma_j^{(i)})^2) - (\mu_j^{(i)})^2 - (\sigma_j^{(i)})^2\right)

In practice, we usually have the encoder network output vectors for μ \mu and log σ 2 \log \sigma^2 instead of σ \sigma directly because its more numerically stable. Here, my model outputs mu , logvar and a batch of reconstructions batch_hat .

		
batch_hat, mu, logvar = vae(batch)
	

Now, with logvar = log σ 2 \log \sigma^2 and mu = μ \mu , you can write the KL divergence as

		
kl_loss  =  -0.5  *  torch.sum(1  +  logvar  -  mu.pow(2) -  logvar.exp())
	

The extra negative sign is there because we have to minimize the negative of our objective in order to maximize our original objective.

Now onto the reconstruction loss. The authors of the original paper used L = 1 L=1 , simplifying the term into a computation of log p θ ( x ( i ) z ( i , l ) ) \log p_{\theta}(x^{(i)} | z^{(i, l)}) . I’ll skip the derivation here again, but if we choose to assume that p θ p_{\theta} follows a Bernoulli distribution, this term becomes the negative binary cross entropy, meaning we can minimize the standard PyTorch BCE loss to achieve our original objective.

		
import torch.nn.functional as F
...
reproduction_loss = F.binary_cross_entropy(batch_hat, batch)
	

Our final loss is then simply

		
loss = reproduction_loss + kl_loss
	
Encoder and Decoder

Let’s define our encoder and decoder models as deep neural networks in PyTorch now. Based off our previous derivations, our first requirement is that the encoder model needs to output two vectors μ \mu and log σ 2 \log \sigma^2 . For our purposes, let’s assume that the input to the encoder will be a 128 by 128 by 3 image.

For images, it’s often best practice to use convolutional layers. Let’s define a helper function that’ll implement a sequential module for a convolutional block.

		
def  conv_block(self, input, output, kernel=4, stride=2, pad=1):
	return  nn.Sequential(
		nn.Conv2d(input, output, kernel, stride, pad, bias=False),
		nn.BatchNorm2d(output),
		nn.LeakyReLU(),
	)
	

I use BatchNorm for normalization and LeakyReLU for the activation function after each Conv2d layer. Now, we can define our encoder model.

		
nf  =  64
dim_mults  = (1, 2, 4, 8, 16)
hidden_dims  = [nf  *  mult  for  mult  in  list(dim_mults)]

self.encoder  =  nn.Sequential(
	*[
		self.conv_block(in_f, out_f)
		for  in_f, out_f  in  zip([channel_size] +  hidden_dims[:-1], hidden_dims)
	]
)
	

Basically, we’re computing the size of each hidden layer in hidden_dims based off the dimension multiples that we’ve specified in dim_mults and our starting filter size of 64 . We then string together all the sequential models created by our helper function from the list of in and out dimensions.

Finally, we need two Linear layers to take our latent embedding from the encoder and convert them into μ \mu and log σ 2 \log \sigma^2 , which are vectors of size latent_size .

		
self.d_max  =  hidden_dims[-1]
self.mu  =  nn.Linear(self.d_max  *  4  *  4, latent_size)
self.logvar  =  nn.Linear(self.d_max  *  4  *  4, latent_size)
	

The decoder model is essentially symmetric to the encoder, apart from the input; where the encoder outputs two vectors, the decoder only takes in one. We need code to transform that input.

		
self.embed  =  nn.Linear(latent_size, self.d_max  *  4  *  4)
	

There is some research that looks into asymmetric VAEs, where the encoder and decoder model are different sizes or even different types of models, but in most cases the symmetric VAE is satisfactory. Our decoder is essentially the mirror image of the encoder.

		
self.decoder  =  nn.Sequential(
	*[
		self.conv_transpose_block(in_f, out_f)
		for  in_f, out_f  in  zip(reversed(hidden_dims[1:]), reversed(hidden_dims[:-1]))
	],
	nn.ConvTranspose2d(nf, 3, 4, 2, 1, bias=False),
	nn.Sigmoid(),
)
	

At the end, we have one final ConvTranspose2d and Sigmoid to sample the image into 3 channels and for the final output to be scaled between 0 and 1. For binary cross entropy loss to work, we have to make sure both our inputs and our outputs are scaled between 0 and 1.

There’s one missing step so far - how do we take a sample from the encoder to pass it to the decoder? We need to implement the reparametrization trick we discussed -

		
def  reparameterize(self, mu, logvar):
	std  =  torch.exp(0.5  *  logvar)
	eps  =  torch.randn_like(std, device=device)
	return  mu  +  std  *  eps
	

This function essentially implements g ϕ ( ϵ , x ( i ) ) g_\phi(\epsilon,x^{(i)}) , where ϵ N ( 0 , I ) \epsilon \sim \mathcal{N(\mathbf{0}, \mathbb{I})} from earlier.

Finally, we can put this all together.

		
def  encode(self, input):
	embed  =  self.encoder(input)
	embed  =  torch.flatten(embed, start_dim=1)
	mu, logvar  =  self.mu(embed), self.logvar(embed)
	sample  =  self.reparameterize(mu, logvar)
	return  sample, mu, logvar

def  decode(self, input):
	embed  =  self.embed(input.squeeze())
	embed  =  embed.view(-1, self.d_max, 4, 4)
	return  self.decoder(embed)

def  forward(self, input):
	sample, mu, logvar  =  self.encode(input)
	out  =  self.decode(sample)
	return  out, mu, logvar
	

Generating Art


Now that we have the base network complete as well as our loss function written in PyTorch, we can put together a quick training loop and dataset to (hopefully) generate some new cool images!

Most posts about VAEs online use the MNIST dataset; for this series, I’ll use some more challenging datasets to better represent the difficulties of training VAEs in the wild. All the experiments below use images from the Japanese Woodblock Print database .

ukiyo-e sample

Let’s try and generate images like the one above!

I’ll skip the steps of defining the dataset and dataloader since they’re usually pretty simple and instead focus on the training loop for the VAE. If you’d like to see that code, you can check out the full repo here . Assuming we’ve packaged the code above into a VariationalAutoEncoder module, we can initialize our model and optimizer.

		
vae  =  VariationalAutoEncoder()
vae.to(device)
optimizer  =  optim.Adam(vae.parameters(), lr=1e-6, betas=(0.5, 0.999))
	

When training VAEs, I find it helpful to define a fixed latent batch at the beginning of training, to use as a comparison throughout training to see how the VAE improves generating for the same latents z z . I’ll also add in an array to track loss statistics throughout training.

		
fixed_latent  =  torch.randn(64, latent_size, 1, 1, device=device)
losses  = []
	

Now we can define our actual training loop. We’ll iterate through our dataloader in batches for each epoch during training, performing gradient updates on the parameters ϕ \phi and θ \theta of our VAE along the way.

		
for  epoch  in  range(n):
	for  i, batch  in  enumerate(dataloader, 0):
		batch, _  =  batch
		batch  =  batch.to(device)

		vae.zero_grad()
		batch_hat, mu, logvar  =  vae(batch)
		reproduction_loss = F.binary_cross_entropy(batch_hat, batch)
		kl_loss  =  -0.5  *  torch.sum(1  +  logvar  -  mu.pow(2) -  logvar.exp())
		loss  =  reproduction_loss  +  kl_loss
		loss.backward()
		optimizer.step()

		losses.append((loss.item(), reproduction_loss.item(), kl_loss.item()))
	

I usually like to track the progress of the VAE training at fixed intervals by generating a batch using the fixed latent we defined earlier. You can compare this batch throughout training to get a sense of how the VAE improved.

		
if (i  %  1000  ==  0):
	vae.eval()
	sample_batch_hat, _, _  =  vae(sample_batch)
	fake  =  vae.decode(fixed_latent).detach().cpu()
	vae.train()
	

And we’re ready to train! In my experiments, I ran the training for 25 epochs, using a learning rate of 1e-6 , a batch size of 32 , and a latent size of 512 . Let’s see how it did. For each image below, I’ve put a generated image from our VAE as well as an image from the dataset that I thought looked similar.

example1
Mitate Murasaki Shikibu
example2
toyokuni-iii-ejiri-and-fuchu
example1
waseda

Sooo, the images generated aren’t exactly what we were going for - they’re more like simplified or abstract versions of the original artwork. We can see similarities between some of the dataset images and the generated art above, but the VAE doesn’t seem to be able to get the finer features down in the generated works.

Let’s compare how the VAE reconstructed images at different iterations to see if the encoder-decoder aspect of the full pipeline works as intended.

comp-reconstruction

As you can see, the VAE does pretty well on reconstruction but isn’t as good at generating new images from scratch. This is a problem that could be approached with some hyper-parameter tuning or other additions that I talk about below, but it’s also possible that this dataset is too complicated for the VAE to model.

In the same vein as above, let’s take a look at how the VAE improved on the same batch of latent inputs (fixed_latent ) from above.

fixed latent animation

As you can see, the VAE improves pretty quickly at the beginning and focuses on finer details towards the end. You can see a reduction in overall noise and corresponding improvement in image quality as we approach the end.

For completeness, let’s also visualize a few interpolations of the generated images over the latent space z z . We’ll generate two random vectors, and interpolate between those vectors, generating images for each. The code for interpolating between two points is pretty standard but I’ve also included some plotting code below in case its useful.

		
def  interpolate_points(p1, p2, n_steps=8):
	ratios  =  np.linspace(0, 1, num=n_steps)
	vectors  = []
	for  ratio  in  ratios:
		v  = (1.0  -  ratio) *  p1  +  ratio  *  p2
		vectors.append(v)
	return  vectors

def  plot_generated(examples):
	for  i  in  range(len(examples)):
		plt.subplot(1, len(examples), 1  +  i)
		plt.axis('off')
		plt.imshow(examples[i])
	plt.subplots_adjust(wspace=0, hspace=0)
	plt.savefig("interp.png", bbox_inches='tight')

# pick two random points to interpolate between
z0  =  torch.randn((512, 1, 1))
z1  =  torch.randn((512, 1, 1))

# get the interpolated points
zs  =  interpolate_points(z0, z1)
zs = [z0] + zs + [z1]

# run the points through the decoder
image_interps  = [vae.decode(z).squeeze() for  z  in  zs]
image_interps  = [image.permute(1, 2, 0).detach().numpy() for  image  in  image_interps]

# plot the generated images
plot_generated(image_interps)
	

And we have our results below:

interpolation1

interpolation 2


Tips and Tricks


Training a VAE can be difficult, as we have seen. Fortunately, there are some small, easy to implement tricks we can add to our implementation that can significantly improve the overall results.

  1. One quick trick to try is switching out the BCE loss for the reconstruction term to MSE loss. I’ll skip the derivation here again, but if we choose to assume that p θ ( x z ) p_{\theta}(x | z) follows a normal distribution, we can actually minimize the mean-square error (MSE) instead. If your data distribution more closely matches a Gaussian distribution as opposed to a Bernoulli distribution like earlier, this can help your VAE optimize a bit better.

  2. Second, one problem that often occurs when training VAEs is that sometimes the VAE can over-prioritize matching the learned posterior distribution to the normal distribution, which can result in little to no variety in the generated images. There are a couple different ways to counteract this, but one popular approach is called β \beta -VAE . In β \beta -VAE, the term for KL divergence in the loss function is weighted with a parameter β \beta , which can be tuned to focus the VAE on either reconstruction loss or matching the posterior distribution. You can think of this like weighting what part of the problem you want the VAE to focus on. In PyTorch, this is as simple as changing our loss to

    				
    loss = reproduction_loss + beta * kl_loss
    			

    In practice, I’ve found lower values of β \beta to work better (such as . 0001 .0001 ) since the KL divergence term tends to be much higher than the reconstruction term. Different datasets will require different values of β \beta , so this approach requires some tuning.

  3. Another related approach is called annealing . In line with the approach for β \beta -VAE, the idea behind annealing is to essentially introduce a schedule for the value of the β \beta parameter throughout training. In practice, it usually works well if you first train the VAE to focus on the reconstruction loss (i.e. β = 0 \beta = 0 ), and then slowly increase β \beta throughout training until the VAE focuses on both terms equally (i.e. β = 1 \beta = 1 ). In PyTorch, this is again pretty simple -

    				
    beta_schedule  =  torch.linspace(0, beta, len(dataloader))
    loss  =  reproduction_loss  +  beta_schedule[i] *  kl_loss
    			

Even with all the tricks above, it’s possible that for your dataset, the VAE is simply not good enough at image generation. One thing that I’ve learned from making models like this is to not get too attached to the approach itself - sometimes certain models just aren’t good enough to get the types of results you’re looking for. VAEs have largely fallen out of the state of the art for complex image generation, although they are still used in combination with other algorithms in some cases, like in Stable Diffusion.

The main reason for this is because VAEs tend to struggle to produce fine features due to their focus on representing the entire distribution at once, which leads to blurry and unrefined images, as we have seen. Even if they are good at reconstruction and compression purposes, VAEs simply may not be enough for robust image generation for complex datasets.

Some research has been done that shows that scaling VAEs up or creating smart representations for VAEs can improve performance and combat issues like posterior collapse, which I may cover in future posts. However, for the next part of this series, I’ll cover a different image generation algorithm - the generative adversarial network, or GAN .

If you’d like to check out the full code for this series, you can visit my public repository .