Exploring Variational Autoencoders (VAEs) for Image Compression

A full-code tutorial using Comet ML

Boluwatife Victor O.
Heartbeat

--

Black and white grid of MNIST handwritten numbers with a distorted lens
photo credit: Tensorflow.org

Introduction

In the era of big data, image compression has become essential for reducing storage and transmission costs without sacrificing image quality. Traditional compression techniques like JPEG, GIF, and PNG use lossy or lossless methods that exploit redundancies in the pixel values. Although these techniques are effective, they are limited by how much data they can compress.

Deep learning-based compression techniques have emerged as a viable alternative to traditional methods. One of these methods is the Variational Autoencoder (VAE), a generative model that learns a lower-dimensional latent space representation of the image. This guide will explore using Variational Autoencoders for image compression, its working principles, and how to implement a VAE model using PyTorch. We will also use Comet ML to track and log the performance of our Variational Autoencoder (VAE) model.

What are Variational Autoencoders (VAEs)?

Generally, autoencoders are a class of neural networks that can learn to compress and decompress data using an iterative optimization process. They consist of two parts: an encoder and a decoder. The encoder maps the input data to a lower-dimensional representation, while the decoder maps the lower-dimensional representation back to the original data. Autoencoders can be trained in an unsupervised manner, meaning they don’t require labeled data.

Variational Autoencoder (VAE) is a type of autoencoder that utilizes probabilistic methods to learn a compressed representation of the input data. The basic idea is that VAE uses an encoder to transform the high-dimensional input data into a lower-dimensional representation and a decoder to reconstruct the original data from the lower-dimensional representation.

Joseph Rocca, in the article titled, Understanding Variational Autoencoders (VAEs), defined Variational Autoencoder (VAE) as:

“an autoencoder whose training is regularized to avoid overfitting and ensure that the latent space has good properties that enable generative process.”

Unlike traditional autoencoders, which encode an input as a single point, VAEs learn a probabilistic data distribution model in the latent space. VAEs model the data distribution using a Gaussian distribution, where the encoder learns the mean and variance. The decoder generates a sample from the learned distribution by sampling from the Gaussian distribution using the mean and variance.

Rocca explains the method for training a model over a latent space as follows:

  1. The input is encoded as a distribution over a latent space.
  2. A point from the latent space is sampled from that distribution.
  3. A sampled point is decoded, and the reconstruction error can be computed.
  4. Lastly, the reconstruction error is backpropagated via the network.
Flow chart demonstrating the difference between a deterministic autoencoder and a probabilistic variational autoencoder
source: Understanding Variational Autoencoders (VAEs) by Joseph Rocca

VAEs learn a probability distribution over the lower-dimensional representation instead of a single value. This allows them to generate new data samples by sampling from the learned probability distribution. VAEs have also proven ideal for several tasks, including image generation and enhancing low-resolution images.

Real-time model analysis allows your team to track, monitor, and adjust models already in production. Learn more lessons from the field with Comet experts.

Implementing VAE for Image Compression

In image compression, we can use VAEs to learn a compressed representation of an image. The encoder maps the image to a lower-dimensional representation, and the decoder maps the lower-dimensional representation back to the original image. The compressed representation is stored or transmitted efficiently.

Let’s explore how to implement VAEs for image compression using PyTorch and the MNIST dataset.

Requirements

For this tutorial, you need the following:

  1. Basic knowledge of Python and deep learning.
  2. PyTorch and Comet ML: We will use these to implement the VAE. Please see Comet docs to understand how to integrate Comet with PyTorch.
  3. The MNIST dataset: A dataset that comprises 60,000 small square 28x28 grayscale images of handwritten digits between 0 and 9. Download the MNIST dataset here
  4. Text editors: Visual Studio Code or Sublime Text would do.

Implementation

Install and import libraries

We will start by setting up Comet ML for experiment tracking. Comet ML is a platform for experiment tracking and reproducibility in machine learning. Comet makes it easier to track and compare different experiments, visualize results, and collaborate. We will use Comet to train the VAE model for several epochs and evaluate its performance on a validation set.

First, we need to create an account on Comet ML and install the Comet ML package:

!pip install comet_ml

Import the required libraries for the project:

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import transforms
from torch.autograd import Variable
from comet_ml import Experiment

Proceed to initiate Comet Experiment:

from comet_ml import Experiment

experiment = Experiment(api_key="your_api_key", project_name="vae-project")

Replace your_api_key with your Comet ML API key (you can find this on the settings page of your Comet ML account). Replace vae-project with the name of your project.

Load and Preprocess Data

We will load the MNIST dataset and preprocess the data by scaling it to the range [-1, 1].

transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
train_data = MNIST(root='./data', train=True, transform=transform, download=True)
test_data = MNIST(root='./data', train=False, transform=transform, download=True)
train_loader = DataLoader(train_data, batch_size=128, shuffle=True)
test_loader = DataLoader(test_data, batch_size=128, shuffle=False)

Define the Model

Let’s define the Encoder and Decoder as separate modules and combine them to form the complete VAE model. The Encoder takes the input image and outputs the mean and variance of the Gaussian distribution. The Decoder takes a sample from the distribution and generates the reconstructed image.

The forward method of the VAE model takes an input image and passes it through the Encoder to obtain the mean and variance of the distribution. It then samples from the distribution using the reparametrization trick and passes the sample through the Decoder to obtain the reconstructed image. We will define the VAE model using PyTorch as follows:

class Encoder(nn.Module):
def __init__(self):
super(Encoder, self).__init__()
self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
self.fc1 = nn.Linear(7 * 7 * 64, 256)
self.fc21 = nn.Linear(256, 128)
self.fc22 = nn.Linear(256, 128)

def forward(self, x):
x = nn.functional.relu(self.conv1(x))
x = nn.functional.relu(self.conv2(x))
x = x.view(-1, 7 * 7 * 64)
x = nn.functional.relu(self.fc1(x))
mu = self.fc21(x)
logvar = self.fc22(x)
return mu, logvar

class Decoder(nn.Module):
def __init__(self):
super(Decoder, self).__init__()
self.fc1 = nn.Linear(128, 256)
self.fc2 = nn.Linear(256, 7 * 7 * 64)
self.conv1 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=1, padding=1)
self.conv2 = nn.ConvTranspose2d(32, 1, kernel_size=3, stride=1, padding=1)

def forward(self, z):
z = nn.functional.relu(self.fc1(z))
z = nn.functional.relu(self.fc2(z))
z = z.view(-1, 64, 7, 7)
z = nn.functional.relu(self.conv1(z))
z = nn.functional.tanh(self.conv2(z))
return z

class VAE(nn.Module):
def __init__(self):
super(VAE, self).__init__()
self.encoder = Encoder()
self.decoder = Decoder()

def forward(self, x):
mu, logvar = self.encoder(x)
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
z = mu + eps * std
recon_x = self.decoder(z)
return recon_x, mu, logvar

Define the loss function

The loss function consists of two terms: the binary cross-entropy (BCE) loss, which measures the difference between the original image and the reconstructed image, and the Kullback-Leibler (KL) divergence, which measures the difference between the encoded distribution and the standard normal distribution.

Let’s define the loss function below:

def vae_loss(recon_x, x, mu, logvar):
BCE = nn.functional.binary_cross_entropy_with_logits(recon_x.view(-1, 784), x.view(-1, 784), reduction='sum')
KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
return BCE + KLD

The total loss is the sum of the BCE loss and the KL divergence loss. The BCE loss ensures that the reconstructed image is similar to the original image, while the KL divergence loss ensures that the distribution learned by the encoder is close to the standard normal distribution.

Initialize the model and define the optimizer

To initialize the VAE model, we instantiate a VAE object with the desired hyperparameters, such as the size of the latent space, number of layers in the encoder and decoder, and the number of channels in the input and output images. We also define the optimizer that will be used to update the model parameters during training. In this example, we use the Adam optimizer with a learning rate of 0.001.

Here is the code to initialize the model and optimizer:

# Define hyperparameters
latent_size = 20
input_channels = 1
output_channels = 1
hidden_dims = [32, 64, 128, 256]

# Initialize model and move it to device
model = VAE(latent_size, input_channels, output_channels, hidden_dims).to(device)

# Define optimizer
optimizer = optim.Adam(model.parameters(), lr=0.001)

Train the model

Let’s train the model by looping over the number of epochs and setting the model to training mode. We then loop over the training data and move it to the device. We should set the gradients to zero, pass the input data through the model to obtain the reconstructed image and the mean and variance of the distribution. Now, we calculate the VAE loss and backpropagate the gradients. Also, we add the loss to the total training loss and update the weights of the model using the optimizer.

for epoch in range(num_epochs):
model.train()
train_loss = 0
for batch_idx, (data, _) in enumerate(train_loader):
data = data.to(device)
optimizer.zero_grad()
recon_batch, mu, logvar = model(data)
loss = vae_loss(recon_batch, data, mu, logvar)
loss.backward()
train_loss += loss.item()
optimizer.step()
if batch_idx % log_interval == 0:
current_loss = train_loss / (batch_idx + 1)
print('Epoch: {} [{}/{} ({:.0f}%)]\tTraining Loss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader),
current_loss))
experiment.log_metric('train_loss', current_loss, step=(epoch * len(train_loader) + batch_idx))

We can now log the current training loss to the console and Comet using the experiment object. Also, we use the batch index and epoch number to keep track of the training progress.

Evaluate the model

It is time to evaluate our trained VAE model on the MNIST test dataset and visualize the reconstructed images. To do this, let’s set the model to evaluation mode and loop over the test data. Afterward, move the test data to the device and pass it through the model to obtain the reconstructed image and the mean and variance of the distribution. We then calculate the VAE loss and add it to the total test loss.

model.eval()
test_loss = 0
with torch.no_grad():
for data, _ in test_loader:
data = data.to(device)
recon_batch, mu, logvar = model(data)
test_loss += vae_loss(recon_batch, data, mu, logvar).item()

test_loss /= len(test_loader.dataset)
print('Test Loss: {:.6f}'.format(test_loss))
experiment.log_metric('test_loss', test_loss)

# visualize the reconstructed images
num_images = 8
for i in range(num_images):
fig, axs = plt.subplots(1, 2)
axs[0].imshow(data[i][0], cmap='gray')
axs[0].set_title('Original Image')
axs[1].imshow(recon_batch[i][0].cpu().detach().numpy(), cmap='gray')
axs[1].set_title('Reconstructed Image')
experiment.log_figure(figure_name=f"Reconstructed Image {i+1}", figure=plt)

Furthermore, we calculate the average test loss and log it to the console and Comet using the experiment object. We then visualize a few examples of the original images and their corresponding reconstructed images using Matplotlib.

In this example, we will create a subplot with two images — one for the original image and one for the reconstructed image. Also, let’s log each reconstructed image to Comet using the log_figure method of the experiment object.

However, note that the reconstructed images would look similar to the original images but with some loss of detail due to the compression. The degree of compression is controlled by the size of the latent space, which can be adjusted by changing the latent_size parameter in the VAE class.

Here is an example of the reconstructed images generated by the VAE model:

A black and white grid showing the original handwritten numbers 1 and 9 from the MNIST data set, and their reconstruction images, as generated from the variational autoencoder
reconstructed images generated by the VAE

As we can see, the reconstructed images are not perfect, but they capture the main features of the original images. You can increase the degree of compression by decreasing the size of the latent space, but this may result in further loss of detail in the reconstructed images.

Conclusion

This guide explored Variational Autoencoders (VAEs) for image compression. We implemented a VAE model using PyTorch and trained it on the MNIST dataset. We used Comet to track and log the model’s performance during training and evaluation.

We showed how VAEs are ideal for compressing images by mapping them to a lower-dimensional latent space and then reconstructing them from this space. VAEs provide a flexible and powerful framework for image compression that can be adapted to a wide range of datasets and applications.

I encourage you to experiment with different datasets, architectures, and hyperparameters and explore the capabilities and limitations of VAEs for image compression.

Resources

  1. Diederik P. Kingma and Max Welling (2019), “An Introduction to Variational Autoencoders” || Foundations and Trends in Machine Learning.
  2. Joseph Rocca and Baptiste Rocca (2021), “Understanding Variational Autoencoders (VAE)” || Towards Data Science.
  3. Comet Docs
  4. MNIST dataset
  5. Gaussian distribution || ScienceDirect

Editor’s Note: Heartbeat is a contributor-driven online publication and community dedicated to providing premier educational resources for data science, machine learning, and deep learning practitioners. We’re committed to supporting and inspiring developers and engineers from all walks of life.

Editorially independent, Heartbeat is sponsored and published by Comet, an MLOps platform that enables data scientists & ML teams to track, compare, explain, & optimize their experiments. We pay our contributors, and we don’t sell ads.

If you’d like to contribute, head on over to our call for contributors. You can also sign up to receive our weekly newsletter (Deep Learning Weekly), check out the Comet blog, join us on Slack, and follow Comet on Twitter and LinkedIn for resources, events, and much more that will help you build better ML models, faster.

--

--

A technical writer who embodies the finesse of Art, science, & persuasion. Expert in cryptocurrency/web3, AI, & software development.