Source code for packnet_sfm.datasets.image_dataset


import re
from collections import defaultdict
import os

from torch.utils.data import Dataset
import numpy as np
from packnet_sfm.utils.image import load_image

########################################################################################################################
#### FUNCTIONS
########################################################################################################################

[docs]def dummy_calibration(image): w, h = [float(d) for d in image.size] return np.array([[1000. , 0. , w / 2. - 0.5], [0. , 1000. , h / 2. - 0.5], [0. , 0. , 1. ]])
[docs]def get_idx(filename): return int(re.search(r'\d+', filename).group())
[docs]def read_files(directory, ext=('.png', '.jpg', '.jpeg'), skip_empty=True): files = defaultdict(list) for entry in os.scandir(directory): relpath = os.path.relpath(entry.path, directory) if entry.is_dir(): d_files = read_files(entry.path, ext=ext, skip_empty=skip_empty) if skip_empty and not len(d_files): continue files[relpath] = d_files[entry.path] elif entry.is_file(): if ext is None or entry.path.lower().endswith(tuple(ext)): files[directory].append(relpath) return files
######################################################################################################################## #### DATASET ########################################################################################################################
[docs]class ImageDataset(Dataset): def __init__(self, root_dir, split, data_transform=None, forward_context=0, back_context=0, strides=(1,), depth_type=None, **kwargs): super().__init__() # Asserts assert depth_type is None or depth_type == '', \ 'ImageDataset currently does not support depth types' assert len(strides) == 1 and strides[0] == 1, \ 'ImageDataset currently only supports stride of 1.' self.root_dir = root_dir self.split = split self.backward_context = back_context self.forward_context = forward_context self.has_context = self.backward_context + self.forward_context > 0 self.strides = 1 self.files = [] file_tree = read_files(root_dir) for k, v in file_tree.items(): file_set = set(file_tree[k]) files = [fname for fname in sorted(v) if self._has_context(fname, file_set)] self.files.extend([[k, fname] for fname in files]) self.data_transform = data_transform def __len__(self): return len(self.files) def _change_idx(self, idx, filename): _, ext = os.path.splitext(os.path.basename(filename)) return self.split.format(idx) + ext def _has_context(self, filename, file_set): context_paths = self._get_context_file_paths(filename) return all([f in file_set for f in context_paths]) def _get_context_file_paths(self, filename): fidx = get_idx(filename) idxs = list(np.arange(-self.backward_context * self.strides, 0, self.strides)) + \ list(np.arange(0, self.forward_context * self.strides, self.strides) + self.strides) return [self._change_idx(fidx + i, filename) for i in idxs] def _read_rgb_context_files(self, session, filename): context_paths = self._get_context_file_paths(filename) return [load_image(os.path.join(self.root_dir, session, filename)) for filename in context_paths] def _read_rgb_file(self, session, filename): return load_image(os.path.join(self.root_dir, session, filename)) def __getitem__(self, idx): session, filename = self.files[idx] image = self._read_rgb_file(session, filename) sample = { 'idx': idx, 'filename': '%s_%s' % (session, os.path.splitext(filename)[0]), # 'rgb': image, 'intrinsics': dummy_calibration(image) } if self.has_context: sample['rgb_context'] = \ self._read_rgb_context_files(session, filename) if self.data_transform: sample = self.data_transform(sample) return sample
########################################################################################################################