MNIST Digits Classification with PyTorch, PyTorch Lightning and PyTorch Ignite

What is Classification ?

Classification is type of supervised learning approach. The machine learning model learns to specify which of n categories the input data points belong to. These cateogires are sometimes referred as classes/targets/labels. This task, can be formulated as y = f(x) where, learning algorithm ususlly generate a mapping function (f) such that it maps from input variable x (from domain R) to discrete output variable y (to domain {0,1,2,...n-1}).

For example, object detection, spam detection, and binary classifier like cancer classification.


Working on MNIST digit classification data is like "Hello World" of deep learning in computer vision. Which infer, that it is the most basic and essential dataset to work on if you want to get deeper into world of image classsifcation and object detection.

The MNIST database of handwritten digits, involves 60000 training and 10000 testing image samples with classes from ranging from 0 to 9. It has images of size 1x28x28, i.e. 1 denotes grayscale images having size 28x28 in spatial domain.

alt text

Show me the Code

alt text

In [1]:
PyTorch Installation

Easy to install. Use the following command line: pip install torch torchvision

In [2]:
! pip install torch torchvision
Training an image classifier using PyTorch

  1. Initialize the hyperparameters and arguments.
  2. Load and apply preprocessing in training and validation data splits.
  3. Create your own custom CNN model.
  4. Define a loss function and optimizer function with initial hyperparameters.
  5. Train the neural network model with the train set.
  6. Validate the neural network model with the validation set.
  7. Print and Save the best model based on the highest achieved validation accuracy.

1. Initialize the hyperparameters and arguments.

It is a good practise to initialize your hyperparameters and other aruments for future ease and changes.


  1. Learning rate(lr)
  2. Momentum
  3. Weight decay
  4. Batch Size
  5. Number of Epochs
  6. lr scheduler

and some other arguments such as seed, number of gpus, and path to save our best model.

In [10]:
import os
import torch
args = {
        'num_gpus': 1,
        'ckpt_dir': 'ckpt/',
        'dataset': 'mnist',
        'epochs': 5,
        'train_batch_size': 32,
        'test_batch_size': 128,
        'lr' : 0.01,
        'lr_schedule': 20,
        'gamma': 0.7,
        'momentum': 0.9,
        'nesterov': False,
        'weight_decay': 5e-4,
        'no-cuda': True,
        'seed': 13


# Uing GPU or not
use_cuda = not args['no-cuda'] and torch.cuda.is_available()

#To make a reproducible code

#Runs in with and without GPU systems: device 
device = torch.device("cuda" if use_cuda else "cpu")

#Save Model Directory
if not os.path.exists(args['ckpt_dir']):
print("Model Save Path:", args['ckpt_dir'])
Model Save Path: ckpt/

2. Load and apply preprocessing in training and validation data splits.

Specially for computer vision tasks, we use a package called torchvision. Torchvision load the train and validation set into a PILImage images of range [0,1], and optionally we can apply various pre-processing steps for better learning. Later, we convert them into tensors of normalized range [-1,1].

In [4]:
#Data loading & Preprcessing 

import torchvision
from torchvision import datasets, transforms
from import DataLoader 
# DataLoader 

#If GPU then use this extra key arguments
kwargs = {'num_workers': 2, 'pin_memory': True} if use_cuda else {}

#Train DataLoader
transform = transforms.Compose([transforms.ToTensor(),
                    transforms.Normalize((0.1307,), (0.3081,))

train_dataset = datasets.MNIST( root ='data', train=True, download=True,
train_loader = DataLoader( train_dataset,
    batch_size=args['train_batch_size'], shuffle=True, **kwargs)

# Validation (Test) DataLoader
test_dataset = datasets.MNIST(root = 'data', train=False, transform=transform)
test_loader = DataLoader( test_dataset,
    batch_size=args['test_batch_size'], shuffle=False, **kwargs)
3. Create your own custom CNN model.

A basic CNN Model consits of sequence of Convolutional, BatchNorm, ReLU, Dropout, and MaxPool Layers followed by a few Dense layers. An example model is shown below:

alt text

In [5]:
#Or write your own Custom Model 
import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
    def __init__(self, in_channels=1, out_channels=10):
        super(Net, self).__init__()

        self.conv1 = nn.Conv2d(in_channels, 32, 3, 1)
        self.bn1 = nn.BatchNorm2d(32)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.bn2 = nn.BatchNorm2d(64)

        self.maxpool = nn.MaxPool2d(2)
        self.dropout1 = nn.Dropout2d(0.25)
        self.dropout2 = nn.Dropout2d(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, out_channels)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.dropout1(x)

        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.dropout2(x)
        logits = self.fc2(x)
        return logits

model = Net(in_channels=1, out_channels=10).to(device)

  (conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))
  (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
  (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (maxpool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (dropout1): Dropout2d(p=0.25, inplace=False)
  (dropout2): Dropout2d(p=0.5, inplace=False)
  (fc1): Linear(in_features=9216, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=10, bias=True)

4. Define a loss function and optimizer function with initial hyperparameters.

Loss function: Cross Entropy Loss

alt text

where, y_hat is the outputs of the model anf y is target classes.

Optimizer: Stochastic Gradient Descent with hyperparameters mentioned above.

In [6]:
# Cross Entropy loss and Stochastic Gradient Descent optimizer with lr scheduling
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=args['lr'], momentum=args['momentum'], nesterov=args['nesterov'],weight_decay=args['weight_decay'])
scheduler = StepLR(optimizer, step_size=args['lr_schedule'], gamma=args['gamma'])

5. Train the neural network model with the train set.

Write a function for training using our train dataloader, model, loss function and optimizer.

Note: Dduring training, we need to set model.train() and optimizer.zero_grad() before feeding inputs to the model.

In [7]:
def train(train_loader, model, device, criterion, optimizer):
    for batch_idx, (inputs, targets) in enumerate(train_loader):
        inputs, targets =,
        outputs = model(inputs)
        loss = criterion(outputs, targets)

6. Validate the neural network model with the validation set.

Write a function for evaluating the performace of your model by printing accuarcy after every epoch.

Note: During validation, to freeze model parameters we use model.eval() and torch.no_grad() before feeding data into the model.

In [8]:
# Evaluating
def eval(test_loader, model, device,  best_acc):
    correct,total = 0,0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(test_loader):
            inputs, targets =,
            outputs = model(inputs)
            _, predicted = torch.max(, 1)
            total += targets.size(0)
            correct += (predicted == targets).sum().item()
    print('Epoch:',epoch, 'Accuracy: %f %%' % (100 * correct / total), 'best_accuracy:', best_acc)
    return (float(100 * correct / total))

7. Print and Save the best model based on the highest achieved validation accuracy


  1. Training Function
  2. Evaluation Funvtion
  3. Print current accuracy and best validation accuracy for each epoch.
  4. save current and best model in the directory ckpt.
In [11]:

best_acc = 0
start_epoch = 0
for epoch in range(start_epoch, args['epochs']):
    train( train_loader, model, device, criterion, optimizer)
    acc = eval(test_loader, model, device, best_acc)
    if acc > best_acc:
        best_acc = acc
        bestepoch = epoch, os.path.join(args['ckpt_dir'], 'best_model_epoch' + '.pth.tar')), os.path.join(args['ckpt_dir'], 'final_model_epoch' + '.pth.tar'))
Epoch: 0 Accuracy: 98.670000 % best_accuracy: 0
Epoch: 1 Accuracy: 98.910000 % best_accuracy: 98.67
Epoch: 2 Accuracy: 98.880000 % best_accuracy: 98.91
Epoch: 3 Accuracy: 98.840000 % best_accuracy: 98.91
Epoch: 4 Accuracy: 99.130000 % best_accuracy: 98.91
In [18]:
print('Best Epoch: {}, Best Accuracy: {:.2f}'.format(bestepoch, best_acc))
Best Epoch: 4, Best Accuracy: 99.13

Rapid Research Framework for PyTorch

alt text


  1. The lightweight PyTorch wrapper for ML researchers.
  2. Scale your models.
  3. Write less boilerplate.
  4. Supports Linux Python 3.6 ( PyTorch 1.1, 1.2, 1.3, 1.4 versions)

Lightning Installation

Easy to use. Simply, run pip install pytorch-lightning to install.

In [12]:
! pip install pytorch-lightning
Lightning Speed Code

Every project goes into a LightningModule in pytorch-lightning framework.

This is a standard interface of 9 to 11 required methods every model has to follow.

It involves:

  1. Defining a custom Model and variable initialization __init__

  2. Computations i.e. forward pass

  3. Training loop process: training_step with defining of loss function.

  4. Validation loop process: validation_step

  5. Model performance evaluation after each epoch: validation_epoch_end

  6. Load, split and pre-process dataset : prepare_data

  7. Training loader train_dataloader

  8. Validation loader: val_dataloader

  9. Optimizer(s) and scheduling is defined : configure_optimizers

Now we build our own code for MNIST handwritten digits classification in a futuristic pytorch based research framework called PyTorch-Lightning.

Note: For consitency we will use same model and hyperparameters as we have used in PyTorch scripts.

In [13]:
import os
import torch
import torch.nn as nn
from torch.nn import functional as F
from import DataLoader
from torchvision.datasets import MNIST
import torch.optim as optim
from torchvision import transforms
import pytorch_lightning as pl
from collections import OrderedDict

class LightNet(pl.LightningModule):

    def __init__(self, in_channels=1, out_channels=10):
        super(LightNet, self).__init__()

        self.conv1 = nn.Conv2d(in_channels, 32, 3, 1)
        self.bn1 = nn.BatchNorm2d(32)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.bn2 = nn.BatchNorm2d(64)

        self.maxpool = nn.MaxPool2d(2)
        self.dropout1 = nn.Dropout2d(0.25)
        self.dropout2 = nn.Dropout2d(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, out_channels)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.dropout1(x)

        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.dropout2(x)
        logits = self.fc2(x)
        return logits
    # define the loss function
    def criterion(self, logits, targets):
        return F.cross_entropy(logits, targets)

    # process inside the training loop
    def training_step(self, train_batch, batch_idx):
        inputs, targets = train_batch
        outputs = self.forward(inputs)
        loss = self.criterion(outputs, targets)

        #inbuilt tensorboard for logs
        tensorboard_logs = {'train_loss': loss}

        return {'loss': loss, 'log': tensorboard_logs}   

    # process inside the validation loop
    def validation_step(self, train_batch, batch_idx):
        inputs, targets = train_batch
        outputs = self.forward(inputs)
        loss = self.criterion(outputs, targets)

        # Accuracy calculation
        pred =[1]  # get the index of the max log-probability
        incorrect =
        err = incorrect.item()/targets.numel()
        val_acc = torch.tensor(1.0-err)

        return {'val_loss': loss, 'val_acc': val_acc}    

    #return average loss and accuracy after every epoch
    def validation_epoch_end(self, outputs):
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        avg_acc = torch.stack([x['val_acc'] for x in outputs]).mean()  
        Accuracy = 100 * avg_acc.item()
        tensorboard_logs = {'val_loss': avg_loss, 'avg_val_acc': avg_acc}
        print('Val Loss:', round(avg_loss.item(),2), 'Val Accuracy: %f %%' % Accuracy) 
        return {'avg_val_loss': avg_loss, 'progress_bar': tensorboard_logs}

    # Load, split and transform PILimage images into normalized tensors in range [-1, 1]. 
    def prepare_data(self):

                                      transforms.Normalize((0.1307,), (0.3081,))])

        train_dataset = MNIST('data', train=True, download=True, transform=transform)
        test_dataset = MNIST('data', train=False, download=True, transform=transform)
        self.mnist_train, self.mnist_val = train_dataset, test_dataset

    #Create train loader
    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=64, num_workers=2)
    #Create validation loader
    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=128, num_workers=2)

    # Can return multiple optimizers and scheduling alogoithms 
    # Here using Stuochastic Gradient Descent
    def configure_optimizers(self):
        optimizer = optim.SGD(self.parameters(), lr=1e-3, momentum=0.9, nesterov=False,weight_decay=5e-4)
        return optimizer

Trainer : Run Lightning Code Here

Now the mystical part is here :

Feed the LightningModule and see the marvel:

In [14]:
# most basic trainer, uses good defaults (1 gpu)
# It also prints the model summary comparable to keras summary for free.
model = LightNet()

# Here you can add many features in your Trainer: such as num_epochs,  gpus used, clusters used etc.
trainer = pl.Trainer(max_epochs=5)
GPU available: True, used: False
TPU available: False, using: 0 TPU cores

  | Name     | Type        | Params
0 | conv1    | Conv2d      | 320   
1 | bn1      | BatchNorm2d | 64    
2 | relu     | ReLU        | 0     
3 | conv2    | Conv2d      | 18 K  
4 | bn2      | BatchNorm2d | 128   
5 | maxpool  | MaxPool2d   | 0     
6 | dropout1 | Dropout2d   | 0     
7 | dropout2 | Dropout2d   | 0     
8 | fc1      | Linear      | 1 M   
9 | fc2      | Linear      | 1 K   
Val Loss: 2.3 Val Accuracy: 7.421875 %

Val Loss: 0.09 Val Accuracy: 97.438687 %
Val Loss: 0.06 Val Accuracy: 98.101264 %
Val Loss: 0.05 Val Accuracy: 98.378164 %
Val Loss: 0.04 Val Accuracy: 98.605615 %
Val Loss: 0.04 Val Accuracy: 98.674840 %


Now let's see another high-level PyTorch based framework

alt text

Ignite Installation

It is also pretty smooth to setup ignite in your system. Run the following command and start your code: pip install pytorch-ignite

In [28]:
! pip install pytorch-ignite
Ignite Setup

Note: Ignite requires a similar setup but it does not have any standard module.

In [39]:
import torch
from torch import nn
import torch.optim as optim
from import DataLoader
import torch.nn.functional as F
from torchvision import transforms
from torchvision.datasets import MNIST

# ignite pre-defined functions
from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator
from ignite.metrics import Accuracy, Loss

#Here we will be using tqdm logger instead of tensorboard loagger to make a code simple to grasp
from tqdm import tqdm

# Define a custom model
class Net(nn.Module):
    def __init__(self, in_channels=1, out_channels=10):
        super(Net, self).__init__()

        self.conv1 = nn.Conv2d(in_channels, 32, 3, 1)
        self.bn1 = nn.BatchNorm2d(32)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.bn2 = nn.BatchNorm2d(64)

        self.maxpool = nn.MaxPool2d(2)
        self.dropout1 = nn.Dropout2d(0.25)
        self.dropout2 = nn.Dropout2d(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, out_channels)
    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.dropout2(x)
        logits = self.fc2(x)
        return logits

# Create a function for data loading (train and validation loader)
def get_data_loaders(train_batch_size, val_batch_size):
    data_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

    train_loader = DataLoader(MNIST(download=True, root="datai", transform=data_transform, train=True),
                              batch_size=train_batch_size, shuffle=True)

    val_loader = DataLoader(MNIST(download=False, root="datai", transform=data_transform, train=False),
                            batch_size=val_batch_size, shuffle=False)
    return train_loader, val_loader

""" run function--> train and valid dataloader --> define model--> define device -->optimizer -->Loss func-->
 pass into a ignite trainer --> use ignite evaluater for accuracy"""
def run(train_batch_size, val_batch_size, epochs, lr, momentum, log_interval):
    train_loader, val_loader = get_data_loaders(train_batch_size, val_batch_size)
    model = Net()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)
    trainer = create_supervised_trainer(model, optimizer, F.cross_entropy, device=device)

    evaluator = create_supervised_evaluator(model,
                                            metrics={'accuracy': Accuracy(),
                                                     'nll': Loss(F.cross_entropy)},

    desc = "ITERATION - loss: {:.2f}"
    pbar = tqdm(
        initial=0, leave=False, total=len(train_loader),

    def log_training_loss(engine):
        iter = (engine.state.iteration - 1) % len(train_loader) + 1

        if iter % log_interval == 0:
            pbar.desc = desc.format(engine.state.output)

    def log_training_results(engine):
        metrics = evaluator.state.metrics
        avg_accuracy = metrics['accuracy']
        avg_nll = metrics['nll']
            "Training Results - Epoch: {}  Avg accuracy: {:.2f} Avg loss: {:.2f}"
            .format(engine.state.epoch, avg_accuracy, avg_nll)

    def log_validation_results(engine):
        metrics = evaluator.state.metrics
        avg_accuracy = metrics['accuracy']
        avg_nll = metrics['nll']
            "Validation Results - Epoch: {}  Avg accuracy: {:.2f} Avg loss: {:.2f}"
            .format(engine.state.epoch, avg_accuracy, avg_nll))

        pbar.n = pbar.last_print_n = 0, max_epochs=epochs)

Run: Ignite Your Code

In [ ]:
run(32, 32, 5, 0.02, 0.5, 10)

ITERATION - loss: 0.12: 100%|█████████▉| 1870/1875 [00:46<00:00, 61.65it/s]Training Results - Epoch: 1 Avg accuracy: 0.98 Avg loss: 0.07

ITERATION - loss: 0.06: 1%| | 10/1875 [00:49<18:05, 1.72it/s] Validation Results - Epoch: 1 Avg accuracy: 0.98 Avg loss: 0.06

ITERATION - loss: 0.01: 100%|█████████▉| 1870/1875 [01:35<00:00, 60.16it/s]Training Results - Epoch: 2 Avg accuracy: 0.99 Avg loss: 0.04

ITERATION - loss: 0.03: 1%| | 10/1875 [01:38<18:10, 1.71it/s] Validation Results - Epoch: 2 Avg accuracy: 0.99 Avg loss: 0.04

ITERATION - loss: 0.14: 100%|█████████▉| 1870/1875 [02:24<00:00, 62.32it/s]Training Results - Epoch: 3 Avg accuracy: 0.99 Avg loss: 0.03

ITERATION - loss: 0.01: 1%| | 10/1875 [02:26<18:15, 1.70it/s] Validation Results - Epoch: 3 Avg accuracy: 0.99 Avg loss: 0.04

ITERATION - loss: 0.05: 100%|█████████▉| 1870/1875 [03:13<00:00, 59.78it/s]Training Results - Epoch: 4 Avg accuracy: 0.99 Avg loss: 0.03

ITERATION - loss: 0.22: 1%| | 10/1875 [03:16<18:19, 1.70it/s] Validation Results - Epoch: 4 Avg accuracy: 0.99 Avg loss: 0.04

ITERATION - loss: 0.04: 100%|█████████▉| 1870/1875 [04:02<00:00, 62.03it/s]Training Results - Epoch: 5 Avg accuracy: 0.99 Avg loss: 0.02

Thank You