Wrapper

class packnet_sfm.models.model_wrapper.ModelWrapper(config, resume=None, logger=None, load_datasets=True)[source]

Bases: torch.nn.modules.module.Module

Top-level torch.nn.Module wrapper around a SfmModel (pose+depth networks). Designed to use models with high-level Trainer classes (cf. trainers/).

Parameters

config (CfgNode) – Model configuration (cf. configs/default_config.py)

configure_optimizers()[source]

Configure depth and pose optimizers and the corresponding scheduler.

depth(*args, **kwargs)[source]

Runs the pose network and returns the output.

property depth_net

Returns depth network.

evaluate_depth(batch)[source]

Evaluate batch to produce depth metrics.

forward(*args, **kwargs)[source]

Runs the model and returns the output.

property logs

Returns various logs for tracking.

pose(*args, **kwargs)[source]

Runs the depth network and returns the output.

property pose_net

Returns pose network.

prepare_datasets(validation_requirements, test_requirements)[source]

Prepare datasets for training, validation and test.

prepare_model(resume=None)[source]

Prepare self.model (incl. loading previous state)

print_metrics(**kwargs)
property progress

Returns training progress (current epoch / max. number of epochs)

test_dataloader()[source]

Prepare test dataloader.

test_epoch_end(output_data_batch)[source]

Finishes a test epoch.

test_step(batch, *args)[source]

Processes a test batch.

train_dataloader()[source]

Prepare training dataloader.

training_epoch_end(output_batch)[source]

Finishes a training epoch.

training_step(batch, *args)[source]

Processes a training batch.

val_dataloader()[source]

Prepare validation dataloader.

validation_epoch_end(output_data_batch)[source]

Finishes a validation epoch.

validation_step(batch, *args)[source]

Processes a validation batch.

packnet_sfm.models.model_wrapper.get_datasampler(dataset, mode)[source]

Distributed data sampler

packnet_sfm.models.model_wrapper.set_random_seed(seed)[source]
packnet_sfm.models.model_wrapper.setup_dataloader(datasets, config, mode)[source]

Create a dataloader class

Parameters
  • datasets (list of Dataset) – List of datasets from which to create dataloaders

  • config (CfgNode) – Model configuration (cf. configs/default_config.py)

  • mode (str {'train', 'validation', 'test'}) – Mode from which we want the dataloader

Returns

dataloaders – List of created dataloaders for each input dataset

Return type

list of Dataloader

packnet_sfm.models.model_wrapper.setup_dataset(config, mode, requirements, **kwargs)[source]

Create a dataset class

Parameters
  • config (CfgNode) – Configuration (cf. configs/default_config.py)

  • mode (str {'train', 'validation', 'test'}) – Mode from which we want the dataset

  • requirements (dict (string -> bool)) – Different requirements for dataset loading (gt_depth, gt_pose, etc)

  • kwargs (dict) – Extra parameters for dataset creation

Returns

dataset – Dataset class for that mode

Return type

Dataset

packnet_sfm.models.model_wrapper.setup_depth_net(config, prepared, **kwargs)[source]

Create a depth network

Parameters
  • config (CfgNode) – Network configuration

  • prepared (bool) – True if the network has been prepared before

  • kwargs (dict) – Extra parameters for the network

Returns

depth_net – Create depth network

Return type

nn.Module

packnet_sfm.models.model_wrapper.setup_model(config, prepared, **kwargs)[source]

Create a model

Parameters
  • config (CfgNode) – Model configuration (cf. configs/default_config.py)

  • prepared (bool) – True if the model has been prepared before

  • kwargs (dict) – Extra parameters for the model

Returns

model – Created model

Return type

nn.Module

packnet_sfm.models.model_wrapper.setup_pose_net(config, prepared, **kwargs)[source]

Create a pose network

Parameters
  • config (CfgNode) – Network configuration

  • prepared (bool) – True if the network has been prepared before

  • kwargs (dict) – Extra parameters for the network

Returns

pose_net – Created pose network

Return type

nn.Module

packnet_sfm.models.model_wrapper.worker_init_fn(worker_id)[source]

Function to initialize workers