Latent-Diffusion-MNIST-DDPM-using-Autoencoder

Latent Diffusion Models for MNIST Dataset using Autoencoders 🌀

Open In Colab Python Status

Welcome to this project exploring Diffusion Models on the MNIST dataset! 🚀

This repository focuses on generating and reconstructing handwritten digits by integrating:

Grid Diffusion


Overview

This project aims to reconstruct MNIST digits by encoding them into a latent space and progressively denoising them through a Diffusion Model.

Highlights:


Autoencoder with CABs

Channel Attention Block (CAB)

CABs refine the feature maps by:

class CALayer(nn.Module):
    def __init__(self, channel, reduction=16, bias=False):
        super(CALayer, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.conv_du = nn.Sequential(
            nn.Conv2d(channel, channel // reduction, 1, bias=bias),
            nn.ReLU(inplace=True),
            nn.Conv2d(channel // reduction, channel, 1, bias=bias),
            nn.Sigmoid()
        )
    def forward(self, x):
        y = self.avg_pool(x)
        y = self.conv_du(y)
        return x * y

Diffusion Model (DDPM)

U-Net Architecture

The core of the diffusion model is a U-Net, enhanced with:


Results and Visualizations

The project features multiple visual outputs that highlight the training and performance of the model.

1. Reconstruction Performance

A side-by-side comparison of original vs reconstructed images. High SSIM and PSNR scores indicate effective reconstructions.

Visualization: Reconstruction Performance


2. Latent Space Visualization

Projection of latent space using t-SNE for a batch and the full test dataset.

Latent Space with Labels (One Batch)
Latent Space with Labels (One Batch)
Latent Space (Full Test Dataset)
Latent Space Full Test

3. Training Loss Curve

Tracking the loss of the diffusion model over epochs.

Loss Plot:

Training Loss


4. Denoising Visualization (Step by Step)

Images progress from noisy states (left) to denoised outputs (right), demonstrating the stepwise denoising process.

Sample Visualization: Denoising Process


Formulas and Key Concepts

The project applies unconditional latent diffusion inspired by classic DDPMs but focuses on the latent space. Below is a simplified breakdown of the key concepts:

$x_t = \sqrt{\bar{\alpha}_t} x_0 + \sqrt{1 - \bar{\alpha}_t} ε$

Where:

Reverse Process (Denoising):

$x_{t-1} = \frac{1}{\sqrt{\alpha_t}} (x_t - (1 - \alpha_t) ε)$

This iterative denoising helps reconstruct the original data.


Future Directions

Here are a few ideas to extend this project:


🚀 Happy Training! 🧠