# Copyright 2020 Toyota Research Institute. All rights reserved.
from functools import lru_cache
import torch
import torch.nn as nn
from packnet_sfm.geometry.pose import Pose
from packnet_sfm.geometry.camera_utils import scale_intrinsics
from packnet_sfm.utils.image import image_grid
########################################################################################################################
[docs]class Camera(nn.Module):
"""
Differentiable camera class implementing reconstruction and projection
functions for a pinhole model.
"""
def __init__(self, K, Tcw=None):
"""
Initializes the Camera class
Parameters
----------
K : torch.Tensor [3,3]
Camera intrinsics
Tcw : Pose
Camera -> World pose transformation
"""
super().__init__()
self.K = K
self.Tcw = Pose.identity(len(K)) if Tcw is None else Tcw
def __len__(self):
"""Batch size of the camera intrinsics"""
return len(self.K)
[docs] def to(self, *args, **kwargs):
"""Moves object to a specific device"""
self.K = self.K.to(*args, **kwargs)
self.Tcw = self.Tcw.to(*args, **kwargs)
return self
########################################################################################################################
@property
def fx(self):
"""Focal length in x"""
return self.K[:, 0, 0]
@property
def fy(self):
"""Focal length in y"""
return self.K[:, 1, 1]
@property
def cx(self):
"""Principal point in x"""
return self.K[:, 0, 2]
@property
def cy(self):
"""Principal point in y"""
return self.K[:, 1, 2]
@property
@lru_cache()
def Twc(self):
"""World -> Camera pose transformation (inverse of Tcw)"""
return self.Tcw.inverse()
@property
@lru_cache()
def Kinv(self):
"""Inverse intrinsics (for lifting)"""
Kinv = self.K.clone()
Kinv[:, 0, 0] = 1. / self.fx
Kinv[:, 1, 1] = 1. / self.fy
Kinv[:, 0, 2] = -1. * self.cx / self.fx
Kinv[:, 1, 2] = -1. * self.cy / self.fy
return Kinv
########################################################################################################################
[docs] def scaled(self, x_scale, y_scale=None):
"""
Returns a scaled version of the camera (changing intrinsics)
Parameters
----------
x_scale : float
Resize scale in x
y_scale : float
Resize scale in y. If None, use the same as x_scale
Returns
-------
camera : Camera
Scaled version of the current cmaera
"""
# If single value is provided, use for both dimensions
if y_scale is None:
y_scale = x_scale
# If no scaling is necessary, return same camera
if x_scale == 1. and y_scale == 1.:
return self
# Scale intrinsics and return new camera with same Pose
K = scale_intrinsics(self.K.clone(), x_scale, y_scale)
return Camera(K, Tcw=self.Tcw)
########################################################################################################################
[docs] def reconstruct(self, depth, frame='w'):
"""
Reconstructs pixel-wise 3D points from a depth map.
Parameters
----------
depth : torch.Tensor [B,1,H,W]
Depth map for the camera
frame : 'w'
Reference frame: 'c' for camera and 'w' for world
Returns
-------
points : torch.tensor [B,3,H,W]
Pixel-wise 3D points
"""
B, C, H, W = depth.shape
assert C == 1
# Create flat index grid
grid = image_grid(B, H, W, depth.dtype, depth.device, normalized=False) # [B,3,H,W]
flat_grid = grid.view(B, 3, -1) # [B,3,HW]
# Estimate the outward rays in the camera frame
xnorm = (self.Kinv.bmm(flat_grid)).view(B, 3, H, W)
# Scale rays to metric depth
Xc = xnorm * depth
# If in camera frame of reference
if frame == 'c':
return Xc
# If in world frame of reference
elif frame == 'w':
return self.Twc @ Xc
# If none of the above
else:
raise ValueError('Unknown reference frame {}'.format(frame))
[docs] def project(self, X, frame='w'):
"""
Projects 3D points onto the image plane
Parameters
----------
X : torch.Tensor [B,3,H,W]
3D points to be projected
frame : 'w'
Reference frame: 'c' for camera and 'w' for world
Returns
-------
points : torch.Tensor [B,H,W,2]
2D projected points that are within the image boundaries
"""
B, C, H, W = X.shape
assert C == 3
# Project 3D points onto the camera image plane
if frame == 'c':
Xc = self.K.bmm(X.view(B, 3, -1))
elif frame == 'w':
Xc = self.K.bmm((self.Tcw @ X).view(B, 3, -1))
else:
raise ValueError('Unknown reference frame {}'.format(frame))
# Normalize points
X = Xc[:, 0]
Y = Xc[:, 1]
Z = Xc[:, 2].clamp(min=1e-5)
Xnorm = 2 * (X / Z) / (W - 1) - 1.
Ynorm = 2 * (Y / Z) / (H - 1) - 1.
# Clamp out-of-bounds pixels
# Xmask = ((Xnorm > 1) + (Xnorm < -1)).detach()
# Xnorm[Xmask] = 2.
# Ymask = ((Ynorm > 1) + (Ynorm < -1)).detach()
# Ynorm[Ymask] = 2.
# Return pixel coordinates
return torch.stack([Xnorm, Ynorm], dim=-1).view(B, H, W, 2)