Source code for packnet_sfm.losses.multiview_photometric_loss

# Copyright 2020 Toyota Research Institute.  All rights reserved.

import torch
import torch.nn as nn

from packnet_sfm.utils.image import match_scales
from packnet_sfm.geometry.camera import Camera
from packnet_sfm.geometry.camera_utils import view_synthesis
from packnet_sfm.utils.depth import calc_smoothness, inv2depth
from packnet_sfm.losses.loss_base import LossBase, ProgressiveScaling

########################################################################################################################

[docs]def SSIM(x, y, C1=1e-4, C2=9e-4, kernel_size=3, stride=1): """ Structural SIMilarity (SSIM) distance between two images. Parameters ---------- x,y : torch.Tensor [B,3,H,W] Input images C1,C2 : float SSIM parameters kernel_size,stride : int Convolutional parameters Returns ------- ssim : torch.Tensor [1] SSIM distance """ pool2d = nn.AvgPool2d(kernel_size, stride=stride) refl = nn.ReflectionPad2d(1) x, y = refl(x), refl(y) mu_x = pool2d(x) mu_y = pool2d(y) mu_x_mu_y = mu_x * mu_y mu_x_sq = mu_x.pow(2) mu_y_sq = mu_y.pow(2) sigma_x = pool2d(x.pow(2)) - mu_x_sq sigma_y = pool2d(y.pow(2)) - mu_y_sq sigma_xy = pool2d(x * y) - mu_x_mu_y v1 = 2 * sigma_xy + C2 v2 = sigma_x + sigma_y + C2 ssim_n = (2 * mu_x_mu_y + C1) * v1 ssim_d = (mu_x_sq + mu_y_sq + C1) * v2 ssim = ssim_n / ssim_d return ssim
########################################################################################################################
[docs]class MultiViewPhotometricLoss(LossBase): """ Self-Supervised multiview photometric loss. It takes two images, a depth map and a pose transformation to produce a reconstruction of one image from the perspective of the other, and calculates the difference between them Parameters ---------- num_scales : int Number of inverse depth map scalesto consider ssim_loss_weight : float Weight for the SSIM loss occ_reg_weight : float Weight for the occlusion regularization loss smooth_loss_weight : float Weight for the smoothness loss C1,C2 : float SSIM parameters photometric_reduce_op : str Method to reduce the photometric loss disp_norm : bool True if inverse depth is normalized for clip_loss : float Threshold for photometric loss clipping progressive_scaling : float Training percentage for progressive scaling (0.0 to disable) padding_mode : str Padding mode for view synthesis automask_loss : bool True if automasking is enabled for the photometric loss kwargs : dict Extra parameters """ def __init__(self, num_scales=4, ssim_loss_weight=0.85, occ_reg_weight=0.1, smooth_loss_weight=0.1, C1=1e-4, C2=9e-4, photometric_reduce_op='mean', disp_norm=True, clip_loss=0.5, progressive_scaling=0.0, padding_mode='zeros', automask_loss=False, **kwargs): super().__init__() self.n = num_scales self.progressive_scaling = progressive_scaling self.ssim_loss_weight = ssim_loss_weight self.occ_reg_weight = occ_reg_weight self.smooth_loss_weight = smooth_loss_weight self.C1 = C1 self.C2 = C2 self.photometric_reduce_op = photometric_reduce_op self.disp_norm = disp_norm self.clip_loss = clip_loss self.padding_mode = padding_mode self.automask_loss = automask_loss self.progressive_scaling = ProgressiveScaling( progressive_scaling, self.n) # Asserts if self.automask_loss: assert self.photometric_reduce_op == 'min', \ 'For automasking only the min photometric_reduce_op is supported.' ######################################################################################################################## @property def logs(self): """Returns class logs.""" return { 'num_scales': self.n, } ########################################################################################################################
[docs] def warp_ref_image(self, inv_depths, ref_image, K, ref_K, pose): """ Warps a reference image to produce a reconstruction of the original one. Parameters ---------- inv_depths : torch.Tensor [B,1,H,W] Inverse depth map of the original image ref_image : torch.Tensor [B,3,H,W] Reference RGB image K : torch.Tensor [B,3,3] Original camera intrinsics ref_K : torch.Tensor [B,3,3] Reference camera intrinsics pose : Pose Original -> Reference camera transformation Returns ------- ref_warped : torch.Tensor [B,3,H,W] Warped reference image (reconstructing the original one) """ B, _, H, W = ref_image.shape device = ref_image.get_device() # Generate cameras for all scales cams, ref_cams = [], [] for i in range(self.n): _, _, DH, DW = inv_depths[i].shape scale_factor = DW / float(W) cams.append(Camera(K=K.float()).scaled(scale_factor).to(device)) ref_cams.append(Camera(K=ref_K.float(), Tcw=pose).scaled(scale_factor).to(device)) # View synthesis depths = [inv2depth(inv_depths[i]) for i in range(self.n)] ref_images = match_scales(ref_image, inv_depths, self.n) ref_warped = [view_synthesis( ref_images[i], depths[i], ref_cams[i], cams[i], padding_mode=self.padding_mode) for i in range(self.n)] # Return warped reference image return ref_warped
########################################################################################################################
[docs] def SSIM(self, x, y, kernel_size=3): """ Calculates the SSIM (Structural SIMilarity) loss Parameters ---------- x,y : torch.Tensor [B,3,H,W] Input images kernel_size : int Convolutional parameter Returns ------- ssim : torch.Tensor [1] SSIM loss """ ssim_value = SSIM(x, y, C1=self.C1, C2=self.C2, kernel_size=kernel_size) return torch.clamp((1. - ssim_value) / 2., 0., 1.)
[docs] def calc_photometric_loss(self, t_est, images): """ Calculates the photometric loss (L1 + SSIM) Parameters ---------- t_est : list of torch.Tensor [B,3,H,W] List of warped reference images in multiple scales images : list of torch.Tensor [B,3,H,W] List of original images in multiple scales Returns ------- photometric_loss : torch.Tensor [1] Photometric loss """ # L1 loss l1_loss = [torch.abs(t_est[i] - images[i]) for i in range(self.n)] # SSIM loss if self.ssim_loss_weight > 0.0: ssim_loss = [self.SSIM(t_est[i], images[i], kernel_size=3) for i in range(self.n)] # Weighted Sum: alpha * ssim + (1 - alpha) * l1 photometric_loss = [self.ssim_loss_weight * ssim_loss[i].mean(1, True) + (1 - self.ssim_loss_weight) * l1_loss[i].mean(1, True) for i in range(self.n)] else: photometric_loss = l1_loss # Clip loss if self.clip_loss > 0.0: for i in range(self.n): mean, std = photometric_loss[i].mean(), photometric_loss[i].std() photometric_loss[i] = torch.clamp( photometric_loss[i], max=float(mean + self.clip_loss * std)) # Return total photometric loss return photometric_loss
[docs] def reduce_photometric_loss(self, photometric_losses): """ Combine the photometric loss from all context images Parameters ---------- photometric_losses : list of torch.Tensor [B,3,H,W] Pixel-wise photometric losses from the entire context Returns ------- photometric_loss : torch.Tensor [1] Reduced photometric loss """ # Reduce function def reduce_function(losses): if self.photometric_reduce_op == 'mean': return sum([l.mean() for l in losses]) / len(losses) elif self.photometric_reduce_op == 'min': return torch.cat(losses, 1).min(1, True)[0].mean() else: raise NotImplementedError( 'Unknown photometric_reduce_op: {}'.format(self.photometric_reduce_op)) # Reduce photometric loss photometric_loss = sum([reduce_function(photometric_losses[i]) for i in range(self.n)]) / self.n # Store and return reduced photometric loss self.add_metric('photometric_loss', photometric_loss) return photometric_loss
########################################################################################################################
[docs] def calc_smoothness_loss(self, inv_depths, images): """ Calculates the smoothness loss for inverse depth maps. Parameters ---------- inv_depths : list of torch.Tensor [B,1,H,W] Predicted inverse depth maps for all scales images : list of torch.Tensor [B,3,H,W] Original images for all scales Returns ------- smoothness_loss : torch.Tensor [1] Smoothness loss """ # Calculate smoothness gradients smoothness_x, smoothness_y = calc_smoothness(inv_depths, images, self.n) # Calculate smoothness loss smoothness_loss = sum([(smoothness_x[i].abs().mean() + smoothness_y[i].abs().mean()) / 2 ** i for i in range(self.n)]) / self.n # Apply smoothness loss weight smoothness_loss = self.smooth_loss_weight * smoothness_loss # Store and return smoothness loss self.add_metric('smoothness_loss', smoothness_loss) return smoothness_loss
########################################################################################################################
[docs] def forward(self, image, context, inv_depths, K, ref_K, poses, return_logs=False, progress=0.0): """ Calculates training photometric loss. Parameters ---------- image : torch.Tensor [B,3,H,W] Original image context : list of torch.Tensor [B,3,H,W] Context containing a list of reference images inv_depths : list of torch.Tensor [B,1,H,W] Predicted depth maps for the original image, in all scales K : torch.Tensor [B,3,3] Original camera intrinsics ref_K : torch.Tensor [B,3,3] Reference camera intrinsics poses : list of Pose Camera transformation between original and context return_logs : bool True if logs are saved for visualization progress : float Training percentage Returns ------- losses_and_metrics : dict Output dictionary """ # If using progressive scaling self.n = self.progressive_scaling(progress) # Loop over all reference images photometric_losses = [[] for _ in range(self.n)] images = match_scales(image, inv_depths, self.n) for j, (ref_image, pose) in enumerate(zip(context, poses)): # Calculate warped images ref_warped = self.warp_ref_image(inv_depths, ref_image, K, ref_K, pose) # Calculate and store image loss photometric_loss = self.calc_photometric_loss(ref_warped, images) for i in range(self.n): photometric_losses[i].append(photometric_loss[i]) # If using automask if self.automask_loss: # Calculate and store unwarped image loss ref_images = match_scales(ref_image, inv_depths, self.n) unwarped_image_loss = self.calc_photometric_loss(ref_images, images) for i in range(self.n): photometric_losses[i].append(unwarped_image_loss[i]) # Calculate reduced photometric loss loss = self.reduce_photometric_loss(photometric_losses) # Include smoothness loss if requested if self.smooth_loss_weight > 0.0: loss += self.calc_smoothness_loss(inv_depths, images) # Return losses and metrics return { 'loss': loss.unsqueeze(0), 'metrics': self.metrics, }
########################################################################################################################