Source code for packnet_sfm.networks.layers.resnet.depth_decoder

# Copyright 2020 Toyota Research Institute.  All rights reserved.

# Adapted from monodepth2
# https://github.com/nianticlabs/monodepth2/blob/master/networks/depth_decoder.py

from __future__ import absolute_import, division, print_function

import numpy as np
import torch
import torch.nn as nn

from collections import OrderedDict
from .layers import ConvBlock, Conv3x3, upsample


[docs]class DepthDecoder(nn.Module): def __init__(self, num_ch_enc, scales=range(4), num_output_channels=1, use_skips=True): super(DepthDecoder, self).__init__() self.num_output_channels = num_output_channels self.use_skips = use_skips self.upsample_mode = 'nearest' self.scales = scales self.num_ch_enc = num_ch_enc self.num_ch_dec = np.array([16, 32, 64, 128, 256]) # decoder self.convs = OrderedDict() for i in range(4, -1, -1): # upconv_0 num_ch_in = self.num_ch_enc[-1] if i == 4 else self.num_ch_dec[i + 1] num_ch_out = self.num_ch_dec[i] self.convs[("upconv", i, 0)] = ConvBlock(num_ch_in, num_ch_out) # upconv_1 num_ch_in = self.num_ch_dec[i] if self.use_skips and i > 0: num_ch_in += self.num_ch_enc[i - 1] num_ch_out = self.num_ch_dec[i] self.convs[("upconv", i, 1)] = ConvBlock(num_ch_in, num_ch_out) for s in self.scales: self.convs[("dispconv", s)] = Conv3x3(self.num_ch_dec[s], self.num_output_channels) self.decoder = nn.ModuleList(list(self.convs.values())) self.sigmoid = nn.Sigmoid()
[docs] def forward(self, input_features): self.outputs = {} # decoder x = input_features[-1] for i in range(4, -1, -1): x = self.convs[("upconv", i, 0)](x) x = [upsample(x)] if self.use_skips and i > 0: x += [input_features[i - 1]] x = torch.cat(x, 1) x = self.convs[("upconv", i, 1)](x) if i in self.scales: self.outputs[("disp", i)] = self.sigmoid(self.convs[("dispconv", i)](x)) return self.outputs