# Copyright 2020 Toyota Research Institute. All rights reserved.
from collections import OrderedDict
import os
import time
import random
import numpy as np
import torch
from torch.utils.data import ConcatDataset, DataLoader
from packnet_sfm.datasets import KITTIDataset, DGPDataset, ImageDataset
from packnet_sfm.datasets.transforms import get_transforms
from packnet_sfm.utils.depth import inv2depth, post_process_inv_depth, compute_depth_metrics
from packnet_sfm.utils.horovod import print0, world_size, rank, on_rank_0
from packnet_sfm.utils.image import flip_lr
from packnet_sfm.utils.load import load_class, load_class_args_create, \
load_network, filter_args
from packnet_sfm.utils.logging import pcolor
from packnet_sfm.utils.reduce import all_reduce_metrics, reduce_dict, \
create_dict, average_loss_and_metrics
from packnet_sfm.utils.save import save_depth
from packnet_sfm.models.model_utils import stack_batch
[docs]class ModelWrapper(torch.nn.Module):
"""
Top-level torch.nn.Module wrapper around a SfmModel (pose+depth networks).
Designed to use models with high-level Trainer classes (cf. trainers/).
Parameters
----------
config : CfgNode
Model configuration (cf. configs/default_config.py)
"""
def __init__(self, config, resume=None, logger=None, load_datasets=True):
super().__init__()
# Store configuration, checkpoint and logger
self.config = config
self.logger = logger
self.resume = resume
# Set random seed
set_random_seed(config.arch.seed)
# Task metrics
self.metrics_name = 'depth'
self.metrics_keys = ('abs_rel', 'sqr_rel', 'rmse', 'rmse_log', 'a1', 'a2', 'a3')
self.metrics_modes = ('', '_pp', '_gt', '_pp_gt')
# Model, optimizers, schedulers and datasets are None for now
self.model = self.optimizer = self.scheduler = None
self.train_dataset = self.validation_dataset = self.test_dataset = None
self.current_epoch = 0
# Prepare model
self.prepare_model(resume)
# Prepare datasets
if load_datasets:
# Requirements for validation (we only evaluate depth for now)
validation_requirements = {'gt_depth': True, 'gt_pose': False}
test_requirements = validation_requirements
self.prepare_datasets(validation_requirements, test_requirements)
# Preparations done
self.config.prepared = True
[docs] def prepare_model(self, resume=None):
"""Prepare self.model (incl. loading previous state)"""
print0(pcolor('### Preparing Model', 'green'))
self.model = setup_model(self.config.model, self.config.prepared)
# Resume model if available
if resume:
print0(pcolor('### Resuming from {}'.format(
resume['file']), 'magenta', attrs=['bold']))
self.model = load_network(
self.model, resume['state_dict'], 'model')
if 'epoch' in resume:
self.current_epoch = resume['epoch']
[docs] def prepare_datasets(self, validation_requirements, test_requirements):
"""Prepare datasets for training, validation and test."""
# Prepare datasets
print0(pcolor('### Preparing Datasets', 'green'))
augmentation = self.config.datasets.augmentation
# Setup train dataset (requirements are given by the model itself)
self.train_dataset = setup_dataset(
self.config.datasets.train, 'train',
self.model.train_requirements, **augmentation)
# Setup validation dataset
self.validation_dataset = setup_dataset(
self.config.datasets.validation, 'validation',
validation_requirements, **augmentation)
# Setup test dataset
self.test_dataset = setup_dataset(
self.config.datasets.test, 'test',
test_requirements, **augmentation)
@property
def depth_net(self):
"""Returns depth network."""
return self.model.depth_net
@property
def pose_net(self):
"""Returns pose network."""
return self.model.pose_net
@property
def logs(self):
"""Returns various logs for tracking."""
params = OrderedDict()
for param in self.optimizer.param_groups:
params['{}_learning_rate'.format(param['name'].lower())] = param['lr']
params['progress'] = self.progress
return {
**params,
**self.model.logs,
}
@property
def progress(self):
"""Returns training progress (current epoch / max. number of epochs)"""
return self.current_epoch / self.config.arch.max_epochs
[docs] def train_dataloader(self):
"""Prepare training dataloader."""
return setup_dataloader(self.train_dataset,
self.config.datasets.train, 'train')[0]
[docs] def val_dataloader(self):
"""Prepare validation dataloader."""
return setup_dataloader(self.validation_dataset,
self.config.datasets.validation, 'validation')
[docs] def test_dataloader(self):
"""Prepare test dataloader."""
return setup_dataloader(self.test_dataset,
self.config.datasets.test, 'test')
[docs] def training_step(self, batch, *args):
"""Processes a training batch."""
batch = stack_batch(batch)
output = self.model(batch, progress=self.progress)
return {
'loss': output['loss'],
'metrics': output['metrics']
}
[docs] def validation_step(self, batch, *args):
"""Processes a validation batch."""
output = self.evaluate_depth(batch)
if self.logger:
self.logger.log_depth('val', batch, output, args,
self.validation_dataset, world_size(),
self.config.datasets.validation)
return {
'idx': batch['idx'],
**output['metrics'],
}
[docs] def test_step(self, batch, *args):
"""Processes a test batch."""
output = self.evaluate_depth(batch)
save_depth(batch, output, args,
self.config.datasets.test,
self.config.save)
return {
'idx': batch['idx'],
**output['metrics'],
}
[docs] def training_epoch_end(self, output_batch):
"""Finishes a training epoch."""
# Calculate and reduce average loss and metrics per GPU
loss_and_metrics = average_loss_and_metrics(output_batch, 'avg_train')
loss_and_metrics = reduce_dict(loss_and_metrics, to_item=True)
# Log to wandb
if self.logger:
self.logger.log_metrics({
**self.logs, **loss_and_metrics,
})
return {
**loss_and_metrics
}
[docs] def validation_epoch_end(self, output_data_batch):
"""Finishes a validation epoch."""
# Reduce depth metrics
metrics_data = all_reduce_metrics(
output_data_batch, self.validation_dataset, self.metrics_name)
# Create depth dictionary
metrics_dict = create_dict(
metrics_data, self.metrics_keys, self.metrics_modes,
self.config.datasets.validation)
# Print stuff
self.print_metrics(metrics_data, self.config.datasets.validation)
# Log to wandb
if self.logger:
self.logger.log_metrics({
**metrics_dict, 'global_step': self.current_epoch + 1,
})
return {
**metrics_dict
}
[docs] def test_epoch_end(self, output_data_batch):
"""Finishes a test epoch."""
# Reduce depth metrics
metrics_data = all_reduce_metrics(
output_data_batch, self.test_dataset, self.metrics_name)
# Create depth dictionary
metrics_dict = create_dict(
metrics_data, self.metrics_keys, self.metrics_modes,
self.config.datasets.test)
# Print stuff
self.print_metrics(metrics_data, self.config.datasets.test)
return {
**metrics_dict
}
[docs] def forward(self, *args, **kwargs):
"""Runs the model and returns the output."""
assert self.model is not None, 'Model not defined'
return self.model(*args, **kwargs)
[docs] def depth(self, *args, **kwargs):
"""Runs the pose network and returns the output."""
assert self.depth_net is not None, 'Depth network not defined'
return self.depth_net(*args, **kwargs)
[docs] def pose(self, *args, **kwargs):
"""Runs the depth network and returns the output."""
assert self.pose_net is not None, 'Pose network not defined'
return self.pose_net(*args, **kwargs)
[docs] def evaluate_depth(self, batch):
"""Evaluate batch to produce depth metrics."""
# Get predicted depth
inv_depths = self.model(batch)['inv_depths']
depth = inv2depth(inv_depths[0])
# Post-process predicted depth
batch['rgb'] = flip_lr(batch['rgb'])
inv_depths_flipped = self.model(batch)['inv_depths']
inv_depth_pp = post_process_inv_depth(
inv_depths[0], inv_depths_flipped[0], method='mean')
depth_pp = inv2depth(inv_depth_pp)
batch['rgb'] = flip_lr(batch['rgb'])
# Calculate predicted metrics
metrics = OrderedDict()
if 'depth' in batch:
for mode in self.metrics_modes:
metrics[self.metrics_name + mode] = compute_depth_metrics(
self.config.model.params, gt=batch['depth'],
pred=depth_pp if 'pp' in mode else depth,
use_gt_scale='gt' in mode)
# Return metrics and extra information
return {
'metrics': metrics,
'inv_depth': inv_depth_pp
}
@on_rank_0
def print_metrics(self, metrics_data, dataset):
"""Print depth metrics on rank 0 if available"""
if not metrics_data[0]:
return
hor_line = '|{:<}|'.format('*' * 93)
met_line = '| {:^14} | {:^8} | {:^8} | {:^8} | {:^8} | {:^8} | {:^8} | {:^8} |'
num_line = '{:<14} | {:^8.3f} | {:^8.3f} | {:^8.3f} | {:^8.3f} | {:^8.3f} | {:^8.3f} | {:^8.3f}'
def wrap(string):
return '| {} |'.format(string)
print()
print()
print()
print(hor_line)
if self.optimizer is not None:
bs = 'E: {} BS: {}'.format(self.current_epoch + 1,
self.config.datasets.train.batch_size)
if self.model is not None:
bs += ' - {}'.format(self.config.model.name)
lr = 'LR ({}):'.format(self.config.model.optimizer.name)
for param in self.optimizer.param_groups:
lr += ' {} {:.2e}'.format(param['name'], param['lr'])
par_line = wrap(pcolor('{:<40}{:>51}'.format(bs, lr),
'green', attrs=['bold', 'dark']))
print(par_line)
print(hor_line)
print(met_line.format(*(('METRIC',) + self.metrics_keys)))
for n, metrics in enumerate(metrics_data):
print(hor_line)
path_line = '{}'.format(
os.path.join(dataset.path[n], dataset.split[n]))
if len(dataset.cameras[n]) == 1: # only allows single cameras
path_line += ' ({})'.format(dataset.cameras[n][0])
print(wrap(pcolor('*** {:<87}'.format(path_line), 'magenta', attrs=['bold'])))
print(hor_line)
for key, metric in metrics.items():
if self.metrics_name in key:
print(wrap(pcolor(num_line.format(
*((key.upper(),) + tuple(metric.tolist()))), 'cyan')))
print(hor_line)
if self.logger:
run_line = wrap(pcolor('{:<60}{:>31}'.format(
self.config.wandb.url, self.config.wandb.name), 'yellow', attrs=['dark']))
print(run_line)
print(hor_line)
print()
[docs]def set_random_seed(seed):
if seed >= 0:
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
[docs]def setup_depth_net(config, prepared, **kwargs):
"""
Create a depth network
Parameters
----------
config : CfgNode
Network configuration
prepared : bool
True if the network has been prepared before
kwargs : dict
Extra parameters for the network
Returns
-------
depth_net : nn.Module
Create depth network
"""
print0(pcolor('DepthNet: %s' % config.name, 'yellow'))
depth_net = load_class_args_create(config.name,
paths=['packnet_sfm.networks.depth',],
args={**config, **kwargs},
)
if not prepared and config.checkpoint_path is not '':
depth_net = load_network(depth_net, config.checkpoint_path,
['depth_net', 'disp_network'])
return depth_net
[docs]def setup_pose_net(config, prepared, **kwargs):
"""
Create a pose network
Parameters
----------
config : CfgNode
Network configuration
prepared : bool
True if the network has been prepared before
kwargs : dict
Extra parameters for the network
Returns
-------
pose_net : nn.Module
Created pose network
"""
print0(pcolor('PoseNet: %s' % config.name, 'yellow'))
pose_net = load_class_args_create(config.name,
paths=['packnet_sfm.networks.pose',],
args={**config, **kwargs},
)
if not prepared and config.checkpoint_path is not '':
pose_net = load_network(pose_net, config.checkpoint_path,
['pose_net', 'pose_network'])
return pose_net
[docs]def setup_model(config, prepared, **kwargs):
"""
Create a model
Parameters
----------
config : CfgNode
Model configuration (cf. configs/default_config.py)
prepared : bool
True if the model has been prepared before
kwargs : dict
Extra parameters for the model
Returns
-------
model : nn.Module
Created model
"""
print0(pcolor('Model: %s' % config.name, 'yellow'))
model = load_class(config.name, paths=['packnet_sfm.models',])(
**{**config.loss, **kwargs})
# Add depth network if required
if model.network_requirements['depth_net']:
model.add_depth_net(setup_depth_net(config.depth_net, prepared))
# Add pose network if required
if model.network_requirements['pose_net']:
model.add_pose_net(setup_pose_net(config.pose_net, prepared))
# If a checkpoint is provided, load pretrained model
if not prepared and config.checkpoint_path is not '':
model = load_network(model, config.checkpoint_path, 'model')
# Return model
return model
[docs]def setup_dataset(config, mode, requirements, **kwargs):
"""
Create a dataset class
Parameters
----------
config : CfgNode
Configuration (cf. configs/default_config.py)
mode : str {'train', 'validation', 'test'}
Mode from which we want the dataset
requirements : dict (string -> bool)
Different requirements for dataset loading (gt_depth, gt_pose, etc)
kwargs : dict
Extra parameters for dataset creation
Returns
-------
dataset : Dataset
Dataset class for that mode
"""
# If no dataset is given, return None
if len(config.path) == 0:
return None
print0(pcolor('###### Setup %s datasets' % mode, 'red'))
# Global shared dataset arguments
dataset_args = {
'back_context': config.back_context,
'forward_context': config.forward_context,
'data_transform': get_transforms(mode, **kwargs)
}
# Loop over all datasets
datasets = []
for i in range(len(config.split)):
path_split = os.path.join(config.path[i], config.split[i])
# Individual shared dataset arguments
dataset_args_i = {
'depth_type': config.depth_type[i] if requirements['gt_depth'] else None,
'with_pose': requirements['gt_pose'],
}
# KITTI dataset
if config.dataset[i] == 'KITTI':
dataset = KITTIDataset(
config.path[i], path_split,
**dataset_args, **dataset_args_i,
)
# DGP dataset
elif config.dataset[i] == 'DGP':
dataset = DGPDataset(
config.path[i], config.split[i],
**dataset_args, **dataset_args_i,
cameras=config.cameras[i],
)
# Image dataset
elif config.dataset[i] == 'Image':
dataset = ImageDataset(
config.path[i], config.split[i],
**dataset_args, **dataset_args_i,
)
else:
ValueError('Unknown dataset %d' % config.dataset[i])
# Repeat if needed
if 'repeat' in config and config.repeat[i] > 1:
dataset = ConcatDataset([dataset for _ in range(config.repeat[i])])
datasets.append(dataset)
# Display dataset information
bar = '######### {:>7}'.format(len(dataset))
if 'repeat' in config:
bar += ' (x{})'.format(config.repeat[i])
bar += ': {:<}'.format(path_split)
print0(pcolor(bar, 'yellow'))
# If training, concatenate all datasets into a single one
if mode == 'train':
datasets = [ConcatDataset(datasets)]
return datasets
[docs]def worker_init_fn(worker_id):
"""Function to initialize workers"""
time_seed = np.array(time.time(), dtype=np.int32)
np.random.seed(time_seed + worker_id)
[docs]def get_datasampler(dataset, mode):
"""Distributed data sampler"""
return torch.utils.data.distributed.DistributedSampler(
dataset, shuffle=(mode=='train'),
num_replicas=world_size(), rank=rank())
[docs]def setup_dataloader(datasets, config, mode):
"""
Create a dataloader class
Parameters
----------
datasets : list of Dataset
List of datasets from which to create dataloaders
config : CfgNode
Model configuration (cf. configs/default_config.py)
mode : str {'train', 'validation', 'test'}
Mode from which we want the dataloader
Returns
-------
dataloaders : list of Dataloader
List of created dataloaders for each input dataset
"""
return [(DataLoader(dataset,
batch_size=config.batch_size, shuffle=False,
pin_memory=True, num_workers=config.num_workers,
worker_init_fn=worker_init_fn,
sampler=get_datasampler(dataset, mode))
) for dataset in datasets]