Source code for packnet_sfm.geometry.camera

# Copyright 2020 Toyota Research Institute.  All rights reserved.

from functools import lru_cache
import torch
import torch.nn as nn

from packnet_sfm.geometry.pose import Pose
from packnet_sfm.geometry.camera_utils import scale_intrinsics
from packnet_sfm.utils.image import image_grid

########################################################################################################################

[docs]class Camera(nn.Module): """ Differentiable camera class implementing reconstruction and projection functions for a pinhole model. """ def __init__(self, K, Tcw=None): """ Initializes the Camera class Parameters ---------- K : torch.Tensor [3,3] Camera intrinsics Tcw : Pose Camera -> World pose transformation """ super().__init__() self.K = K self.Tcw = Pose.identity(len(K)) if Tcw is None else Tcw def __len__(self): """Batch size of the camera intrinsics""" return len(self.K)
[docs] def to(self, *args, **kwargs): """Moves object to a specific device""" self.K = self.K.to(*args, **kwargs) self.Tcw = self.Tcw.to(*args, **kwargs) return self
######################################################################################################################## @property def fx(self): """Focal length in x""" return self.K[:, 0, 0] @property def fy(self): """Focal length in y""" return self.K[:, 1, 1] @property def cx(self): """Principal point in x""" return self.K[:, 0, 2] @property def cy(self): """Principal point in y""" return self.K[:, 1, 2] @property @lru_cache() def Twc(self): """World -> Camera pose transformation (inverse of Tcw)""" return self.Tcw.inverse() @property @lru_cache() def Kinv(self): """Inverse intrinsics (for lifting)""" Kinv = self.K.clone() Kinv[:, 0, 0] = 1. / self.fx Kinv[:, 1, 1] = 1. / self.fy Kinv[:, 0, 2] = -1. * self.cx / self.fx Kinv[:, 1, 2] = -1. * self.cy / self.fy return Kinv ########################################################################################################################
[docs] def scaled(self, x_scale, y_scale=None): """ Returns a scaled version of the camera (changing intrinsics) Parameters ---------- x_scale : float Resize scale in x y_scale : float Resize scale in y. If None, use the same as x_scale Returns ------- camera : Camera Scaled version of the current cmaera """ # If single value is provided, use for both dimensions if y_scale is None: y_scale = x_scale # If no scaling is necessary, return same camera if x_scale == 1. and y_scale == 1.: return self # Scale intrinsics and return new camera with same Pose K = scale_intrinsics(self.K.clone(), x_scale, y_scale) return Camera(K, Tcw=self.Tcw)
########################################################################################################################
[docs] def reconstruct(self, depth, frame='w'): """ Reconstructs pixel-wise 3D points from a depth map. Parameters ---------- depth : torch.Tensor [B,1,H,W] Depth map for the camera frame : 'w' Reference frame: 'c' for camera and 'w' for world Returns ------- points : torch.tensor [B,3,H,W] Pixel-wise 3D points """ B, C, H, W = depth.shape assert C == 1 # Create flat index grid grid = image_grid(B, H, W, depth.dtype, depth.device, normalized=False) # [B,3,H,W] flat_grid = grid.view(B, 3, -1) # [B,3,HW] # Estimate the outward rays in the camera frame xnorm = (self.Kinv.bmm(flat_grid)).view(B, 3, H, W) # Scale rays to metric depth Xc = xnorm * depth # If in camera frame of reference if frame == 'c': return Xc # If in world frame of reference elif frame == 'w': return self.Twc @ Xc # If none of the above else: raise ValueError('Unknown reference frame {}'.format(frame))
[docs] def project(self, X, frame='w'): """ Projects 3D points onto the image plane Parameters ---------- X : torch.Tensor [B,3,H,W] 3D points to be projected frame : 'w' Reference frame: 'c' for camera and 'w' for world Returns ------- points : torch.Tensor [B,H,W,2] 2D projected points that are within the image boundaries """ B, C, H, W = X.shape assert C == 3 # Project 3D points onto the camera image plane if frame == 'c': Xc = self.K.bmm(X.view(B, 3, -1)) elif frame == 'w': Xc = self.K.bmm((self.Tcw @ X).view(B, 3, -1)) else: raise ValueError('Unknown reference frame {}'.format(frame)) # Normalize points X = Xc[:, 0] Y = Xc[:, 1] Z = Xc[:, 2].clamp(min=1e-5) Xnorm = 2 * (X / Z) / (W - 1) - 1. Ynorm = 2 * (Y / Z) / (H - 1) - 1. # Clamp out-of-bounds pixels # Xmask = ((Xnorm > 1) + (Xnorm < -1)).detach() # Xnorm[Xmask] = 2. # Ymask = ((Ynorm > 1) + (Ynorm < -1)).detach() # Ynorm[Ymask] = 2. # Return pixel coordinates return torch.stack([Xnorm, Ynorm], dim=-1).view(B, H, W, 2)