MNIST Image classification

This script implements a model to classify the MNIST dataset. The model mainly consists of convolutonal layers and pooling layers with a few dense layers at the end. As the script is only for demonstration purposes, only 100 first datapoints are used to make the training faster. For a full example, change the parameter n to 60000. If n is increased, more epochs may need to be added and other hyperparameters tuned.

  • ImageClassification
  • True label: 7 | Predicted label: 7, True label: 2 | Predicted label: 2, True label: 1 | Predicted label: 1, True label: 0 | Predicted label: 0
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz

       0/11490434 ━━━━━━━━━━━━━━━━━━━━ 0s 0s/step
 8396800/11490434 ━━━━━━━━━━━━━━━━━━━━ 0s 0us/step
11490434/11490434 ━━━━━━━━━━━━━━━━━━━━ 0s 0us/step
torch.Size([409, 1, 28, 28]) torch.Size([409, 10]) torch.Size([103, 1, 28, 28]) torch.Size([103, 10]) torch.Size([512, 1, 28, 28]) torch.Size([512, 10])
tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
        [0, 0, 0, 0, 0, 0, 0, 0, 1, 0]])
Model summary:
Input - Output: ((1, 28, 28))
Conv2D - (Input, Output): ((1, 28, 28), (32, 26, 26)) - Parameters: 320
    ReLU - Output: ((32, 26, 26))
MaxPooling2D - Output: ((32, 13, 13))
Layer normalisation - Output: ((32, 13, 13)) - Parameters: 64
Conv2D - (Input, Output): ((32, 13, 13), (32, 11, 11)) - Parameters: 9248
    ReLU - Output: ((32, 11, 11))
MaxPooling2D - Output: ((32, 5, 5))
Group normalisation - Output: ((32, 5, 5)) - Parameters: 64
Dropout - Output: ((32, 5, 5)) - Keep probability: 0.8
Flatten - (Input, Output): ((32, 5, 5), 800)
Dense - (Input, Output): (800, 200) - Parameters: 160200
    ReLU - Output: (200)
Dense - (Input, Output): (200, 10) - Parameters: 2010
    Softmax - Output: (10)
Total number of parameters: 171906
Epoch: 1 - Metrics: {'loss': '1.8260', 'accuracy': '0.4841', 'val_loss': '2.0189', 'val_accuracy': '0.2718'}
Epoch: 2 - Metrics: {'loss': '1.5050', 'accuracy': '0.6284', 'val_loss': '1.6021', 'val_accuracy': '0.5437'}
Epoch: 3 - Metrics: {'loss': '1.1816', 'accuracy': '0.7311', 'val_loss': '1.2676', 'val_accuracy': '0.6699'}
Epoch: 4 - Metrics: {'loss': '0.9141', 'accuracy': '0.8191', 'val_loss': '0.9703', 'val_accuracy': '0.7573'}
Epoch: 5 - Metrics: {'loss': '0.7508', 'accuracy': '0.8289', 'val_loss': '0.7857', 'val_accuracy': '0.8350'}
Epoch: 6 - Metrics: {'loss': '0.5937', 'accuracy': '0.8557', 'val_loss': '0.6401', 'val_accuracy': '0.8932'}
Epoch: 7 - Metrics: {'loss': '0.5111', 'accuracy': '0.8753', 'val_loss': '0.5740', 'val_accuracy': '0.8544'}
Epoch: 8 - Metrics: {'loss': '0.4598', 'accuracy': '0.8778', 'val_loss': '0.5058', 'val_accuracy': '0.8835'}
Epoch: 9 - Metrics: {'loss': '0.4097', 'accuracy': '0.8924', 'val_loss': '0.4158', 'val_accuracy': '0.8835'}
Epoch: 10 - Metrics: {'loss': '0.3876', 'accuracy': '0.8802', 'val_loss': '0.4120', 'val_accuracy': '0.9029'}
Epoch: 11 - Metrics: {'loss': '0.3642', 'accuracy': '0.8875', 'val_loss': '0.3654', 'val_accuracy': '0.8835'}
Epoch: 12 - Metrics: {'loss': '0.3309', 'accuracy': '0.9022', 'val_loss': '0.3424', 'val_accuracy': '0.8932'}
Epoch: 13 - Metrics: {'loss': '0.2965', 'accuracy': '0.9193', 'val_loss': '0.3313', 'val_accuracy': '0.8932'}
Epoch: 14 - Metrics: {'loss': '0.2704', 'accuracy': '0.9242', 'val_loss': '0.3272', 'val_accuracy': '0.9029'}
Epoch: 15 - Metrics: {'loss': '0.2581', 'accuracy': '0.9364', 'val_loss': '0.3633', 'val_accuracy': '0.9029'}
Epoch: 16 - Metrics: {'loss': '0.2316', 'accuracy': '0.9389', 'val_loss': '0.3745', 'val_accuracy': '0.8932'}
Epoch: 17 - Metrics: {'loss': '0.2222', 'accuracy': '0.9364', 'val_loss': '0.3456', 'val_accuracy': '0.9029'}
0.787109375

import torch
import matplotlib.pyplot as plt
import tensorflow as tf

from DLL.DeepLearning.Model import Model
from DLL.DeepLearning.Layers import Dense, Conv2D, Flatten, MaxPooling2D, Reshape
from DLL.DeepLearning.Layers.Regularisation import Dropout, BatchNorm, GroupNorm, InstanceNorm, LayerNorm
from DLL.DeepLearning.Layers.Activations import ReLU, SoftMax
from DLL.DeepLearning.Losses import CCE
from DLL.DeepLearning.Optimisers import SGD, ADAM
from DLL.DeepLearning.Callbacks import EarlyStopping
from DLL.DeepLearning.Initialisers import Xavier_Normal, Xavier_Uniform, Kaiming_Normal, Kaiming_Uniform
from DLL.Data.Preprocessing import OneHotEncoder, data_split
from DLL.Data.Metrics import accuracy


device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

torch.manual_seed(0)
torch.cuda.manual_seed(0)

n = 512  # 60000
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()
train_images = torch.from_numpy(train_images).to(dtype=torch.float32, device=device).reshape(60000, 1, 28, 28)[:n]
train_labels = torch.from_numpy(train_labels).to(dtype=torch.float32, device=device)[:n]
test_images = torch.from_numpy(test_images).to(dtype=torch.float32, device=device).reshape(10000, 1, 28, 28)[:n]
test_labels = torch.from_numpy(test_labels).to(dtype=torch.float32, device=device)[:n]
train_images = train_images / train_images.max()
test_images = test_images / test_images.max()

train_images, train_labels, validation_images, validation_labels, _, _ = data_split(train_images, train_labels, train_split=0.8, validation_split=0.2)

label_encoder = OneHotEncoder()
train_labels = label_encoder.fit_encode(train_labels)
validation_labels = label_encoder.encode(validation_labels)
test_labels = label_encoder.encode(test_labels)
print(train_images.shape, train_labels.shape, validation_images.shape, validation_labels.shape, test_images.shape, test_labels.shape)
print(train_labels[:2])

model = Model((1, 28, 28), device=device)
model.add(Conv2D(kernel_size=3, output_depth=32, initialiser=Kaiming_Normal(), activation=ReLU()))
model.add(MaxPooling2D(pool_size=2))
model.add(LayerNorm())
# model.add(BatchNorm())
model.add(Conv2D(kernel_size=3, output_depth=32, initialiser=Kaiming_Uniform(), activation=ReLU()))
model.add(MaxPooling2D(pool_size=2))
# model.add(InstanceNorm())
# model.add(LayerNorm())
model.add(GroupNorm(num_groups=16))
model.add(Dropout(p=0.2))
model.add(Flatten())
# model.add(Reshape(800))
model.add(Dense(200, activation=ReLU()))
model.add(Dense(10, activation=SoftMax()))
model.compile(optimiser=ADAM(learning_rate=0.001), loss=CCE(), metrics=["loss", "val_loss", "val_accuracy", "accuracy"], callbacks=(EarlyStopping(patience=3),))
model.summary()

history = model.fit(train_images, train_labels, val_data=(validation_images, validation_labels), epochs=25, batch_size=256, verbose=True)

plt.figure(figsize=(8, 6))
plt.subplot(1, 2, 1)
plt.plot(history["val_loss"], label="validation loss")
plt.plot(history["loss"], label="loss")
plt.xlabel("Epoch")
plt.ylabel("Categorical cross entropy")
plt.legend()
plt.subplot(1, 2, 2)
plt.plot(history["val_accuracy"], label="validation accuracy")
plt.plot(history["accuracy"], label="accuracy")
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.legend()
test_pred = model.predict(test_images)
print(accuracy(test_pred, test_labels))
plt.show()

fig, ax = plt.subplots(2, 2, figsize=(8, 8))
ax = ax.ravel()
for i in range(len(ax)):
    ax[i].imshow(test_images[i].numpy()[0], cmap='gray', vmin=0, vmax=1)
    ax[i].set_title(f"True label: {test_labels[i].argmax()} | Predicted label: {test_pred[i].argmax()}")
plt.show()

Total running time of the script: (1 minutes 19.348 seconds)

Gallery generated by Sphinx-Gallery