For CIFAR-10 data augmentations using torchvision transforms. torchvision version: '0.15.2+cu117' and torch version: 2.0.1+cu117
strength = 0.2
color_jitter = transforms.ColorJitter(
brightness = 0.8 * strength, contrast = 0.8 * strength,
saturation = 0.8 * strength, hue = 0.2 * strength
)
rand_color_jitter = transforms.RandomApply([color_jitter], p = 0.8)
rand_gray = transforms.RandomGrayscale(p = 0.2)
color_distortion = transforms.Compose(
[
rand_color_jitter,
rand_gray
]
)
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
train_data_augmentation = transforms.Compose(
[
transforms.RandomResizedCrop(size = 32, scale = (0.14, 1.0)),
transforms.GaussianBlur(kernel_size = (3, 3), sigma = (0.1, 2.0)),
color_distortion,
# v2.ToImage(),
# v2.ToDtype(torch.float32, scale = True),
# v2.ToDtype(torch.float32),
# v2.Normalize(),
transforms.ToTensor(),
transforms.Normalize(mean = mean, std = std)
]
)
I use these augmentations inside a datasets class:
class Cifar10Dataset(torchvision.datasets.CIFAR10):
def __init__(
self, root = "~/data/cifar10",
train = True, download = True,
transform = None
):
super().__init__(
root = root, train = train,
download = download, transform = transform
)
def __getitem__(self, index):
image, label = self.data[index], self.targets[index]
if self.transform is not None:
image = self.transform(image)
# image = transformed["image"]
# Randomly select 0, 1, 2 or 3 for image rotation-
ang = np.random.randint(low = 0, high = 4, size = None)
image = TF.rotate(img = image, angle = ang * 90)
return image, ang
train_dataset = Cifar10Dataset(
root = some_path, train = True,
download = True, transform = train_data_augmentation
)
train_loader = torch.utils.data.DataLoader(
dataset = train_dataset, batch_size = 128,
shuffle = True
)
x, y = next(iter(training_loader))
Which throws the error:
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Cell In[124], line 1
----> 1 x, y = next(iter(train_loader))
File ~/anaconda3/envs/lightning_cuda/lib/python3.11/site-packages/torch/utils/data/dataloader.py:633, in _BaseDataLoaderIter.__next__(self)
630 if self._sampler_iter is None:
631 # TODO(https://github.com/pytorch/pytorch/issues/76750)
632 self._reset() # type: ignore[call-arg]
--> 633 data = self._next_data()
634 self._num_yielded += 1
635 if self._dataset_kind == _DatasetKind.Iterable and \
636 self._IterableDataset_len_called is not None and \
637 self._num_yielded > self._IterableDataset_len_called:
File ~/anaconda3/envs/lightning_cuda/lib/python3.11/site-packages/torch/utils/data/dataloader.py:677, in _SingleProcessDataLoaderIter._next_data(self)
675 def _next_data(self):
676 index = self._next_index() # may raise StopIteration
--> 677 data = self._dataset_fetcher.fetch(index) # may raise StopIteration
678 if self._pin_memory:
679 data = _utils.pin_memory.pin_memory(data, self._pin_memory_device)
File ~/anaconda3/envs/lightning_cuda/lib/python3.11/site-packages/torch/utils/data/_utils/fetch.py:51, in _MapDatasetFetcher.fetch(self, possibly_batched_index)
49 data = self.dataset.__getitems__(possibly_batched_index)
50 else:
---> 51 data = [self.dataset[idx] for idx in possibly_batched_index]
52 else:
53 data = self.dataset[possibly_batched_index]
File ~/anaconda3/envs/lightning_cuda/lib/python3.11/site-packages/torch/utils/data/_utils/fetch.py:51, in <listcomp>(.0)
49 data = self.dataset.__getitems__(possibly_batched_index)
50 else:
---> 51 data = [self.dataset[idx] for idx in possibly_batched_index]
52 else:
53 data = self.dataset[possibly_batched_index]
Cell In[121], line 16, in Cifar10Dataset.__getitem__(self, index)
13 image, label = self.data[index], self.targets[index]
15 if self.transform is not None:
---> 16 image = self.transform(image)
17 # image = transformed["image"]
18
19 # Randomly select 0, 1, 2 or 3 for image rotation-
20 ang = np.random.randint(low = 0, high = 4, size = None)
File ~/anaconda3/envs/lightning_cuda/lib/python3.11/site-packages/torchvision/transforms/transforms.py:95, in Compose.__call__(self, img)
93 def __call__(self, img):
94 for t in self.transforms:
---> 95 img = t(img)
96 return img
File ~/anaconda3/envs/lightning_cuda/lib/python3.11/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
1496 # If we don't have any hooks, we want to skip the rest of the logic in
1497 # this function, and just call forward.
1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1499 or _global_backward_pre_hooks or _global_backward_hooks
1500 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501 return forward_call(*args, **kwargs)
1502 # Do not call functions when jit is used
1503 full_backward_hooks, non_full_backward_hooks = [], []
File ~/anaconda3/envs/lightning_cuda/lib/python3.11/site-packages/torchvision/transforms/transforms.py:979, in RandomResizedCrop.forward(self, img)
971 def forward(self, img):
972 """
973 Args:
974 img (PIL Image or Tensor): Image to be cropped and resized.
(...)
977 PIL Image or Tensor: Randomly cropped and resized image.
978 """
--> 979 i, j, h, w = self.get_params(img, self.scale, self.ratio)
980 return F.resized_crop(img, i, j, h, w, self.size, self.interpolation, antialias=self.antialias)
File ~/anaconda3/envs/lightning_cuda/lib/python3.11/site-packages/torchvision/transforms/transforms.py:940, in RandomResizedCrop.get_params(img, scale, ratio)
927 @staticmethod
928 def get_params(img: Tensor, scale: List[float], ratio: List[float]) -> Tuple[int, int, int, int]:
929 """Get parameters for ``crop`` for a random sized crop.
930
931 Args:
(...)
938 sized crop.
939 """
--> 940 _, height, width = F.get_dimensions(img)
941 area = height * width
943 log_ratio = torch.log(torch.tensor(ratio))
File ~/anaconda3/envs/lightning_cuda/lib/python3.11/site-packages/torchvision/transforms/functional.py:78, in get_dimensions(img)
75 if isinstance(img, torch.Tensor):
76 return F_t.get_dimensions(img)
---> 78 return F_pil.get_dimensions(img)
File ~/anaconda3/envs/lightning_cuda/lib/python3.11/site-packages/torchvision/transforms/_functional_pil.py:31, in get_dimensions(img)
29 width, height = img.size
30 return [channels, height, width]
---> 31 raise TypeError(f"Unexpected type {type(img)}")
TypeError: Unexpected type <class 'numpy.ndarray'>
Shouldn't RandomResizedCrop() be able to convert np array to torch tensor? What am I missing?
No, torch.tensors
and numpy.arrays
are not fully interchangeable, even though they can be used as such in many cases. (As far as I know, this has something to do with the fact that torch
needs to handle ownership across many devices.)
The documentation for RandomResizedCrop
does state that the only accepted input types are PIL.Image
and torch.Tensor
. So you need to convert your images to tensors first.
What you usually want to do is run as many of the augmentations on ByteTensors
as possible, and then do scaling and normalization in the end. Using v2
transforms, you're probably looking for something like this:
v2.Compose([
v2.ToImageTensor(), # [H,W,C] NDArray[uint8] -> [C,H,W] ByteTensor
v2.RandomResizedCrop(),
v2.GaussianBlur(),
v2.ColorJitter(),
v2.ConvertDtype(), # ByteTensor (0, 255) -> FloatTensor (0, 1)
v2.Normalize(),
])
or for torchvision
versions >=0.16
v2.Compose([
v2.ToImage(), # [H,W,C] NDArray[uint8] -> [C,H,W] ByteTensor
v2.RandomResizedCrop(),
v2.GaussianBlur(),
v2.ColorJitter(),
v2.ToDtype(), # ByteTensor (0, 255) -> FloatTensor (0, 1)
v2.Normalize(),
])