Source code for packnet_sfm.models.model_utils

# Copyright 2020 Toyota Research Institute.  All rights reserved.

from packnet_sfm.utils.types import is_tensor, is_list, is_numpy

[docs]def merge_outputs(*outputs): """ Merges model outputs for logging Parameters ---------- outputs : tuple of dict Outputs to be merged Returns ------- output : dict Dictionary with a "metrics" key containing a dictionary with various metrics and all other keys that are not "loss" (it is handled differently). """ ignore = ['loss'] # Keys to ignore combine = ['metrics'] # Keys to combine merge = {key: {} for key in combine} for output in outputs: # Iterate over all keys for key, val in output.items(): # Combine these keys if key in combine: for sub_key, sub_val in output[key].items(): assert sub_key not in merge[key].keys(), \ 'Combining duplicated key {} to {}'.format(sub_key, key) merge[key][sub_key] = sub_val # Ignore these keys elif key not in ignore: assert key not in merge.keys(), \ 'Adding duplicated key {}'.format(key) merge[key] = val return merge
[docs]def stack_batch(batch): """ Stack multi-camera batches (B,N,C,H,W becomes BN,C,H,W) Parameters ---------- batch : dict Batch Returns ------- batch : dict Stacked batch """ # If there is multi-camera information if len(batch['rgb'].shape) == 5: assert batch['rgb'].shape[0] == 1, 'Only batch size 1 is supported for multi-cameras' # Loop over all keys for key in batch.keys(): # If list, stack every item if is_list(batch[key]): if is_tensor(batch[key][0]) or is_numpy(batch[key][0]): batch[key] = [sample[0] for sample in batch[key]] # Else, stack single item else: batch[key] = batch[key][0] return batch