Source code for packnet_sfm.networks.depth.DepthResNet

# Copyright 2020 Toyota Research Institute.  All rights reserved.

import torch.nn as nn
from functools import partial

from packnet_sfm.networks.layers.resnet.resnet_encoder import ResnetEncoder
from packnet_sfm.networks.layers.resnet.depth_decoder import DepthDecoder
from packnet_sfm.networks.layers.resnet.layers import disp_to_depth

########################################################################################################################

[docs]class DepthResNet(nn.Module): """ Inverse depth network based on the ResNet architecture. Parameters ---------- version : str Has a XY format, where: X is the number of residual layers [18, 34, 50] and Y is an optional ImageNet pretrained flag added by the "pt" suffix Example: "18pt" initializes a pretrained ResNet18, and "34" initializes a ResNet34 from scratch kwargs : dict Extra parameters """ def __init__(self, version=None, **kwargs): super().__init__() assert version is not None, "DispResNet needs a version" num_layers = int(version[:2]) # First two characters are the number of layers pretrained = version[2:] == 'pt' # If the last characters are "pt", use ImageNet pretraining assert num_layers in [18, 34, 50], 'ResNet version {} not available'.format(num_layers) self.encoder = ResnetEncoder(num_layers=num_layers, pretrained=pretrained) self.decoder = DepthDecoder(num_ch_enc=self.encoder.num_ch_enc) self.scale_inv_depth = partial(disp_to_depth, min_depth=0.1, max_depth=100.0)
[docs] def forward(self, x): """ Runs the network and returns inverse depth maps (4 scales if training and 1 if not). """ x = self.encoder(x) x = self.decoder(x) disps = [x[('disp', i)] for i in range(4)] if self.training: return [self.scale_inv_depth(d)[0] for d in disps] else: return self.scale_inv_depth(disps[0])[0]
########################################################################################################################