MNIST Image classification

This script implements a model to classify the MNIST dataset. The model mainly consists of convolutonal layers and pooling layers with a few dense layers at the end. As the script is only for demonstration purposes, only 100 first datapoints are used to make the training faster. For a full example, change the parameter n to 60000. If n is increased, more epochs may need to be added and other hyperparameters tuned.

  • ImageClassification
  • True label: 7 | Predicted label: 7, True label: 2 | Predicted label: 2, True label: 1 | Predicted label: 1, True label: 0 | Predicted label: 0
0.3%
0.7%
1.0%
1.3%
1.7%
2.0%
2.3%
2.6%
3.0%
3.3%
3.6%
4.0%
4.3%
4.6%
5.0%
5.3%
5.6%
6.0%
6.3%
6.6%
6.9%
7.3%
7.6%
7.9%
8.3%
8.6%
8.9%
9.3%
9.6%
9.9%
10.2%
10.6%
10.9%
11.2%
11.6%
11.9%
12.2%
12.6%
12.9%
13.2%
13.6%
13.9%
14.2%
14.5%
14.9%
15.2%
15.5%
15.9%
16.2%
16.5%
16.9%
17.2%
17.5%
17.9%
18.2%
18.5%
18.8%
19.2%
19.5%
19.8%
20.2%
20.5%
20.8%
21.2%
21.5%
21.8%
22.1%
22.5%
22.8%
23.1%
23.5%
23.8%
24.1%
24.5%
24.8%
25.1%
25.5%
25.8%
26.1%
26.4%
26.8%
27.1%
27.4%
27.8%
28.1%
28.4%
28.8%
29.1%
29.4%
29.8%
30.1%
30.4%
30.7%
31.1%
31.4%
31.7%
32.1%
32.4%
32.7%
33.1%
33.4%
33.7%
34.0%
34.4%
34.7%
35.0%
35.4%
35.7%
36.0%
36.4%
36.7%
37.0%
37.4%
37.7%
38.0%
38.3%
38.7%
39.0%
39.3%
39.7%
40.0%
40.3%
40.7%
41.0%
41.3%
41.7%
42.0%
42.3%
42.6%
43.0%
43.3%
43.6%
44.0%
44.3%
44.6%
45.0%
45.3%
45.6%
45.9%
46.3%
46.6%
46.9%
47.3%
47.6%
47.9%
48.3%
48.6%
48.9%
49.3%
49.6%
49.9%
50.2%
50.6%
50.9%
51.2%
51.6%
51.9%
52.2%
52.6%
52.9%
53.2%
53.6%
53.9%
54.2%
54.5%
54.9%
55.2%
55.5%
55.9%
56.2%
56.5%
56.9%
57.2%
57.5%
57.9%
58.2%
58.5%
58.8%
59.2%
59.5%
59.8%
60.2%
60.5%
60.8%
61.2%
61.5%
61.8%
62.1%
62.5%
62.8%
63.1%
63.5%
63.8%
64.1%
64.5%
64.8%
65.1%
65.5%
65.8%
66.1%
66.4%
66.8%
67.1%
67.4%
67.8%
68.1%
68.4%
68.8%
69.1%
69.4%
69.8%
70.1%
70.4%
70.7%
71.1%
71.4%
71.7%
72.1%
72.4%
72.7%
73.1%
73.4%
73.7%
74.0%
74.4%
74.7%
75.0%
75.4%
75.7%
76.0%
76.4%
76.7%
77.0%
77.4%
77.7%
78.0%
78.3%
78.7%
79.0%
79.3%
79.7%
80.0%
80.3%
80.7%
81.0%
81.3%
81.7%
82.0%
82.3%
82.6%
83.0%
83.3%
83.6%
84.0%
84.3%
84.6%
85.0%
85.3%
85.6%
85.9%
86.3%
86.6%
86.9%
87.3%
87.6%
87.9%
88.3%
88.6%
88.9%
89.3%
89.6%
89.9%
90.2%
90.6%
90.9%
91.2%
91.6%
91.9%
92.2%
92.6%
92.9%
93.2%
93.6%
93.9%
94.2%
94.5%
94.9%
95.2%
95.5%
95.9%
96.2%
96.5%
96.9%
97.2%
97.5%
97.9%
98.2%
98.5%
98.8%
99.2%
99.5%
99.8%
100.0%

100.0%

2.0%
4.0%
6.0%
7.9%
9.9%
11.9%
13.9%
15.9%
17.9%
19.9%
21.9%
23.8%
25.8%
27.8%
29.8%
31.8%
33.8%
35.8%
37.8%
39.7%
41.7%
43.7%
45.7%
47.7%
49.7%
51.7%
53.7%
55.6%
57.6%
59.6%
61.6%
63.6%
65.6%
67.6%
69.6%
71.5%
73.5%
75.5%
77.5%
79.5%
81.5%
83.5%
85.5%
87.4%
89.4%
91.4%
93.4%
95.4%
97.4%
99.4%
100.0%

100.0%
torch.Size([409, 1, 28, 28]) torch.Size([409, 10]) torch.Size([103, 1, 28, 28]) torch.Size([103, 10]) torch.Size([512, 1, 28, 28]) torch.Size([512, 10])
tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
        [0, 0, 0, 0, 0, 0, 0, 0, 1, 0]])
Model summary:
Input - Output: ((1, 28, 28))
Conv2D - (Input, Output): ((1, 28, 28), (32, 28, 28)) - Parameters: 320
    ReLU - Output: ((32, 28, 28))
Add - (Input, Output): ((32, 28, 28), (32, 28, 28))
    Conv2D - (Input, Output): ((32, 28, 28), (32, 28, 28)) - Parameters: 9248
        ReLU - Output: ((32, 28, 28))
    Identity - Output: ((32, 28, 28))
Add - (Input, Output): ((32, 28, 28), (32, 28, 28))
    Conv2D - (Input, Output): ((32, 28, 28), (32, 28, 28)) - Parameters: 9248
        ReLU - Output: ((32, 28, 28))
    Identity - Output: ((32, 28, 28))
MaxPooling2D - Output: ((32, 14, 14))
Batch normalisation - Output: ((32, 14, 14)) - Parameters: 12544
Add - (Input, Output): ((32, 14, 14), (32, 14, 14))
    Conv2D - (Input, Output): ((32, 14, 14), (32, 14, 14)) - Parameters: 9248
        ReLU - Output: ((32, 14, 14))
    Identity - Output: ((32, 14, 14))
Add - (Input, Output): ((32, 14, 14), (32, 14, 14))
    Conv2D - (Input, Output): ((32, 14, 14), (32, 14, 14)) - Parameters: 9248
        ReLU - Output: ((32, 14, 14))
    Identity - Output: ((32, 14, 14))
Add - (Input, Output): ((32, 14, 14), (32, 14, 14))
    Conv2D - (Input, Output): ((32, 14, 14), (32, 14, 14)) - Parameters: 9248
        ReLU - Output: ((32, 14, 14))
    Identity - Output: ((32, 14, 14))
MaxPooling2D - Output: ((32, 7, 7))
Batch normalisation - Output: ((32, 7, 7)) - Parameters: 3136
Add - (Input, Output): ((32, 7, 7), (32, 7, 7))
    Conv2D - (Input, Output): ((32, 7, 7), (32, 7, 7)) - Parameters: 9248
        ReLU - Output: ((32, 7, 7))
    Identity - Output: ((32, 7, 7))
Add - (Input, Output): ((32, 7, 7), (32, 7, 7))
    Conv2D - (Input, Output): ((32, 7, 7), (32, 7, 7)) - Parameters: 9248
        ReLU - Output: ((32, 7, 7))
    Identity - Output: ((32, 7, 7))
Add - (Input, Output): ((32, 7, 7), (32, 7, 7))
    Conv2D - (Input, Output): ((32, 7, 7), (32, 7, 7)) - Parameters: 9248
        ReLU - Output: ((32, 7, 7))
    Identity - Output: ((32, 7, 7))
Conv2D - (Input, Output): ((32, 7, 7), (8, 7, 7)) - Parameters: 2312
    ReLU - Output: ((8, 7, 7))
Batch normalisation - Output: ((8, 7, 7)) - Parameters: 784
Dropout - Output: ((8, 7, 7)) - Keep probability: 0.7
Flatten - (Input, Output): ((8, 7, 7), 392)
Dense - (Input, Output): (392, 200) - Parameters: 78600
    ReLU - Output: (200)
Dense - (Input, Output): (200, 10) - Parameters: 2010
    Softmax - Output: (10)
Total number of parameters: 173690
Epoch: 1 - Metrics: {'loss': '9.1475', 'accuracy': '0.0978', 'val_loss': '9.0377', 'val_accuracy': '0.0971'}
Epoch: 2 - Metrics: {'loss': '7.4317', 'accuracy': '0.1027', 'val_loss': '7.1494', 'val_accuracy': '0.0971'}
Epoch: 3 - Metrics: {'loss': '5.0784', 'accuracy': '0.1198', 'val_loss': '4.8808', 'val_accuracy': '0.1262'}
Epoch: 4 - Metrics: {'loss': '3.5678', 'accuracy': '0.1247', 'val_loss': '3.4583', 'val_accuracy': '0.1553'}
Epoch: 5 - Metrics: {'loss': '2.4567', 'accuracy': '0.1785', 'val_loss': '2.4486', 'val_accuracy': '0.2330'}
Epoch: 6 - Metrics: {'loss': '1.6430', 'accuracy': '0.4205', 'val_loss': '1.7033', 'val_accuracy': '0.4466'}
Epoch: 7 - Metrics: {'loss': '1.1709', 'accuracy': '0.6137', 'val_loss': '1.2493', 'val_accuracy': '0.6311'}
Epoch: 8 - Metrics: {'loss': '0.8938', 'accuracy': '0.7359', 'val_loss': '0.9796', 'val_accuracy': '0.7184'}
Epoch: 9 - Metrics: {'loss': '0.6822', 'accuracy': '0.8435', 'val_loss': '0.7699', 'val_accuracy': '0.8058'}
Epoch: 10 - Metrics: {'loss': '0.5457', 'accuracy': '0.8851', 'val_loss': '0.6452', 'val_accuracy': '0.8447'}
Epoch: 11 - Metrics: {'loss': '0.4554', 'accuracy': '0.8973', 'val_loss': '0.5666', 'val_accuracy': '0.8835'}
Epoch: 12 - Metrics: {'loss': '0.3836', 'accuracy': '0.9193', 'val_loss': '0.5113', 'val_accuracy': '0.9029'}
Epoch: 13 - Metrics: {'loss': '0.3222', 'accuracy': '0.9413', 'val_loss': '0.4724', 'val_accuracy': '0.9029'}
Epoch: 14 - Metrics: {'loss': '0.2651', 'accuracy': '0.9511', 'val_loss': '0.4337', 'val_accuracy': '0.9126'}
Epoch: 15 - Metrics: {'loss': '0.2150', 'accuracy': '0.9633', 'val_loss': '0.4004', 'val_accuracy': '0.9126'}
Epoch: 16 - Metrics: {'loss': '0.1715', 'accuracy': '0.9780', 'val_loss': '0.3673', 'val_accuracy': '0.9126'}
Epoch: 17 - Metrics: {'loss': '0.1420', 'accuracy': '0.9853', 'val_loss': '0.3408', 'val_accuracy': '0.9126'}
Epoch: 18 - Metrics: {'loss': '0.1209', 'accuracy': '0.9878', 'val_loss': '0.3217', 'val_accuracy': '0.9126'}
Epoch: 19 - Metrics: {'loss': '0.1019', 'accuracy': '0.9878', 'val_loss': '0.3123', 'val_accuracy': '0.9126'}
Epoch: 20 - Metrics: {'loss': '0.0866', 'accuracy': '0.9878', 'val_loss': '0.3068', 'val_accuracy': '0.9126'}
Epoch: 21 - Metrics: {'loss': '0.0733', 'accuracy': '0.9927', 'val_loss': '0.3090', 'val_accuracy': '0.9223'}
Epoch: 22 - Metrics: {'loss': '0.0631', 'accuracy': '0.9927', 'val_loss': '0.3116', 'val_accuracy': '0.9223'}
Epoch: 23 - Metrics: {'loss': '0.0539', 'accuracy': '0.9951', 'val_loss': '0.3160', 'val_accuracy': '0.9223'}
0.86328125

import torch
from torchvision import datasets
import matplotlib.pyplot as plt

from DLL.DeepLearning.Model import Model
from DLL.DeepLearning.Layers import Dense, Conv2D, Flatten, MaxPooling2D, Reshape, Add, Identity
from DLL.DeepLearning.Layers.Regularisation import Dropout, BatchNorm, GroupNorm, InstanceNorm, LayerNorm
from DLL.DeepLearning.Layers.Activations import ReLU, SoftMax
from DLL.DeepLearning.Losses import CCE
from DLL.DeepLearning.Optimisers import SGD, ADAM
from DLL.DeepLearning.Callbacks import EarlyStopping
from DLL.DeepLearning.Initialisers import Xavier_Normal, Xavier_Uniform, Kaiming_Normal, Kaiming_Uniform
from DLL.Data.Preprocessing import OneHotEncoder, data_split
from DLL.Data.Metrics import accuracy


device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

torch.manual_seed(0)
torch.cuda.manual_seed(0)

n = 512  # 60000
train_set = datasets.MNIST(root="./data", train=True, download=True)
test_set = datasets.MNIST(root="./data", train=False, download=True)
train_images = train_set.data[:n]
train_labels = train_set.targets[:n]
test_images = test_set.data[:n]
test_labels = test_set.targets[:n]
train_images = train_images.unsqueeze(1).to(dtype=torch.float32, device=device)
test_images = test_images.unsqueeze(1).to(dtype=torch.float32, device=device)
train_images = train_images / train_images.max()
test_images = test_images / test_images.max()
train_labels = train_labels.to(dtype=torch.float32, device=device)
test_labels = test_labels.to(dtype=torch.float32, device=device)

train_images, train_labels, validation_images, validation_labels, _, _ = data_split(train_images, train_labels, train_split=0.8, validation_split=0.2)

label_encoder = OneHotEncoder()
train_labels = label_encoder.fit_encode(train_labels)
validation_labels = label_encoder.encode(validation_labels)
test_labels = label_encoder.encode(test_labels)
print(train_images.shape, train_labels.shape, validation_images.shape, validation_labels.shape, test_images.shape, test_labels.shape)
print(train_labels[:2])

model = Model((1, 28, 28), device=device)
model.train()

model.add(Conv2D(kernel_size=3, output_depth=32, initialiser=Kaiming_Normal(), activation=ReLU()))
model.add(Add(Conv2D(kernel_size=3, output_depth=32, initialiser=Kaiming_Normal(), activation=ReLU()), Identity()))
model.add(Add(Conv2D(kernel_size=3, output_depth=32, initialiser=Kaiming_Normal(), activation=ReLU()), Identity()))
model.add(MaxPooling2D(pool_size=2))
model.add(BatchNorm())
model.add(Add(Conv2D(kernel_size=3, output_depth=32, initialiser=Kaiming_Normal(), activation=ReLU()), Identity()))
model.add(Add(Conv2D(kernel_size=3, output_depth=32, initialiser=Kaiming_Normal(), activation=ReLU()), Identity()))
model.add(Add(Conv2D(kernel_size=3, output_depth=32, initialiser=Kaiming_Normal(), activation=ReLU()), Identity()))
model.add(MaxPooling2D(pool_size=2))
model.add(BatchNorm())
model.add(Add(Conv2D(kernel_size=3, output_depth=32, initialiser=Kaiming_Normal(), activation=ReLU()), Identity()))
model.add(Add(Conv2D(kernel_size=3, output_depth=32, initialiser=Kaiming_Normal(), activation=ReLU()), Identity()))
model.add(Add(Conv2D(kernel_size=3, output_depth=32, initialiser=Kaiming_Normal(), activation=ReLU()), Identity()))
model.add(Conv2D(kernel_size=3, output_depth=8, initialiser=Kaiming_Normal(), activation=ReLU()))
model.add(BatchNorm())
model.add(Dropout(p=0.3))
model.add(Flatten())
model.add(Dense(200, activation=ReLU()))
model.add(Dense(10, activation=SoftMax()))

model.compile(optimiser=ADAM(learning_rate=0.001), loss=CCE(), metrics=["loss", "val_loss", "val_accuracy", "accuracy"], callbacks=(EarlyStopping(patience=3),))
model.summary()


history = model.fit(train_images, train_labels, val_data=(validation_images, validation_labels), epochs=25, batch_size=256, verbose=True)

model.eval()
test_pred = model.predict(test_images)
print(accuracy(test_pred, test_labels))

plt.figure(figsize=(8, 6))
plt.subplot(1, 2, 1)
plt.plot(history["val_loss"], label="validation loss")
plt.plot(history["loss"], label="loss")
plt.xlabel("Epoch")
plt.ylabel("Categorical cross entropy")
plt.legend()
plt.subplot(1, 2, 2)
plt.plot(history["val_accuracy"], label="validation accuracy")
plt.plot(history["accuracy"], label="accuracy")
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.legend()
plt.show()

fig, ax = plt.subplots(2, 2, figsize=(8, 8))
ax = ax.ravel()
for i in range(len(ax)):
    ax[i].imshow(test_images[i].numpy()[0], cmap='gray', vmin=0, vmax=1)
    ax[i].set_title(f"True label: {test_labels[i].argmax()} | Predicted label: {test_pred[i].argmax()}")
plt.show()

Total running time of the script: (0 minutes 43.705 seconds)

Gallery generated by Sphinx-Gallery