Source code for packnet_sfm.utils.load

# Copyright 2020 Toyota Research Institute.  All rights reserved.

import importlib
import logging
import os
import warnings
import torch

from inspect import signature
from collections import OrderedDict

from packnet_sfm.utils.misc import make_list, same_shape
from packnet_sfm.utils.logging import pcolor
from packnet_sfm.utils.horovod import print0
from packnet_sfm.utils.types import is_str


[docs]def set_debug(debug): """ Enable or disable debug terminal logging Parameters ---------- debug : bool Debugging flag (True to enable) """ # Disable logging if requested if not debug: os.environ['NCCL_DEBUG'] = '' os.environ['WANDB_SILENT'] = 'false' warnings.filterwarnings("ignore") logging.disable(logging.CRITICAL)
[docs]def filter_args(func, keys): """ Filters a dictionary so it only contains keys that are arguments of a function Parameters ---------- func : Function Function for which we are filtering the dictionary keys : dict Dictionary with keys we are filtering Returns ------- filtered : dict Dictionary containing only keys that are arguments of func """ filtered = {} sign = list(signature(func).parameters.keys()) for k, v in {**keys}.items(): if k in sign: filtered[k] = v return filtered
[docs]def filter_args_create(func, keys): """ Filters a dictionary so it only contains keys that are arguments of a function and creates a function with those arguments Parameters ---------- func : Function Function for which we are filtering the dictionary keys : dict Dictionary with keys we are filtering Returns ------- func : Function Function with filtered keys as arguments """ return func(**filter_args(func, keys))
[docs]def load_class(filename, paths, concat=True): """ Look for a file in different locations and return its method with the same name Optionally, you can use concat to search in path.filename instead Parameters ---------- filename : str Name of the file we are searching for paths : str or list of str Folders in which the file will be searched concat : bool Flag to concatenate filename to each path during the search Returns ------- method : Function Loaded method """ # for each path in paths for path in make_list(paths): # Create full path full_path = '{}.{}'.format(path, filename) if concat else path if importlib.util.find_spec(full_path): # Return method with same name as the file return getattr(importlib.import_module(full_path), filename) raise ValueError('Unknown class {}'.format(filename))
[docs]def load_class_args_create(filename, paths, args={}, concat=True): """Loads a class (filename) and returns an instance with filtered arguments (args)""" class_type = load_class(filename, paths, concat) return filter_args_create(class_type, args)
[docs]def load_network(network, path, prefixes=''): """ Loads a pretrained network Parameters ---------- network : nn.Module Network that will receive the pretrained weights path : str File containing a 'state_dict' key with pretrained network weights prefixes : str or list of str Layer name prefixes to consider when loading the network Returns ------- network : nn.Module Updated network with pretrained weights """ prefixes = make_list(prefixes) # If path is a string if is_str(path): saved_state_dict = torch.load(path, map_location='cpu')['state_dict'] if path.endswith('.pth.tar'): saved_state_dict = backwards_state_dict(saved_state_dict) # If state dict is already provided else: saved_state_dict = path # Get network state dict network_state_dict = network.state_dict() updated_state_dict = OrderedDict() n, n_total = 0, len(network_state_dict.keys()) for key, val in saved_state_dict.items(): for prefix in prefixes: prefix = prefix + '.' if prefix in key: idx = key.find(prefix) + len(prefix) key = key[idx:] if key in network_state_dict.keys() and \ same_shape(val.shape, network_state_dict[key].shape): updated_state_dict[key] = val n += 1 network.load_state_dict(updated_state_dict, strict=False) base_color, attrs = 'cyan', ['bold', 'dark'] color = 'green' if n == n_total else 'yellow' if n > 0 else 'red' print0(pcolor('###### Pretrained {} loaded:'.format(prefixes[0]), base_color, attrs=attrs) + pcolor(' {}/{} '.format(n, n_total), color, attrs=attrs) + pcolor('tensors', base_color, attrs=attrs)) return network
[docs]def backwards_state_dict(state_dict): """ Modify the state dict of older models for backwards compatibility Parameters ---------- state_dict : dict Model state dict with pretrained weights Returns ------- state_dict : dict Updated model state dict with modified layer names """ # List of layer names to change changes = (('model.model', 'model'), ('pose_network', 'pose_net'), ('disp_network', 'depth_net')) # Iterate over all keys and values updated_state_dict = OrderedDict() for key, val in state_dict.items(): # Ad hoc changes due to version changes key = '{}.{}'.format('model', key) if 'disp_network' in key: key = key.replace('conv3.0.weight', 'conv3.weight') key = key.replace('conv3.0.bias', 'conv3.bias') # Change layer names for change in changes: key = key.replace('{}.'.format(change[0]), '{}.'.format(change[1])) updated_state_dict[key] = val # Return updated state dict return updated_state_dict