# Copyright 2020 Toyota Research Institute. All rights reserved.
import torch
import torch.nn.functional as funct
from functools import lru_cache
from PIL import Image
from packnet_sfm.utils.misc import same_shape
########################################################################################################################
[docs]def load_image(path):
"""
Read an image using PIL
Parameters
----------
path : str
Path to the image
Returns
-------
image : PIL.Image
Loaded image
"""
return Image.open(path)
########################################################################################################################
[docs]def flip_lr(image):
"""
Flip image horizontally
Parameters
----------
image : torch.Tensor [B,3,H,W]
Image to be flipped
Returns
-------
image_flipped : torch.Tensor [B,3,H,W]
Flipped image
"""
assert image.dim() == 4, 'You need to provide a [B,C,H,W] image to flip'
return torch.flip(image, [3])
[docs]def flip_model(model, image, flip):
"""
Flip input image and flip output inverse depth map
Parameters
----------
model : nn.Module
Module to be used
image : torch.Tensor [B,3,H,W]
Input image
flip : bool
True if the flip is happening
Returns
-------
inv_depths : list of torch.Tensor [B,1,H,W]
List of predicted inverse depth maps
"""
if flip:
return [flip_lr(inv_depth) for inv_depth in model(flip_lr(image))]
else:
return model(image)
########################################################################################################################
[docs]def gradient_x(image):
"""
Calculates the gradient of an image in the x dimension
Parameters
----------
image : torch.Tensor [B,3,H,W]
Input image
Returns
-------
gradient_x : torch.Tensor [B,3,H,W-1]
Gradient of image with respect to x
"""
return image[:, :, :, :-1] - image[:, :, :, 1:]
[docs]def gradient_y(image):
"""
Calculates the gradient of an image in the y dimension
Parameters
----------
image : torch.Tensor [B,3,H,W]
Input image
Returns
-------
gradient_y : torch.Tensor [B,3,H-1,W]
Gradient of image with respect to y
"""
return image[:, :, :-1, :] - image[:, :, 1:, :]
########################################################################################################################
[docs]def interpolate_image(image, shape, mode='bilinear', align_corners=True):
"""
Interpolate an image to a different resolution
Parameters
----------
image : torch.Tensor [B,?,h,w]
Image to be interpolated
shape : tuple (H, W)
Output shape
mode : str
Interpolation mode
align_corners : bool
True if corners will be aligned after interpolation
Returns
-------
image : torch.Tensor [B,?,H,W]
Interpolated image
"""
# Take last two dimensions as shape
if len(shape) > 2:
shape = shape[-2:]
# If the shapes are the same, do nothing
if same_shape(image.shape[-2:], shape):
return image
else:
# Interpolate image to match the shape
return funct.interpolate(image, size=shape, mode=mode,
align_corners=align_corners)
[docs]def interpolate_scales(images, shape=None, mode='bilinear', align_corners=False):
"""
Interpolate list of images to the same shape
Parameters
----------
images : list of torch.Tensor [B,?,?,?]
Images to be interpolated, with different resolutions
shape : tuple (H, W)
Output shape
mode : str
Interpolation mode
align_corners : bool
True if corners will be aligned after interpolation
Returns
-------
images : list of torch.Tensor [B,?,H,W]
Interpolated images, with the same resolution
"""
# If no shape is provided, interpolate to highest resolution
if shape is None:
shape = images[0].shape
# Take last two dimensions as shape
if len(shape) > 2:
shape = shape[-2:]
# Interpolate all images
return [funct.interpolate(image, shape, mode=mode,
align_corners=align_corners) for image in images]
[docs]def match_scales(image, targets, num_scales,
mode='bilinear', align_corners=True):
"""
Interpolate one image to produce a list of images with the same shape as targets
Parameters
----------
image : torch.Tensor [B,?,h,w]
Input image
targets : list of torch.Tensor [B,?,?,?]
Tensors with the target resolutions
num_scales : int
Number of considered scales
mode : str
Interpolation mode
align_corners : bool
True if corners will be aligned after interpolation
Returns
-------
images : list of torch.Tensor [B,?,?,?]
List of images with the same resolutions as targets
"""
# For all scales
images = []
image_shape = image.shape[-2:]
for i in range(num_scales):
target_shape = targets[i].shape
# If image shape is equal to target shape
if same_shape(image_shape, target_shape):
images.append(image)
else:
# Otherwise, interpolate
images.append(interpolate_image(
image, target_shape, mode=mode, align_corners=align_corners))
# Return scaled images
return images
########################################################################################################################
[docs]@lru_cache(maxsize=None)
def meshgrid(B, H, W, dtype, device, normalized=False):
"""
Create meshgrid with a specific resolution
Parameters
----------
B : int
Batch size
H : int
Height size
W : int
Width size
dtype : torch.dtype
Meshgrid type
device : torch.device
Meshgrid device
normalized : bool
True if grid is normalized between -1 and 1
Returns
-------
xs : torch.Tensor [B,1,W]
Meshgrid in dimension x
ys : torch.Tensor [B,H,1]
Meshgrid in dimension y
"""
if normalized:
xs = torch.linspace(-1, 1, W, device=device, dtype=dtype)
ys = torch.linspace(-1, 1, H, device=device, dtype=dtype)
else:
xs = torch.linspace(0, W-1, W, device=device, dtype=dtype)
ys = torch.linspace(0, H-1, H, device=device, dtype=dtype)
ys, xs = torch.meshgrid([ys, xs])
return xs.repeat([B, 1, 1]), ys.repeat([B, 1, 1])
[docs]@lru_cache(maxsize=None)
def image_grid(B, H, W, dtype, device, normalized=False):
"""
Create an image grid with a specific resolution
Parameters
----------
B : int
Batch size
H : int
Height size
W : int
Width size
dtype : torch.dtype
Meshgrid type
device : torch.device
Meshgrid device
normalized : bool
True if grid is normalized between -1 and 1
Returns
-------
grid : torch.Tensor [B,3,H,W]
Image grid containing a meshgrid in x, y and 1
"""
xs, ys = meshgrid(B, H, W, dtype, device, normalized=normalized)
ones = torch.ones_like(xs)
grid = torch.stack([xs, ys, ones], dim=1)
return grid
########################################################################################################################