Source code for DLL.DeepLearning.Layers.Regularisation._InstanceNormalisation

from ._GroupNormalisation import GroupNorm


[docs] class InstanceNorm(GroupNorm): """ The instance normalisation layer for neural networks. Computes the group norm of a batch along axis=1 with the same number of groups as channels. """ def __init__(self, **kwargs): super().__init__(None, **kwargs) self.name = "Instance normalisation" def initialise_layer(self, **kwargs): """ :meta private: """ if "input_shape" not in kwargs or not isinstance(kwargs.get("input_shape"), tuple | list): raise ValueError(f"input_shape must be passed as a key word argument to layer.initialise_layer as a tuple or a list. Currently {kwargs.get('input_shape', None)}.") self.num_groups = kwargs.get("input_shape")[0] super().initialise_layer(**kwargs)
[docs] def forward(self, input, **kwargs): """ Normalises the input to have zero mean and one variance accross the channel dimension with the following equation: .. math:: y = \\gamma\\frac{x - \\mathbb{E}[x]}{\\sqrt{\\text{var}(x) + \\epsilon}} + \\beta, where :math:`x` is the input, :math:`\\mathbb{E}[x]` is the expected value or the mean accross the channel dimension, :math:`\\text{var}(x)` is the variance accross the variance accross the channel dimension, :math:`\\epsilon` is a small constant and :math:`\\gamma` and :math:`\\beta` are trainable parameters. Args: input (torch.Tensor of shape (batch_size, channels, ...)): The input to the layer. Must be a torch.Tensor of the spesified shape given by layer.input_shape. Returns: torch.Tensor: The output tensor after the normalisation with the same shape as the input. """ return super().forward(input, **kwargs)
[docs] def backward(self, dCdy, **kwargs): """ Calculates the gradient of the loss function with respect to the input of the layer. Also calculates the gradients of the loss function with respect to the model parameters. Args: dCdy (torch.Tensor of the same shape as returned from the forward method): The gradient given by the next layer. Returns: torch.Tensor of shape (batch_size, channels, ...): The new gradient after backpropagation through the layer. """ return super().backward(dCdy, **kwargs)