Note
Go to the end to download the full example code.
MNIST Image classification
This script implements a model to classify the MNIST dataset. The model mainly consists of convolutonal layers and pooling layers with a few dense layers at the end. As the script is only for demonstration purposes, only 100 first datapoints are used to make the training faster. For a full example, change the parameter n to 60000. If n is increased, more epochs may need to be added and other hyperparameters tuned.
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
0/11490434 ━━━━━━━━━━━━━━━━━━━━ 0s 0s/step
8396800/11490434 ━━━━━━━━━━━━━━━━━━━━ 0s 0us/step
11490434/11490434 ━━━━━━━━━━━━━━━━━━━━ 0s 0us/step
torch.Size([409, 1, 28, 28]) torch.Size([409, 10]) torch.Size([103, 1, 28, 28]) torch.Size([103, 10]) torch.Size([512, 1, 28, 28]) torch.Size([512, 10])
tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 0, 0, 0, 1, 0]])
Model summary:
Input - Output: ((1, 28, 28))
Conv2D - (Input, Output): ((1, 28, 28), (32, 26, 26)) - Parameters: 320
ReLU - Output: ((32, 26, 26))
MaxPooling2D - Output: ((32, 13, 13))
Layer normalisation - Output: ((32, 13, 13)) - Parameters: 64
Conv2D - (Input, Output): ((32, 13, 13), (32, 11, 11)) - Parameters: 9248
ReLU - Output: ((32, 11, 11))
MaxPooling2D - Output: ((32, 5, 5))
Group normalisation - Output: ((32, 5, 5)) - Parameters: 64
Dropout - Output: ((32, 5, 5)) - Keep probability: 0.8
Flatten - (Input, Output): ((32, 5, 5), 800)
Dense - (Input, Output): (800, 200) - Parameters: 160200
ReLU - Output: (200)
Dense - (Input, Output): (200, 10) - Parameters: 2010
Softmax - Output: (10)
Total number of parameters: 171906
Epoch: 1 - Metrics: {'loss': '1.8260', 'accuracy': '0.4841', 'val_loss': '2.0189', 'val_accuracy': '0.2718'}
Epoch: 2 - Metrics: {'loss': '1.5050', 'accuracy': '0.6284', 'val_loss': '1.6021', 'val_accuracy': '0.5437'}
Epoch: 3 - Metrics: {'loss': '1.1816', 'accuracy': '0.7311', 'val_loss': '1.2676', 'val_accuracy': '0.6699'}
Epoch: 4 - Metrics: {'loss': '0.9141', 'accuracy': '0.8191', 'val_loss': '0.9703', 'val_accuracy': '0.7573'}
Epoch: 5 - Metrics: {'loss': '0.7508', 'accuracy': '0.8289', 'val_loss': '0.7857', 'val_accuracy': '0.8350'}
Epoch: 6 - Metrics: {'loss': '0.5937', 'accuracy': '0.8557', 'val_loss': '0.6401', 'val_accuracy': '0.8932'}
Epoch: 7 - Metrics: {'loss': '0.5111', 'accuracy': '0.8753', 'val_loss': '0.5740', 'val_accuracy': '0.8544'}
Epoch: 8 - Metrics: {'loss': '0.4598', 'accuracy': '0.8778', 'val_loss': '0.5058', 'val_accuracy': '0.8835'}
Epoch: 9 - Metrics: {'loss': '0.4097', 'accuracy': '0.8924', 'val_loss': '0.4158', 'val_accuracy': '0.8835'}
Epoch: 10 - Metrics: {'loss': '0.3876', 'accuracy': '0.8802', 'val_loss': '0.4120', 'val_accuracy': '0.9029'}
Epoch: 11 - Metrics: {'loss': '0.3642', 'accuracy': '0.8875', 'val_loss': '0.3654', 'val_accuracy': '0.8835'}
Epoch: 12 - Metrics: {'loss': '0.3309', 'accuracy': '0.9022', 'val_loss': '0.3424', 'val_accuracy': '0.8932'}
Epoch: 13 - Metrics: {'loss': '0.2965', 'accuracy': '0.9193', 'val_loss': '0.3313', 'val_accuracy': '0.8932'}
Epoch: 14 - Metrics: {'loss': '0.2704', 'accuracy': '0.9242', 'val_loss': '0.3272', 'val_accuracy': '0.9029'}
Epoch: 15 - Metrics: {'loss': '0.2581', 'accuracy': '0.9364', 'val_loss': '0.3633', 'val_accuracy': '0.9029'}
Epoch: 16 - Metrics: {'loss': '0.2316', 'accuracy': '0.9389', 'val_loss': '0.3745', 'val_accuracy': '0.8932'}
Epoch: 17 - Metrics: {'loss': '0.2222', 'accuracy': '0.9364', 'val_loss': '0.3456', 'val_accuracy': '0.9029'}
0.787109375
import torch
import matplotlib.pyplot as plt
import tensorflow as tf
from DLL.DeepLearning.Model import Model
from DLL.DeepLearning.Layers import Dense, Conv2D, Flatten, MaxPooling2D, Reshape
from DLL.DeepLearning.Layers.Regularisation import Dropout, BatchNorm, GroupNorm, InstanceNorm, LayerNorm
from DLL.DeepLearning.Layers.Activations import ReLU, SoftMax
from DLL.DeepLearning.Losses import CCE
from DLL.DeepLearning.Optimisers import SGD, ADAM
from DLL.DeepLearning.Callbacks import EarlyStopping
from DLL.DeepLearning.Initialisers import Xavier_Normal, Xavier_Uniform, Kaiming_Normal, Kaiming_Uniform
from DLL.Data.Preprocessing import OneHotEncoder, data_split
from DLL.Data.Metrics import accuracy
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
torch.manual_seed(0)
torch.cuda.manual_seed(0)
n = 512 # 60000
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()
train_images = torch.from_numpy(train_images).to(dtype=torch.float32, device=device).reshape(60000, 1, 28, 28)[:n]
train_labels = torch.from_numpy(train_labels).to(dtype=torch.float32, device=device)[:n]
test_images = torch.from_numpy(test_images).to(dtype=torch.float32, device=device).reshape(10000, 1, 28, 28)[:n]
test_labels = torch.from_numpy(test_labels).to(dtype=torch.float32, device=device)[:n]
train_images = train_images / train_images.max()
test_images = test_images / test_images.max()
train_images, train_labels, validation_images, validation_labels, _, _ = data_split(train_images, train_labels, train_split=0.8, validation_split=0.2)
label_encoder = OneHotEncoder()
train_labels = label_encoder.fit_encode(train_labels)
validation_labels = label_encoder.encode(validation_labels)
test_labels = label_encoder.encode(test_labels)
print(train_images.shape, train_labels.shape, validation_images.shape, validation_labels.shape, test_images.shape, test_labels.shape)
print(train_labels[:2])
model = Model((1, 28, 28), device=device)
model.add(Conv2D(kernel_size=3, output_depth=32, initialiser=Kaiming_Normal(), activation=ReLU()))
model.add(MaxPooling2D(pool_size=2))
model.add(LayerNorm())
# model.add(BatchNorm())
model.add(Conv2D(kernel_size=3, output_depth=32, initialiser=Kaiming_Uniform(), activation=ReLU()))
model.add(MaxPooling2D(pool_size=2))
# model.add(InstanceNorm())
# model.add(LayerNorm())
model.add(GroupNorm(num_groups=16))
model.add(Dropout(p=0.2))
model.add(Flatten())
# model.add(Reshape(800))
model.add(Dense(200, activation=ReLU()))
model.add(Dense(10, activation=SoftMax()))
model.compile(optimiser=ADAM(learning_rate=0.001), loss=CCE(), metrics=["loss", "val_loss", "val_accuracy", "accuracy"], callbacks=(EarlyStopping(patience=3),))
model.summary()
history = model.fit(train_images, train_labels, val_data=(validation_images, validation_labels), epochs=25, batch_size=256, verbose=True)
plt.figure(figsize=(8, 6))
plt.subplot(1, 2, 1)
plt.plot(history["val_loss"], label="validation loss")
plt.plot(history["loss"], label="loss")
plt.xlabel("Epoch")
plt.ylabel("Categorical cross entropy")
plt.legend()
plt.subplot(1, 2, 2)
plt.plot(history["val_accuracy"], label="validation accuracy")
plt.plot(history["accuracy"], label="accuracy")
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.legend()
test_pred = model.predict(test_images)
print(accuracy(test_pred, test_labels))
plt.show()
fig, ax = plt.subplots(2, 2, figsize=(8, 8))
ax = ax.ravel()
for i in range(len(ax)):
ax[i].imshow(test_images[i].numpy()[0], cmap='gray', vmin=0, vmax=1)
ax[i].set_title(f"True label: {test_labels[i].argmax()} | Predicted label: {test_pred[i].argmax()}")
plt.show()
Total running time of the script: (1 minutes 19.348 seconds)