Source code for packnet_sfm.models.SelfSupModel

# Copyright 2020 Toyota Research Institute.  All rights reserved.

from packnet_sfm.models.SfmModel import SfmModel
from packnet_sfm.losses.multiview_photometric_loss import MultiViewPhotometricLoss
from packnet_sfm.models.model_utils import merge_outputs


[docs]class SelfSupModel(SfmModel): """ Model that inherits a depth and pose network from SfmModel and includes the photometric loss for self-supervised training. Parameters ---------- kwargs : dict Extra parameters """ def __init__(self, **kwargs): # Initializes SfmModel super().__init__(**kwargs) # Initializes the photometric loss self._photometric_loss = MultiViewPhotometricLoss(**kwargs) @property def logs(self): """Return logs.""" return { **super().logs, **self._photometric_loss.logs }
[docs] def self_supervised_loss(self, image, ref_images, inv_depths, poses, intrinsics, return_logs=False, progress=0.0): """ Calculates the self-supervised photometric loss. Parameters ---------- image : torch.Tensor [B,3,H,W] Original image ref_images : list of torch.Tensor [B,3,H,W] Reference images from context inv_depths : torch.Tensor [B,1,H,W] Predicted inverse depth maps from the original image poses : list of Pose List containing predicted poses between original and context images intrinsics : torch.Tensor [B,3,3] Camera intrinsics return_logs : bool True if logs are stored progress : Training progress percentage Returns ------- output : dict Dictionary containing a "loss" scalar a "metrics" dictionary """ return self._photometric_loss( image, ref_images, inv_depths, intrinsics, intrinsics, poses, return_logs=return_logs, progress=progress)
[docs] def forward(self, batch, return_logs=False, progress=0.0): """ Processes a batch. Parameters ---------- batch : dict Input batch return_logs : bool True if logs are stored progress : Training progress percentage Returns ------- output : dict Dictionary containing a "loss" scalar and different metrics and predictions for logging and downstream usage. """ # Calculate predicted depth and pose output output = super().forward(batch, return_logs=return_logs) if not self.training: # If not training, no need for self-supervised loss return output else: # Otherwise, calculate self-supervised loss self_sup_output = self.self_supervised_loss( batch['rgb_original'], batch['rgb_context_original'], output['inv_depths'], output['poses'], batch['intrinsics'], return_logs=return_logs, progress=progress) # Return loss and metrics return { 'loss': self_sup_output['loss'], **merge_outputs(output, self_sup_output), }