For the sake of readability and ease of use, the best approach to applying transforms to Torchvision datasets is to pass all transforms to the transform parameter of the initializing function during import.
For the CIFAR10 dataset, that would look like this.
cifar_trainset = datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform)
However, this will not yet work as we have not yet imported torch nor have we defined the single object labeled train_transform that is being passed to the transform parameter.
The torchvision.transform’s class that allows us to create this object is transforms.compose.
In order to use transforms.compose, first we will want to import torch,
import torch
torchvision,
import torchvision
torchvision.datasets as datasets,
import torchvision.datasets as datasets
and torchvision.transforms as transforms.
import torchvision.transforms as transforms
We also want to check that our installed versions of torch and torchvision are current.
print(torch.__version__)
print(torchvision.__version__)
For use this example, I will redefine the normalize transform.
normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
Now, I define the train transform as train_transform equals transforms.Compose with a list of desired transforms starting with a RandomCrop, followed by the ToTensor transform, then followed by our custom normalize transform.
train_transform = transforms.Compose(
[transforms.RandomCrop(32, 4),
transforms.ToTensor(),
normalize])
Transforms.compose takes a list of transform objects as an argument and returns a single object that represents all the listed transforms chained together in order.
In this case, the train transform will randomly crop all of the dataset images, convert them to tensors, and then normalize them.
As I mentioned, the transforms are applied in order.
This is important because any transforms that take images as arguments, i.e., RandomCrop, need to be listed before the ToTensor transform while any that take tensors as argument need to be listed after the ToTensor transform, i.e., normalize.
Once the transforms have been composed into a single transform object, we can pass that object to the transform parameter of our import function as shown earlier.
cifar_trainset = datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform)
Now, every image of the dataset will be modified in the desired way.