141. Caching Dataset For Faster Training

The time for training a computer vision model can be quite long which leads to a slower PDCA cycle. One way for speeding up training is caching the dataset before starting the training.

When you load your data to the data loader, every time it reaches out for the next mini-batch, it reads the image and applies the desired preprocessing method. This can become time-consuming if you have a lot of high-resolution data. So instead of repeating this process over and over again, you can cache your data beforehand.

Here is how you can do it.


import torch from skimage.io import imread #To Inherit Dataset Class from torch.utils import data # For Progress Bar from tqdm import tqdm class CacheCustomDataSet(data.Dataset): def __init__(self, inputs: list, targets: list, transform=None, use_cache=False, #Flag for whether to cache data pre_transform=None, ): #Input Image self.inputs = inputs #Annotation Image self.targets = targets #Preprocessing self.transform = transform #Cache Data self.use_cache = use_cache #Preprocessing for cach data self.pre_transform = pre_transform if self.use_cache: self.cached_data = [] progressbar = tqdm(range(len(self.inputs)), desc='Caching') #cache input and annotation image set for i, img_name, tar_name in zip(progressbar, self.inputs, self.targets): img, tar = imread(str(img_name)), imread(str(tar_name)) #Apply preprocessing to Cache data if self.pre_transform is not None: img, tar = self.pre_transform(img, tar) #Cache self.cached_data.append((img, tar)) def __len__(self): return len(self.inputs) def __getitem__(self, index: int): #Use Cache Data when specified if self.use_cache: x, y = self.cached_data[index] else: input_ID = self.inputs[index] target_ID = self.targets[index] x, y = imread(str(input_ID)), imread(str(target_ID)) if self.transform is not None: x, y = self.transform(x, y) return x, y

By using the dataset class defined above, now you can train with cached data

dataset_train = CacheCustomDataSet(inputs=inputs_train,
                                    targets=targets_train,
                                    transform=transforms_training,
                                    use_cache=True,
                                    pre_transform=pre_transforms)

# dataloader training
dataloader_training = DataLoader(dataset=dataset_train,
                                 batch_size=2,
                                 shuffle=True)