import torch
import horovod.torch as hvd
import numpy as np
from collections import OrderedDict
from packnet_sfm.utils.logging import prepare_dataset_prefix
########################################################################################################################
[docs]def reduce_value(value):
    """Reduce the mean value of a tensor from all GPUs"""
    return hvd.allreduce(value, average=True, name='value') 
[docs]def reduce_dict(dict, to_item=False):
    """Reduce the mean values of a dictionary from all GPUs"""
    for key, val in dict.items():
        dict[key] = reduce_value(dict[key])
        if to_item:
            dict[key] = dict[key].item()
    return dict 
[docs]def all_reduce_metrics(output_data_batch, datasets, name='depth'):
    # If there is only one dataset, wrap in a list
    if isinstance(output_data_batch[0], dict):
        output_data_batch = [output_data_batch]
    # Get metrics keys and dimensions
    names = [key for key in list(output_data_batch[0][0].keys()) if key.startswith(name)]
    dims = [output_data_batch[0][0][name].shape[0] for name in names]
    # List storing metrics for all datasets
    all_metrics_dict = []
    # Loop over all datasets and all batches
    for output_batch, dataset in zip(output_data_batch, datasets):
        metrics_dict = OrderedDict()
        length = len(dataset)
        # Count how many times each sample was seen
        seen = torch.zeros(length)
        for output in output_batch:
            for i, idx in enumerate(output['idx']):
                seen[idx] += 1
        seen = hvd.allreduce(seen, average=False, name='idx')
        assert not np.any(seen.numpy() == 0), \
            
'Not all samples were seen during evaluation'
        # Reduce all relevant metrics
        for name, dim in zip(names, dims):
            metrics = torch.zeros(length, dim)
            for output in output_batch:
                for i, idx in enumerate(output['idx']):
                    metrics[idx] = output[name]
            metrics = hvd.allreduce(metrics, average=False, name='depth_pp_gt')
            metrics_dict[name] = (metrics / seen.view(-1, 1)).mean(0)
        # Append metrics dictionary to the list
        all_metrics_dict.append(metrics_dict)
    # Return list of metrics dictionary
    return all_metrics_dict 
########################################################################################################################
[docs]def collate_metrics(output_data_batch, name='depth'):
    """Collate epoch output to produce average metrics."""
    # If there is only one dataset, wrap in a list
    if isinstance(output_data_batch[0], dict):
        output_data_batch = [output_data_batch]
    # Calculate the mean of all metrics
    metrics_data = []
    # For all datasets
    for i, output_batch in enumerate(output_data_batch):
        metrics = OrderedDict()
        # For all keys (assume they are the same for all batches)
        for key, val in output_batch[0].items():
            if key.startswith(name):
                metrics[key] = torch.stack([output[key] for output in output_batch], 0)
                metrics[key] = torch.mean(metrics[key], 0)
        metrics_data.append(metrics)
    # Return metrics data
    return metrics_data 
[docs]def create_dict(metrics_data, metrics_keys, metrics_modes,
                dataset, name='depth'):
    """Creates a dictionary from collated metrics."""
    # Create metrics dictionary
    metrics_dict = {}
    # For all datasets
    for n, metrics in enumerate(metrics_data):
        if metrics: # If there are calculated metrics
            prefix = prepare_dataset_prefix(dataset, n)
            # For all keys
            for i, key in enumerate(metrics_keys):
                for mode in metrics_modes:
                    metrics_dict['{}-{}{}'.format(prefix, key, mode)] =\
                        
metrics['{}{}'.format(name, mode)][i].item()
    # Return metrics dictionary
    return metrics_dict 
########################################################################################################################
[docs]def average_key(batch_list, key):
    """
    Average key in a list of batches
    Parameters
    ----------
    batch_list : list of dict
        List containing dictionaries with the same keys
    key : str
        Key to be averaged
    Returns
    -------
    average : float
        Average of the value contained in key for all batches
    """
    values = [batch[key] for batch in batch_list]
    return sum(values) / len(values) 
[docs]def average_sub_key(batch_list, key, sub_key):
    """
    Average subkey in a dictionary in a list of batches
    Parameters
    ----------
    batch_list : list of dict
        List containing dictionaries with the same keys
    key : str
        Key to be averaged
    sub_key :
        Sub key to be averaged (belonging to key)
    Returns
    -------
    average : float
        Average of the value contained in the sub_key of key for all batches
    """
    values = [batch[key][sub_key] for batch in batch_list]
    return sum(values) / len(values) 
[docs]def average_loss_and_metrics(batch_list, prefix):
    """
    Average loss and metrics values in a list of batches
    Parameters
    ----------
    batch_list : list of dict
        List containing dictionaries with the same keys
    prefix : str
        Prefix string for metrics logging
    Returns
    -------
    values : dict
        Dictionary containing a 'loss' float entry and a 'metrics' dict entry
    """
    values = OrderedDict()
    key = 'loss'
    values['{}-{}'.format(prefix, key)] = \
        
average_key(batch_list, key)
    key = 'metrics'
    for sub_key in batch_list[0][key].keys():
        values['{}-{}'.format(prefix, sub_key)] = \
            
average_sub_key(batch_list, key, sub_key)
    return values 
########################################################################################################################