# Copyright 2020 Toyota Research Institute. All rights reserved.
import numpy as np
import torch.nn as nn
from packnet_sfm.utils.types import is_list
########################################################################################################################
[docs]class ProgressiveScaling:
"""
Helper class to manage progressive scaling.
After a certain training progress percentage, decrease the number of scales by 1.
Parameters
----------
progressive_scaling : float
Training progress percentage where the number of scales is decreased
num_scales : int
Initial number of scales
"""
def __init__(self, progressive_scaling, num_scales=4):
self.num_scales = num_scales
# Use it only if bigger than zero (make a list)
if progressive_scaling > 0.0:
self.progressive_scaling = np.float32(
[progressive_scaling * (i + 1) for i in range(num_scales - 1)] + [1.0])
# Otherwise, disable it
else:
self.progressive_scaling = progressive_scaling
def __call__(self, progress):
"""
Call for an update in the number of scales
Parameters
----------
progress : float
Training progress percentage
Returns
-------
num_scales : int
New number of scales
"""
if is_list(self.progressive_scaling):
return int(self.num_scales -
np.searchsorted(self.progressive_scaling, progress))
else:
return self.num_scales
########################################################################################################################
[docs]class LossBase(nn.Module):
"""Base class for losses."""
def __init__(self):
"""Initializes logs and metrics dictionaries"""
super().__init__()
self._logs = {}
self._metrics = {}
########################################################################################################################
@property
def logs(self):
"""Return logs."""
return self._logs
@property
def metrics(self):
"""Return metrics."""
return self._metrics
[docs] def add_metric(self, key, val):
"""Add a new metric to the dictionary and detach it."""
self._metrics[key] = val.detach()
########################################################################################################################