Source code for packnet_sfm.utils.reduce


import torch
import horovod.torch as hvd
import numpy as np
from collections import OrderedDict
from packnet_sfm.utils.logging import prepare_dataset_prefix

########################################################################################################################

[docs]def reduce_value(value): """Reduce the mean value of a tensor from all GPUs""" return hvd.allreduce(value, average=True, name='value')
[docs]def reduce_dict(dict, to_item=False): """Reduce the mean values of a dictionary from all GPUs""" for key, val in dict.items(): dict[key] = reduce_value(dict[key]) if to_item: dict[key] = dict[key].item() return dict
[docs]def all_reduce_metrics(output_data_batch, datasets, name='depth'): # If there is only one dataset, wrap in a list if isinstance(output_data_batch[0], dict): output_data_batch = [output_data_batch] # Get metrics keys and dimensions names = [key for key in list(output_data_batch[0][0].keys()) if key.startswith(name)] dims = [output_data_batch[0][0][name].shape[0] for name in names] # List storing metrics for all datasets all_metrics_dict = [] # Loop over all datasets and all batches for output_batch, dataset in zip(output_data_batch, datasets): metrics_dict = OrderedDict() length = len(dataset) # Count how many times each sample was seen seen = torch.zeros(length) for output in output_batch: for i, idx in enumerate(output['idx']): seen[idx] += 1 seen = hvd.allreduce(seen, average=False, name='idx') assert not np.any(seen.numpy() == 0), \ 'Not all samples were seen during evaluation' # Reduce all relevant metrics for name, dim in zip(names, dims): metrics = torch.zeros(length, dim) for output in output_batch: for i, idx in enumerate(output['idx']): metrics[idx] = output[name] metrics = hvd.allreduce(metrics, average=False, name='depth_pp_gt') metrics_dict[name] = (metrics / seen.view(-1, 1)).mean(0) # Append metrics dictionary to the list all_metrics_dict.append(metrics_dict) # Return list of metrics dictionary return all_metrics_dict
########################################################################################################################
[docs]def collate_metrics(output_data_batch, name='depth'): """Collate epoch output to produce average metrics.""" # If there is only one dataset, wrap in a list if isinstance(output_data_batch[0], dict): output_data_batch = [output_data_batch] # Calculate the mean of all metrics metrics_data = [] # For all datasets for i, output_batch in enumerate(output_data_batch): metrics = OrderedDict() # For all keys (assume they are the same for all batches) for key, val in output_batch[0].items(): if key.startswith(name): metrics[key] = torch.stack([output[key] for output in output_batch], 0) metrics[key] = torch.mean(metrics[key], 0) metrics_data.append(metrics) # Return metrics data return metrics_data
[docs]def create_dict(metrics_data, metrics_keys, metrics_modes, dataset, name='depth'): """Creates a dictionary from collated metrics.""" # Create metrics dictionary metrics_dict = {} # For all datasets for n, metrics in enumerate(metrics_data): if metrics: # If there are calculated metrics prefix = prepare_dataset_prefix(dataset, n) # For all keys for i, key in enumerate(metrics_keys): for mode in metrics_modes: metrics_dict['{}-{}{}'.format(prefix, key, mode)] =\ metrics['{}{}'.format(name, mode)][i].item() # Return metrics dictionary return metrics_dict
########################################################################################################################
[docs]def average_key(batch_list, key): """ Average key in a list of batches Parameters ---------- batch_list : list of dict List containing dictionaries with the same keys key : str Key to be averaged Returns ------- average : float Average of the value contained in key for all batches """ values = [batch[key] for batch in batch_list] return sum(values) / len(values)
[docs]def average_sub_key(batch_list, key, sub_key): """ Average subkey in a dictionary in a list of batches Parameters ---------- batch_list : list of dict List containing dictionaries with the same keys key : str Key to be averaged sub_key : Sub key to be averaged (belonging to key) Returns ------- average : float Average of the value contained in the sub_key of key for all batches """ values = [batch[key][sub_key] for batch in batch_list] return sum(values) / len(values)
[docs]def average_loss_and_metrics(batch_list, prefix): """ Average loss and metrics values in a list of batches Parameters ---------- batch_list : list of dict List containing dictionaries with the same keys prefix : str Prefix string for metrics logging Returns ------- values : dict Dictionary containing a 'loss' float entry and a 'metrics' dict entry """ values = OrderedDict() key = 'loss' values['{}-{}'.format(prefix, key)] = \ average_key(batch_list, key) key = 'metrics' for sub_key in batch_list[0][key].keys(): values['{}-{}'.format(prefix, sub_key)] = \ average_sub_key(batch_list, key, sub_key) return values
########################################################################################################################