Source code for DLL.Data._DataReader

import torch


[docs] class DataReader: """ The data reader. Args: X (torch.Tensor of shape (n_samples, ...)): The input data. Y (torch.Tensor of shape (n_samples, ...)): The target labels or values. batch_size (int, optional): The batch size. If the batch size is larger than the number of samples, the maximum number of samples is used as the batch_size. Must be a positive integer. Defaults to 64. shuffle (bool, optional): Determines if the input data is shuffled in the beginning. Defaults to True. shuffle_every_epoch (bool, optional): Determines if the input data is shuffled every time all data points are used. Defaults to False. """ def __init__(self, X, Y, batch_size=64, shuffle=True, shuffle_every_epoch=False): if not isinstance(X, torch.Tensor): raise TypeError("X must be a torch.Tensor.") if not isinstance(Y, torch.Tensor): raise TypeError("Y must be a torch.Tensor.") if len(X) != len(Y): raise ValueError("X and Y must have the same number of samples.") if not isinstance(batch_size, int) or batch_size <= 0: raise ValueError("batch_size must be positive integer.") if not isinstance(shuffle, bool): raise TypeError("shuffle must be a boolean.") if not isinstance(shuffle_every_epoch, bool): raise TypeError("shuffle_every_epoch must be a boolean.") self.data_length = Y.size(0) self.shuffle = shuffle self.shuffle_every_epoch = shuffle_every_epoch if self.shuffle: self.perm = torch.randperm(self.data_length, device=X.device) self.X = X.index_select(0, self.perm) self.Y = Y.index_select(0, self.perm) else: self.X = X self.Y = Y self.batch_size = batch_size if len(Y) >= batch_size else len(Y)
[docs] def get_data(self): """ A generator giving going through the entire dataset. Yields: tuple[torch.Tensor, torch.Tensor]: X_batch, y_batch """ iteration = 0 while iteration * self.batch_size < self.data_length: yield self.X[iteration * self.batch_size:(iteration + 1) * self.batch_size], self.Y[iteration * self.batch_size:(iteration + 1) * self.batch_size] iteration += 1 if self.shuffle_every_epoch and self.shuffle: self.perm = torch.randperm(self.data_length, device=self.X.device) self.X = self.X.index_select(0, self.perm) self.Y = self.Y.index_select(0, self.perm)