# Copyright 2020 Toyota Research Institute. All rights reserved.
import torch
from packnet_sfm.models.SelfSupModel import SfmModel, SelfSupModel
from packnet_sfm.losses.supervised_loss import SupervisedLoss
from packnet_sfm.models.model_utils import merge_outputs
from packnet_sfm.utils.depth import depth2inv
[docs]class SemiSupModel(SelfSupModel):
"""
Model that inherits a depth and pose networks, plus the self-supervised loss from
SelfSupModel and includes a supervised loss for semi-supervision.
Parameters
----------
supervised_loss_weight : float
Weight for the supervised loss
kwargs : dict
Extra parameters
"""
def __init__(self, supervised_loss_weight=0.9, **kwargs):
# Initializes SelfSupModel
super().__init__(**kwargs)
# If supervision weight is 0.0, use SelfSupModel directly
assert 0. < supervised_loss_weight <= 1., "Model requires (0, 1] supervision"
# Store weight and initializes supervised loss
self.supervised_loss_weight = supervised_loss_weight
self._supervised_loss = SupervisedLoss(**kwargs)
# Pose network is only required if there is self-supervision
self._network_requirements['pose_net'] = self.supervised_loss_weight < 1
# GT depth is only required if there is supervision
self._train_requirements['gt_depth'] = self.supervised_loss_weight > 0
@property
def logs(self):
"""Return logs."""
return {
**super().logs,
**self._supervised_loss.logs
}
[docs] def supervised_loss(self, inv_depths, gt_inv_depths,
return_logs=False, progress=0.0):
"""
Calculates the supervised loss.
Parameters
----------
inv_depths : torch.Tensor [B,1,H,W]
Predicted inverse depth maps from the original image
gt_inv_depths : torch.Tensor [B,1,H,W]
Ground-truth inverse depth maps from the original image
return_logs : bool
True if logs are stored
progress :
Training progress percentage
Returns
-------
output : dict
Dictionary containing a "loss" scalar a "metrics" dictionary
"""
return self._supervised_loss(
inv_depths, gt_inv_depths,
return_logs=return_logs, progress=progress)
[docs] def forward(self, batch, return_logs=False, progress=0.0):
"""
Processes a batch.
Parameters
----------
batch : dict
Input batch
return_logs : bool
True if logs are stored
progress :
Training progress percentage
Returns
-------
output : dict
Dictionary containing a "loss" scalar and different metrics and predictions
for logging and downstream usage.
"""
if not self.training:
# If not training, no need for self-supervised loss
return SfmModel.forward(self, batch)
else:
if self.supervised_loss_weight == 1.:
# If no self-supervision, no need to calculate loss
self_sup_output = SfmModel.forward(self, batch)
loss = torch.tensor([0.]).type_as(batch['rgb'])
else:
# Otherwise, calculate and weight self-supervised loss
self_sup_output = SelfSupModel.forward(self, batch)
loss = (1.0 - self.supervised_loss_weight) * self_sup_output['loss']
# Calculate and weight supervised loss
sup_output = self.supervised_loss(
self_sup_output['inv_depths'], depth2inv(batch['depth']),
return_logs=return_logs, progress=progress)
loss += self.supervised_loss_weight * sup_output['loss']
# Merge and return outputs
return {
'loss': loss,
**merge_outputs(self_sup_output, sup_output),
}