Source code for packnet_sfm.networks.pose.PoseNet

# Copyright 2020 Toyota Research Institute.  All rights reserved.

# Adapted from SfmLearner
# https://github.com/ClementPinard/SfmLearner-Pytorch/blob/master/models/PoseExpNet.py

import torch
import torch.nn as nn

########################################################################################################################

[docs]def conv_gn(in_planes, out_planes, kernel_size=3): """ Convolutional block with GroupNorm Parameters ---------- in_planes : int Number of input channels out_planes : int Number of output channels kernel_size : int Convolutional kernel size Returns ------- layers : nn.Sequential Sequence of Conv2D + GroupNorm + ReLU """ return nn.Sequential( nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, padding=(kernel_size - 1) // 2, stride=2), nn.GroupNorm(16, out_planes), nn.ReLU(inplace=True) )
########################################################################################################################
[docs]class PoseNet(nn.Module): """Pose network """ def __init__(self, nb_ref_imgs=2, rotation_mode='euler', **kwargs): super().__init__() self.nb_ref_imgs = nb_ref_imgs self.rotation_mode = rotation_mode conv_channels = [16, 32, 64, 128, 256, 256, 256] self.conv1 = conv_gn(3 * (1 + self.nb_ref_imgs), conv_channels[0], kernel_size=7) self.conv2 = conv_gn(conv_channels[0], conv_channels[1], kernel_size=5) self.conv3 = conv_gn(conv_channels[1], conv_channels[2]) self.conv4 = conv_gn(conv_channels[2], conv_channels[3]) self.conv5 = conv_gn(conv_channels[3], conv_channels[4]) self.conv6 = conv_gn(conv_channels[4], conv_channels[5]) self.conv7 = conv_gn(conv_channels[5], conv_channels[6]) self.pose_pred = nn.Conv2d(conv_channels[6], 6 * self.nb_ref_imgs, kernel_size=1, padding=0) self.init_weights()
[docs] def init_weights(self): for m in self.modules(): if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): nn.init.xavier_uniform_(m.weight.data) if m.bias is not None: m.bias.data.zero_()
[docs] def forward(self, image, context): assert (len(context) == self.nb_ref_imgs) input = [image] input.extend(context) input = torch.cat(input, 1) out_conv1 = self.conv1(input) out_conv2 = self.conv2(out_conv1) out_conv3 = self.conv3(out_conv2) out_conv4 = self.conv4(out_conv3) out_conv5 = self.conv5(out_conv4) out_conv6 = self.conv6(out_conv5) out_conv7 = self.conv7(out_conv6) pose = self.pose_pred(out_conv7) pose = pose.mean(3).mean(2) pose = 0.01 * pose.view(pose.size(0), self.nb_ref_imgs, 6) return pose
########################################################################################################################