import re
from collections import defaultdict
import os
from torch.utils.data import Dataset
import numpy as np
from packnet_sfm.utils.image import load_image
########################################################################################################################
#### FUNCTIONS
########################################################################################################################
[docs]def dummy_calibration(image):
w, h = [float(d) for d in image.size]
return np.array([[1000. , 0. , w / 2. - 0.5],
[0. , 1000. , h / 2. - 0.5],
[0. , 0. , 1. ]])
[docs]def get_idx(filename):
return int(re.search(r'\d+', filename).group())
[docs]def read_files(directory, ext=('.png', '.jpg', '.jpeg'), skip_empty=True):
files = defaultdict(list)
for entry in os.scandir(directory):
relpath = os.path.relpath(entry.path, directory)
if entry.is_dir():
d_files = read_files(entry.path, ext=ext, skip_empty=skip_empty)
if skip_empty and not len(d_files):
continue
files[relpath] = d_files[entry.path]
elif entry.is_file():
if ext is None or entry.path.lower().endswith(tuple(ext)):
files[directory].append(relpath)
return files
########################################################################################################################
#### DATASET
########################################################################################################################
[docs]class ImageDataset(Dataset):
def __init__(self, root_dir, split, data_transform=None,
forward_context=0, back_context=0, strides=(1,),
depth_type=None, **kwargs):
super().__init__()
# Asserts
assert depth_type is None or depth_type == '', \
'ImageDataset currently does not support depth types'
assert len(strides) == 1 and strides[0] == 1, \
'ImageDataset currently only supports stride of 1.'
self.root_dir = root_dir
self.split = split
self.backward_context = back_context
self.forward_context = forward_context
self.has_context = self.backward_context + self.forward_context > 0
self.strides = 1
self.files = []
file_tree = read_files(root_dir)
for k, v in file_tree.items():
file_set = set(file_tree[k])
files = [fname for fname in sorted(v) if self._has_context(fname, file_set)]
self.files.extend([[k, fname] for fname in files])
self.data_transform = data_transform
def __len__(self):
return len(self.files)
def _change_idx(self, idx, filename):
_, ext = os.path.splitext(os.path.basename(filename))
return self.split.format(idx) + ext
def _has_context(self, filename, file_set):
context_paths = self._get_context_file_paths(filename)
return all([f in file_set for f in context_paths])
def _get_context_file_paths(self, filename):
fidx = get_idx(filename)
idxs = list(np.arange(-self.backward_context * self.strides, 0, self.strides)) + \
list(np.arange(0, self.forward_context * self.strides, self.strides) + self.strides)
return [self._change_idx(fidx + i, filename) for i in idxs]
def _read_rgb_context_files(self, session, filename):
context_paths = self._get_context_file_paths(filename)
return [load_image(os.path.join(self.root_dir, session, filename))
for filename in context_paths]
def _read_rgb_file(self, session, filename):
return load_image(os.path.join(self.root_dir, session, filename))
def __getitem__(self, idx):
session, filename = self.files[idx]
image = self._read_rgb_file(session, filename)
sample = {
'idx': idx,
'filename': '%s_%s' % (session, os.path.splitext(filename)[0]),
#
'rgb': image,
'intrinsics': dummy_calibration(image)
}
if self.has_context:
sample['rgb_context'] = \
self._read_rgb_context_files(session, filename)
if self.data_transform:
sample = self.data_transform(sample)
return sample
########################################################################################################################