Source code for DLL.DeepLearning.Layers.Regularisation._GroupNormalisation

import torch

from ._BaseRegularisation import BaseRegularisation


[docs] class GroupNorm(BaseRegularisation): """ The group normalisation layer for neural networks. Computes the group norm of a batch along axis=1 Args: num_groups (int, optional): The number of groups used in the normalisation. Must be a positive integer. Defaults to 32. The number of channels must be evenly divisible by num_groups. If is set to 1, is identical to layer normalisation and if batch_size, is identical to the instance normalisation. """ def __init__(self, num_groups=32, **kwargs): if num_groups is not None and (not isinstance(num_groups, int) or num_groups <= 0): raise ValueError("num_groups must be a positive integer.") super().__init__(**kwargs) self.num_groups = num_groups self.epsilon = 1e-6 self.name = "Group normalisation" def initialise_layer(self, **kwargs): """ :meta private: """ super().initialise_layer(**kwargs) if self.output_shape[0] % self.num_groups != 0: raise ValueError("The number of channels must be evenly divisible by num_groups.") self.gamma = torch.ones(self.output_shape, device=self.device, dtype=self.data_type) self.beta = torch.zeros(self.output_shape, device=self.device, dtype=self.data_type) self.nparams = 2 * self.output_shape[0]
[docs] def forward(self, input, **kwargs): """ Normalises the input to have zero mean and one variance accross self.num_groups groups accross the channel dimension with the following equation: .. math:: y = \\gamma\\frac{x - \\mathbb{E}[x]}{\\sqrt{\\text{var}(x) + \\epsilon}} + \\beta, where :math:`x` is the input, :math:`\\mathbb{E}[x]` is the expected value or the mean accross each group, :math:`\\text{var}(x)` is the variance accross the variance accross each group, :math:`\\epsilon` is a small constant and :math:`\\gamma` and :math:`\\beta` are trainable parameters. Args: input (torch.Tensor of shape (batch_size, channels, ...)): The input to the layer. Must be a torch.Tensor of the spesified shape given by layer.input_shape. training (bool, optional): The boolean flag deciding if the model is in training mode. Defaults to False. Returns: torch.Tensor: The output tensor after the normalisation with the same shape as the input. """ if not isinstance(input, torch.Tensor): raise TypeError("input must be a torch.Tensor.") if input.shape[1:] != self.input_shape: raise ValueError(f"input is not the same shape as the spesified input_shape ({input.shape[1:], self.input_shape}).") elements_per_group = input.shape[1] // self.num_groups self.input = input batch_size = input.shape[0] self.input = input self.input_reshaped = self.input.view(batch_size, self.num_groups, elements_per_group, *input.shape[2:]) mean = 1.0 / elements_per_group * self.input_reshaped.sum(2, keepdim=True) self.x_centered = self.input_reshaped - mean self.x_centered_squared = self.x_centered ** 2 self.var = 1.0 / elements_per_group * self.x_centered_squared.sum(2, keepdim=True) # biased variance self.inv_std = (self.var + self.epsilon) ** -0.5 self.x_norm = self.x_centered * self.inv_std self.x_reshaped = self.x_norm.view(self.input.shape) self.output = self.x_reshaped * self.gamma + self.beta return self.output
[docs] def backward(self, dCdy, **kwargs): """ Calculates the gradient of the loss function with respect to the input of the layer. Also calculates the gradients of the loss function with respect to the model parameters. Args: dCdy (torch.Tensor of the same shape as returned from the forward method): The gradient given by the next layer. Returns: torch.Tensor of shape (batch_size, channels, ...): The new gradient after backpropagation through the layer. """ if not isinstance(dCdy, torch.Tensor): raise TypeError("dCdy must be a torch.Tensor.") if dCdy.shape[1:] != self.output.shape[1:]: raise ValueError(f"dCdy is not the same shape as the spesified output_shape ({dCdy.shape[1:], self.output.shape[1:]}).") batch_size = dCdy.shape[0] elements_per_group = self.output.shape[1] // self.num_groups dCdx_reshaped = dCdy * self.gamma dCdgamma = (dCdy * self.x_reshaped).mean(axis=0) dCdbeta = dCdy.mean(axis=0) self.gamma.grad += dCdgamma self.beta.grad += dCdbeta dCdx_norm = dCdx_reshaped.view(batch_size, self.num_groups, elements_per_group, *self.output.shape[2:]) dCdx_centered = dCdx_norm * self.inv_std dCdinv_std = (dCdx_norm * self.x_centered).sum(2, keepdim=True) dCdvar = -0.5 * ((self.var + self.epsilon) ** -1.5) * dCdinv_std dCdx_centered_squared = 1.0 / elements_per_group * torch.ones_like(self.x_centered_squared, device=self.device, dtype=self.data_type) * dCdvar dCdx_centered += 2 * self.x_centered * dCdx_centered_squared dCdinput_reshaped = dCdx_centered.clone() dCdmean = -(dCdx_centered).sum(2, keepdim=True) dCdinput_reshaped += 1.0 / elements_per_group * torch.ones_like(self.input_reshaped, device=self.device, dtype=self.data_type) * dCdmean dCdx = dCdinput_reshaped.view(self.output.shape) return dCdx.view(self.output.shape)
def get_parameters(self): """ :meta private: """ return (self.gamma, self.beta)