import os
import torch
from datetime import datetime
from yacs.config import CfgNode
from packnet_sfm.utils.logging import s3_url, prepare_dataset_prefix
from packnet_sfm.utils.horovod import on_rank_0
from packnet_sfm.utils.types import is_cfg, is_list
from packnet_sfm.utils.misc import make_list
from packnet_sfm.utils.load import load_class, backwards_state_dict
[docs]def prep_dataset(config):
"""
Expand dataset configuration to match split length
Parameters
----------
config : CfgNode
Dataset configuration
Returns
-------
config : CfgNode
Updated dataset configuration
"""
# If there is no dataset, do nothing
if len(config.path) == 0:
return config
# If cameras is not a double list, make it so
if not config.cameras or not is_list(config.cameras[0]):
config.cameras = [config.cameras]
# Get maximum length and expand other arguments to the same length
n = max(len(config.split), len(config.cameras), len(config.depth_type))
config.dataset = make_list(config.dataset, n)
config.path = make_list(config.path, n)
config.split = make_list(config.split, n)
config.depth_type = make_list(config.depth_type, n)
config.cameras = make_list(config.cameras, n)
if 'repeat' in config:
config.repeat = make_list(config.repeat, n)
# Return updated configuration
return config
[docs]def set_name(config):
"""
Set run name based on available information
Parameters
----------
config : CfgNode
Model configuration
Returns
-------
name : str
Updated run name
"""
# If there is a name already, do nothing
if config.name is not '':
return config.name
else:
# Create a name based on available information
return '{}-{}-{}'.format(
os.path.basename(config.default),
os.path.splitext(os.path.basename(config.config))[0],
datetime.now().strftime("%Y.%m.%d-%Hh%Mm%Ss"))
[docs]def set_checkpoint(config):
"""
Set checkpoint information
Parameters
----------
config : CfgNode
Model configuration
Returns
-------
config : CfgNode
Updated model configuration
"""
# If checkpoint is enabled
if config.checkpoint.filepath is not '':
# Create proper monitor string
config.checkpoint.monitor = os.path.join('{}-{}'.format(
prepare_dataset_prefix(config.datasets.validation,
config.checkpoint.monitor_index),
config.checkpoint.monitor))
# Join checkpoint folder with run name
config.checkpoint.filepath = os.path.join(
config.checkpoint.filepath, config.name,
'{epoch:02d}_{%s:.3f}' % config.checkpoint.monitor)
# Set s3 url
if config.checkpoint.s3_path is not '':
config.checkpoint.s3_url = s3_url(config)
else:
# If not saving checkpoint, do not sync to s3
config.checkpoint.s3_path = ''
return config.checkpoint
@on_rank_0
def prep_logger_and_checkpoint(model):
"""
Use logger and checkpoint information to update configuration
Parameters
----------
model : nn.Module
Module to update
"""
# Change run name to be the wandb assigned name
if model.logger and not model.config.wandb.dry_run:
model.config.name = model.config.wandb.name = model.logger.run_name
model.config.wandb.url = model.logger.run_url
# If we are saving models we need to update the path
if model.config.checkpoint.filepath is not '':
# Change checkpoint filepath
filepath = model.config.checkpoint.filepath.split('/')
filepath[-2] = model.config.name
model.config.checkpoint.filepath = '/'.join(filepath)
# Change callback dirpath
dirpath = os.path.join(os.path.dirname(
model.trainer.checkpoint.dirpath), model.config.name)
model.trainer.checkpoint.dirpath = dirpath
os.makedirs(dirpath, exist_ok=True)
model.config.checkpoint.s3_url = s3_url(model.config)
# Log updated configuration
model.logger.log_config(model.config)
[docs]def get_default_config(cfg_default):
"""Get default configuration from file"""
config = load_class('get_cfg_defaults',
paths=[cfg_default.replace('/', '.')],
concat=False)()
config.merge_from_list(['default', cfg_default])
return config
[docs]def merge_cfg_file(config, cfg_file=None):
"""Merge configuration file"""
if cfg_file is not None:
config.merge_from_file(cfg_file)
config.merge_from_list(['config', cfg_file])
return config
[docs]def merge_cfgs(original, override):
"""
Updates CfgNode with information from another one
Parameters
----------
original : CfgNode
Original configuration node
override : CfgNode
Another configuration node used for overriding
Returns
-------
updated : CfgNode
Updated configuration node
"""
for key, value in original.items():
if key in override.keys():
if is_cfg(value): # If it's a configuration node, recursion
original[key] = merge_cfgs(original[key], override[key])
else: # Otherwise, simply update key
original[key] = override[key]
return original
[docs]def backwards_config(config):
"""
Add or update configuration for backwards compatibility
(no need for it right now, pretrained models are up-to-date with configuration files).
Parameters
----------
config : CfgNode
Model configuration
Returns
-------
config : CfgNode
Updated model configuration
"""
# Return updated configuration
return config
[docs]def parse_train_file(file):
"""
Parse file for training
Parameters
----------
file : str
File, can be either a
**.yaml** for a yacs configuration file or a
**.ckpt** for a pre-trained checkpoint file
Returns
-------
config : CfgNode
Parsed model configuration
ckpt : str
Parsed checkpoint file
"""
# If it's a .yaml configuration file
if file.endswith('yaml'):
cfg_default = 'configs/default_config'
return parse_train_config(cfg_default, file), None
# If it's a .ckpt checkpoint file
elif file.endswith('ckpt'):
checkpoint = torch.load(file, map_location='cpu')
config = checkpoint.pop('config')
checkpoint['file'] = file
return config, checkpoint
# We have a problem
else:
raise ValueError('You need to provide a .yaml or .ckpt to train')
[docs]def parse_train_config(cfg_default, cfg_file):
"""
Parse model configuration for training
Parameters
----------
cfg_default : str
Default **.py** configuration file
cfg_file : str
Configuration **.yaml** file to override the default parameters
Returns
-------
config : CfgNode
Parsed model configuration
"""
# Loads default configuration
config = get_default_config(cfg_default)
# Merge configuration file
config = merge_cfg_file(config, cfg_file)
# Return prepared configuration
return prepare_train_config(config)
[docs]def prepare_train_config(config):
"""
Prepare model configuration for training
Parameters
----------
config : CfgNode
Model configuration
Returns
-------
config : CfgNode
Prepared model configuration
"""
# If arguments have already been prepared, don't prepare
if config.prepared:
return config
# Asserts
assert config.wandb.dry_run or config.wandb.entity is not '', \
'You need a wandb entity'
assert config.wandb.dry_run or config.wandb.project is not '', \
'You need a wandb project'
assert config.checkpoint.filepath is '' or \
(config.checkpoint.monitor_index < len(config.datasets.validation.split)), \
'You need to monitor a valid dataset'
# Prepare datasets
config.datasets.train = prep_dataset(config.datasets.train)
config.datasets.validation = prep_dataset(config.datasets.validation)
config.datasets.test = prep_dataset(config.datasets.test)
# Set name and checkpoint
config.name = set_name(config)
config.checkpoint = set_checkpoint(config)
# Return configuration
return config
[docs]def parse_test_file(ckpt_file, cfg_file=None):
"""
Parse model configuration for testing
Parameters
----------
ckpt_file : str
Checkpoint file, with pretrained model
cfg_file :
Configuration file, to update pretrained model configuration
Returns
-------
config : CfgNode
Parsed model configuration
state_dict : dict
Model state dict with pretrained weights
"""
assert ckpt_file.endswith('.ckpt') or ckpt_file.endswith('.pth.tar'), \
'You need to provide a .ckpt or .pth.tar file for checkpoint, not {}'.format(ckpt_file)
assert cfg_file is None or cfg_file.endswith('yaml'), \
'You need to provide a .yaml file for configuration, not {}'.format(cfg_file)
cfg_default = 'configs/default_config'
return parse_test_config(ckpt_file, cfg_default, cfg_file)
[docs]def parse_test_config(ckpt_file, cfg_default, cfg_file):
"""
Parse model configuration for testing
Parameters
----------
ckpt_file : str
Checkpoint file, with pretrained model
cfg_default : str
Default configuration file, with default values
cfg_file : str
Configuration file with updated information
Returns
-------
Returns
-------
config : CfgNode
Parsed model configuration
state_dict : dict
Model state dict with pretrained weights
"""
if ckpt_file.endswith('.ckpt'):
# Load checkpoint
ckpt = torch.load(ckpt_file, map_location='cpu')
# Get base configuration
config_default = get_default_config(cfg_default)
# Extract configuration and model state
config_model, state_dict = ckpt['config'], ckpt['state_dict']
# Override default configuration with model configuration
config = merge_cfgs(config_default, config_model)
# Update configuration for backwards compatibility
config = backwards_config(config)
# If another config file is provided, use it
config = merge_cfg_file(config, cfg_file)
# Backwards compatibility with older models
elif ckpt_file.endswith('.pth.tar'):
# Load model state and update it for backwards compatibility
state_dict = torch.load(ckpt_file, map_location='cpu')['state_dict']
state_dict = backwards_state_dict(state_dict)
# Get default configuration
config = get_default_config(cfg_default)
# If config file is present, update configuration
config = merge_cfg_file(config, cfg_file)
else:
raise ValueError('Unknown checkpoint {}'.format(ckpt_file))
# Set pretrained model name
config.save.pretrained = ckpt_file
# Return prepared configuration and model state
return prepare_test_config(config), state_dict
[docs]def prepare_test_config(config):
"""
Prepare model configuration for testing
Parameters
----------
config : CfgNode
Model configuration
Returns
-------
config : CfgNode
Prepared model configuration
"""
# Remove train and validation datasets
config.datasets.train.path = config.datasets.validation.path = []
config.datasets.test = prep_dataset(config.datasets.test)
# Don't save models or log to wandb
config.wandb.dry_run = True
config.checkpoint.filepath = ''
# Return updated configuration
return config