Source code for packnet_sfm.datasets.transforms
# Copyright 2020 Toyota Research Institute. All rights reserved.
from functools import partial
from packnet_sfm.datasets.augmentations import resize_image, resize_sample, \
duplicate_sample, colorjitter_sample, to_tensor_sample
########################################################################################################################
[docs]def train_transforms(sample, image_shape, jittering):
"""
Training data augmentation transformations
Parameters
----------
sample : dict
Sample to be augmented
image_shape : tuple (height, width)
Image dimension to reshape
jittering : tuple (brightness, contrast, saturation, hue)
Color jittering parameters
Returns
-------
sample : dict
Augmented sample
"""
if len(image_shape) > 0:
sample = resize_sample(sample, image_shape)
sample = duplicate_sample(sample)
if len(jittering) > 0:
sample = colorjitter_sample(sample, jittering)
sample = to_tensor_sample(sample)
return sample
[docs]def validation_transforms(sample, image_shape):
"""
Validation data augmentation transformations
Parameters
----------
sample : dict
Sample to be augmented
image_shape : tuple (height, width)
Image dimension to reshape
Returns
-------
sample : dict
Augmented sample
"""
if len(image_shape) > 0:
sample['rgb'] = resize_image(sample['rgb'], image_shape)
sample = to_tensor_sample(sample)
return sample
[docs]def test_transforms(sample, image_shape):
"""
Test data augmentation transformations
Parameters
----------
sample : dict
Sample to be augmented
image_shape : tuple (height, width)
Image dimension to reshape
Returns
-------
sample : dict
Augmented sample
"""
if len(image_shape) > 0:
sample['rgb'] = resize_image(sample['rgb'], image_shape)
sample = to_tensor_sample(sample)
return sample
[docs]def get_transforms(mode, image_shape, jittering, **kwargs):
"""
Get data augmentation transformations for each split
Parameters
----------
mode : str {'train', 'validation', 'test'}
Mode from which we want the data augmentation transformations
image_shape : tuple (height, width)
Image dimension to reshape
jittering : tuple (brightness, contrast, saturation, hue)
Color jittering parameters
Returns
-------
XXX_transform: Partial function
Data augmentation transformation for that mode
"""
if mode == 'train':
return partial(train_transforms,
image_shape=image_shape,
jittering=jittering)
elif mode == 'validation':
return partial(validation_transforms,
image_shape=image_shape)
elif mode == 'test':
return partial(test_transforms,
image_shape=image_shape)
else:
raise ValueError('Unknown mode {}'.format(mode))
########################################################################################################################