The time for training a computer vision model can be quite long which leads to a slower PDCA cycle. One way for speeding up training is caching the dataset before starting the training.
When you load your data to the data loader, every time it reaches out for the next mini-batch, it reads the image and applies the desired preprocessing method. This can become time-consuming if you have a lot of high-resolution data. So instead of repeating this process over and over again, you can cache your data beforehand.
Here is how you can do it.
import torch
from skimage.io import imread
#To Inherit Dataset Class
from torch.utils import data
# For Progress Bar
from tqdm import tqdm
class CacheCustomDataSet(data.Dataset):
def __init__(self,
inputs: list,
targets: list,
transform=None,
use_cache=False, #Flag for whether to cache data
pre_transform=None,
):
#Input Image
self.inputs = inputs
#Annotation Image
self.targets = targets
#Preprocessing
self.transform = transform
#Cache Data
self.use_cache = use_cache
#Preprocessing for cach data
self.pre_transform = pre_transform
if self.use_cache:
self.cached_data = []
progressbar = tqdm(range(len(self.inputs)), desc='Caching')
#cache input and annotation image set
for i, img_name, tar_name in zip(progressbar, self.inputs, self.targets):
img, tar = imread(str(img_name)), imread(str(tar_name))
#Apply preprocessing to Cache data
if self.pre_transform is not None:
img, tar = self.pre_transform(img, tar)
#Cache
self.cached_data.append((img, tar))
def __len__(self):
return len(self.inputs)
def __getitem__(self,
index: int):
#Use Cache Data when specified
if self.use_cache:
x, y = self.cached_data[index]
else:
input_ID = self.inputs[index]
target_ID = self.targets[index]
x, y = imread(str(input_ID)), imread(str(target_ID))
if self.transform is not None:
x, y = self.transform(x, y)
return x, y
By using the dataset class defined above, now you can train with cached data
dataset_train = CacheCustomDataSet(inputs=inputs_train,
targets=targets_train,
transform=transforms_training,
use_cache=True,
pre_transform=pre_transforms)
# dataloader training
dataloader_training = DataLoader(dataset=dataset_train,
batch_size=2,
shuffle=True)