Source code for packnet_sfm.trainers.base_trainer

# Copyright 2020 Toyota Research Institute.  All rights reserved.

import torch
from tqdm import tqdm
from packnet_sfm.utils.logging import prepare_dataset_prefix


[docs]def sample_to_cuda(data, dtype=None): if isinstance(data, str): return data elif isinstance(data, dict): return {key: sample_to_cuda(data[key], dtype) for key in data.keys()} elif isinstance(data, list): return [sample_to_cuda(val, dtype) for val in data] else: # only convert floats (e.g., to half), otherwise preserve (e.g, ints) dtype = dtype if torch.is_floating_point(data) else None return data.to('cuda', dtype=dtype)
[docs]class BaseTrainer: def __init__(self, min_epochs=0, max_epochs=50, checkpoint=None, **kwargs): self.min_epochs = min_epochs self.max_epochs = max_epochs self.checkpoint = checkpoint self.module = None @property def proc_rank(self): raise NotImplementedError('Not implemented for BaseTrainer') @property def world_size(self): raise NotImplementedError('Not implemented for BaseTrainer') @property def is_rank_0(self): return self.proc_rank == 0
[docs] def check_and_save(self, module, output): if self.checkpoint: self.checkpoint.check_and_save(module, output)
[docs] def train_progress_bar(self, dataloader, config, ncols=120): return tqdm(enumerate(dataloader, 0), unit=' images', unit_scale=self.world_size * config.batch_size, total=len(dataloader), smoothing=0, disable=not self.is_rank_0, ncols=ncols, )
[docs] def val_progress_bar(self, dataloader, config, n=0, ncols=120): return tqdm(enumerate(dataloader, 0), unit=' images', unit_scale=self.world_size * config.batch_size, total=len(dataloader), smoothing=0, disable=not self.is_rank_0, ncols=ncols, desc=prepare_dataset_prefix(config, n) )
[docs] def test_progress_bar(self, dataloader, config, n=0, ncols=120): return tqdm(enumerate(dataloader, 0), unit=' images', unit_scale=self.world_size * config.batch_size, total=len(dataloader), smoothing=0, disable=not self.is_rank_0, ncols=ncols, desc=prepare_dataset_prefix(config, n) )