Source code for DLL.DeepLearning.Layers.Regularisation._BatchNormalisation

import torch
import numpy as np

from ._BaseRegularisation import BaseRegularisation


[docs] class BatchNorm(BaseRegularisation): """ The batch normalisation layer for neural networks. Args: patience (float, optional): The number deciding how fast the mean and variance in training. Must be strictly between 0 and 1. Defaults to 0.9. """ def __init__(self, patience=0.9, **kwargs): if patience <= 0 or 1 <= patience: raise ValueError("patience must be strictly between 0 and 1.") super().__init__(**kwargs) self.patience = patience self.epsilon = 1e-6 self.name = "Batch normalisation" def initialise_layer(self, **kwargs): """ :meta private: """ super().initialise_layer(**kwargs) self.gamma = torch.ones(self.output_shape, dtype=self.data_type, device=self.device) self.beta = torch.zeros(self.output_shape, dtype=self.data_type, device=self.device) self.running_var = torch.ones(self.output_shape, dtype=self.data_type, device=self.device) self.running_mean = torch.zeros(self.output_shape, dtype=self.data_type, device=self.device) self.nparams = 2 * np.prod(self.output_shape)
[docs] def forward(self, input, training=False, **kwargs): """ Normalises the input to have zero mean and one variance with the following equation: .. math:: y = \\gamma\\frac{x - \\mathbb{E}[x]}{\\sqrt{\\text{var}(x) + \\epsilon}} + \\beta, where :math:`x` is the input, :math:`\\mathbb{E}[x]` is the expected value or the mean accross the batch dimension, :math:`\\text{var}(x)` is the variance accross the variance accross the batch dimension, :math:`\\epsilon` is a small constant and :math:`\\gamma` and :math:`\\beta` are trainable parameters. Args: input (torch.Tensor of shape (batch_size, channels, ...)): The input to the layer. Must be a torch.Tensor of the spesified shape given by layer.input_shape. training (bool, optional): The boolean flag deciding if the model is in training mode. Defaults to False. Returns: torch.Tensor: The output tensor after the normalisation with the same shape as the input. """ if not isinstance(input, torch.Tensor): raise TypeError("input must be a torch.Tensor.") if input.shape[1:] != self.input_shape: raise ValueError(f"input is not the same shape as the spesified input_shape ({input.shape[1:], self.input_shape}).") if not isinstance(training, bool): raise TypeError("training must be a boolean.") if input.shape[0] <= 1: raise ValueError("The batch size must be atleast 2.") self.input = input if training: mean = torch.mean(input, axis=0) variance = torch.var(input, axis=0, unbiased=True) self.std = torch.sqrt(variance + self.epsilon) self.running_mean = self.patience * self.running_mean + (1 - self.patience) * mean self.running_var = self.patience * self.running_var + (1 - self.patience) * variance self.x_centered = (self.input - mean) self.x_norm = self.x_centered / self.std self.output = self.gamma * self.x_norm + self.beta else: self.output = self.gamma * ((self.input - self.running_mean) / torch.sqrt(self.running_var + self.epsilon)) + self.beta return self.output
[docs] def backward(self, dCdy, **kwargs): """ Calculates the gradient of the loss function with respect to the input of the layer. Also calculates the gradients of the loss function with respect to the model parameters. Args: dCdy (torch.Tensor of the same shape as returned from the forward method): The gradient given by the next layer. Returns: torch.Tensor of shape (batch_size, channels, ...): The new gradient after backpropagation through the layer. """ if not isinstance(dCdy, torch.Tensor): raise TypeError("dCdy must be a torch.Tensor.") if dCdy.shape[1:] != self.output.shape[1:]: raise ValueError(f"dCdy is not the same shape as the spesified output_shape ({dCdy.shape[1:], self.output.shape[1:]}).") batch_size = self.output.shape[0] dCdx_norm = dCdy * self.gamma dCdgamma = (dCdy * self.x_norm).mean(axis=0) dCdbeta = dCdy.mean(axis=0) dCdvar = (dCdx_norm * self.x_centered * -self.std**(-3) / 2).sum(axis=0) dCdmean = -((dCdx_norm / self.std).sum(axis=0) + dCdvar * (2 / batch_size) * self.x_centered.sum(axis=0)) dCdx = dCdx_norm / self.std + dCdvar * 2 * self.x_centered / (batch_size - 1) + dCdmean / batch_size self.gamma.grad += dCdgamma self.beta.grad += dCdbeta return dCdx
def get_parameters(self): """ :meta private: """ return (self.gamma, self.beta)