Source code for packnet_sfm.models.SfmModel

# Copyright 2020 Toyota Research Institute.  All rights reserved.

import random
import torch.nn as nn
from packnet_sfm.utils.image import flip_model, interpolate_scales
from packnet_sfm.geometry.pose import Pose
from packnet_sfm.utils.misc import make_list


[docs]class SfmModel(nn.Module): """ Model class encapsulating a pose and depth networks. Parameters ---------- depth_net : nn.Module Depth network to be used pose_net : nn.Module Pose network to be used rotation_mode : str Rotation mode for the pose network flip_lr_prob : float Probability of flipping when using the depth network upsample_depth_maps : bool True if depth map scales are upsampled to highest resolution kwargs : dict Extra parameters """ def __init__(self, depth_net=None, pose_net=None, rotation_mode='euler', flip_lr_prob=0.0, upsample_depth_maps=False, **kwargs): super().__init__() self.depth_net = depth_net self.pose_net = pose_net self.rotation_mode = rotation_mode self.flip_lr_prob = flip_lr_prob self.upsample_depth_maps = upsample_depth_maps self._logs = {} self._losses = {} self._network_requirements = { 'depth_net': True, # Depth network required 'pose_net': True, # Pose network required } self._train_requirements = { 'gt_depth': False, # No ground-truth depth required 'gt_pose': False, # No ground-truth pose required } @property def logs(self): """Return logs.""" return self._logs @property def losses(self): """Return metrics.""" return self._losses
[docs] def add_loss(self, key, val): """Add a new loss to the dictionary and detaches it.""" self._losses[key] = val.detach()
@property def network_requirements(self): """ Networks required to run the model Returns ------- requirements : dict depth_net : bool Whether a depth network is required by the model pose_net : bool Whether a depth network is required by the model """ return self._network_requirements @property def train_requirements(self): """ Information required by the model at training stage Returns ------- requirements : dict gt_depth : bool Whether ground truth depth is required by the model at training time gt_pose : bool Whether ground truth pose is required by the model at training time """ return self._train_requirements
[docs] def add_depth_net(self, depth_net): """Add a depth network to the model""" self.depth_net = depth_net
[docs] def add_pose_net(self, pose_net): """Add a pose network to the model""" self.pose_net = pose_net
[docs] def compute_inv_depths(self, image): """Computes inverse depth maps from single images""" # Randomly flip and estimate inverse depth maps flip_lr = random.random() < self.flip_lr_prob if self.training else False inv_depths = make_list(flip_model(self.depth_net, image, flip_lr)) # If upsampling depth maps if self.upsample_depth_maps: inv_depths = interpolate_scales( inv_depths, mode='nearest', align_corners=None) # Return inverse depth maps return inv_depths
[docs] def compute_poses(self, image, contexts): """Compute poses from image and a sequence of context images""" pose_vec = self.pose_net(image, contexts) return [Pose.from_vec(pose_vec[:, i], self.rotation_mode) for i in range(pose_vec.shape[1])]
[docs] def forward(self, batch, return_logs=False): """ Processes a batch. Parameters ---------- batch : dict Input batch return_logs : bool True if logs are stored Returns ------- output : dict Dictionary containing predicted inverse depth maps and poses """ # Generate inverse depth predictions inv_depths = self.compute_inv_depths(batch['rgb']) # Generate pose predictions if available pose = None if 'rgb_context' in batch and self.pose_net is not None: pose = self.compute_poses(batch['rgb'], batch['rgb_context']) # Return output dictionary return { 'inv_depths': inv_depths, 'poses': pose, }