Source code for medicaltorch.datasets

import os
import re
import collections

from medicaltorch import transforms as mt_transforms

from tqdm import tqdm
import numpy as np
import nibabel as nib

from torch.utils.data import Dataset
import torch
from torch._six import string_classes, int_classes

from PIL import Image


__numpy_type_map = {
    'float64': torch.DoubleTensor,
    'float32': torch.FloatTensor,
    'float16': torch.HalfTensor,
    'int64': torch.LongTensor,
    'int32': torch.IntTensor,
    'int16': torch.ShortTensor,
    'int8': torch.CharTensor,
    'uint8': torch.ByteTensor,
}


class SampleMetadata(object):
    def __init__(self, d=None):
        self.metadata = {} or d

    def __setitem__(self, key, value):
        self.metadata[key] = value

    def __getitem__(self, key):
        return self.metadata[key]

    def __contains__(self, key):
        return key in self.metadata

    def keys(self):
        return self.metadata.keys()


class BatchSplit(object):
    def __init__(self, batch):
        self.batch = batch

    def __iter__(self):
        batch_len = len(self.batch["input"])
        for i in range(batch_len):
            single_sample = {k: v[i] for k, v in self.batch.items()}
            single_sample['index'] = i
            yield single_sample
        raise StopIteration


[docs]class SegmentationPair2D(object): """This class is used to build 2D segmentation datasets. It represents a pair of of two data volumes (the input data and the ground truth data). :param input_filename: the input filename (supported by nibabel). :param gt_filename: the ground-truth filename. :param cache: if the data should be cached in memory or not. :param canonical: canonical reordering of the volume axes. """ def __init__(self, input_filename, gt_filename, cache=True, canonical=False): self.input_filename = input_filename self.gt_filename = gt_filename self.canonical = canonical self.cache = cache self.input_handle = nib.load(self.input_filename) # Unlabeled data (inference time) if self.gt_filename is None: self.gt_handle = None else: self.gt_handle = nib.load(self.gt_filename) if len(self.input_handle.shape) > 3: raise RuntimeError("4-dimensional volumes not supported.") # Sanity check for dimensions, should be the same input_shape, gt_shape = self.get_pair_shapes() if self.gt_handle is not None: if not np.allclose(input_shape, gt_shape): raise RuntimeError('Input and ground truth with different dimensions.') if self.canonical: self.input_handle = nib.as_closest_canonical(self.input_handle) # Unlabeled data if self.gt_handle is not None: self.gt_handle = nib.as_closest_canonical(self.gt_handle)
[docs] def get_pair_shapes(self): """Return the tuple (input, ground truth) representing both the input and ground truth shapes.""" input_shape = self.input_handle.header.get_data_shape() # Handle unlabeled data if self.gt_handle is None: gt_shape = None else: gt_shape = self.gt_handle.header.get_data_shape() return input_shape, gt_shape
[docs] def get_pair_data(self): """Return the tuble (input, ground truth) with the data content in numpy array.""" cache_mode = 'fill' if self.cache else 'unchanged' input_data = self.input_handle.get_fdata(cache_mode, dtype=np.float32) # Handle unlabeled data if self.gt_handle is None: gt_data = None else: gt_data = self.gt_handle.get_fdata(cache_mode, dtype=np.float32) return input_data, gt_data
[docs] def get_pair_slice(self, slice_index, slice_axis=2): """Return the specified slice from (input, ground truth). :param slice_index: the slice number. :param slice_axis: axis to make the slicing. """ if self.cache: input_dataobj, gt_dataobj = self.get_pair_data() else: # use dataobj to avoid caching input_dataobj = self.input_handle.dataobj if self.gt_handle is None: gt_dataobj = None else: gt_dataobj = self.gt_handle.dataobj if slice_axis not in [0, 1, 2]: raise RuntimeError("Invalid axis, must be between 0 and 2.") if slice_axis == 2: input_slice = np.asarray(input_dataobj[..., slice_index], dtype=np.float32) elif slice_axis == 1: input_slice = np.asarray(input_dataobj[:, slice_index, ...], dtype=np.float32) elif slice_axis == 0: input_slice = np.asarray(input_dataobj[slice_index, ...], dtype=np.float32) # Handle the case for unlabeled data gt_meta_dict = None if self.gt_handle is None: gt_slice = None else: if slice_axis == 2: gt_slice = np.asarray(gt_dataobj[..., slice_index], dtype=np.float32) elif slice_axis == 1: gt_slice = np.asarray(gt_dataobj[:, slice_index, ...], dtype=np.float32) elif slice_axis == 0: gt_slice = np.asarray(gt_dataobj[slice_index, ...], dtype=np.float32) gt_meta_dict = SampleMetadata({ "zooms": self.gt_handle.header.get_zooms()[:2], "data_shape": self.gt_handle.header.get_data_shape()[:2], }) input_meta_dict = SampleMetadata({ "zooms": self.input_handle.header.get_zooms()[:2], "data_shape": self.input_handle.header.get_data_shape()[:2], }) dreturn = { "input": input_slice, "gt": gt_slice, "input_metadata": input_meta_dict, "gt_metadata": gt_meta_dict, } return dreturn
[docs]class MRI2DSegmentationDataset(Dataset): """This is a generic class for 2D (slice-wise) segmentation datasets. :param filename_pairs: a list of tuples in the format (input filename, ground truth filename). :param slice_axis: axis to make the slicing (default axial). :param cache: if the data should be cached in memory or not. :param transform: transformations to apply. """ def __init__(self, filename_pairs, slice_axis=2, cache=True, transform=None, slice_filter_fn=None, canonical=False): self.filename_pairs = filename_pairs self.handlers = [] self.indexes = [] self.transform = transform self.cache = cache self.slice_axis = slice_axis self.slice_filter_fn = slice_filter_fn self.canonical = canonical self._load_filenames() self._prepare_indexes() def _load_filenames(self): for input_filename, gt_filename in self.filename_pairs: segpair = SegmentationPair2D(input_filename, gt_filename, self.cache, self.canonical) self.handlers.append(segpair) def _prepare_indexes(self): for segpair in self.handlers: input_data_shape, _ = segpair.get_pair_shapes() for segpair_slice in range(input_data_shape[2]): # Check if slice pair should be used or not if self.slice_filter_fn: slice_pair = segpair.get_pair_slice(segpair_slice, self.slice_axis) filter_fn_ret = self.slice_filter_fn(slice_pair) if not filter_fn_ret: continue item = (segpair, segpair_slice) self.indexes.append(item)
[docs] def set_transform(self, transform): """This method will replace the current transformation for the dataset. :param transform: the new transformation """ self.transform = transform
[docs] def compute_mean_std(self, verbose=False): """Compute the mean and standard deviation of the entire dataset. :param verbose: if True, it will show a progress bar. :returns: tuple (mean, std dev) """ sum_intensities = 0.0 numel = 0 with DatasetManager(self, override_transform=mt_transforms.ToTensor()) as dset: pbar = tqdm(dset, desc="Mean calculation", disable=not verbose) for sample in pbar: input_data = sample['input'] sum_intensities += input_data.sum() numel += input_data.numel() pbar.set_postfix(mean="{:.2f}".format(sum_intensities / numel), refresh=False) training_mean = sum_intensities / numel sum_var = 0.0 numel = 0 pbar = tqdm(dset, desc="Std Dev calculation", disable=not verbose) for sample in pbar: input_data = sample['input'] sum_var += (input_data - training_mean).pow(2).sum() numel += input_data.numel() pbar.set_postfix(std="{:.2f}".format(np.sqrt(sum_var / numel)), refresh=False) training_std = np.sqrt(sum_var / numel) return training_mean.item(), training_std.item()
def __len__(self): """Return the dataset size.""" return len(self.indexes) def __getitem__(self, index): """Return the specific index pair slices (input, ground truth). :param index: slice index. """ segpair, segpair_slice = self.indexes[index] pair_slice = segpair.get_pair_slice(segpair_slice, self.slice_axis) # Consistency with torchvision, returning PIL Image # Using the "Float mode" of PIL, the only mode # supporting unbounded float32 values input_img = Image.fromarray(pair_slice["input"], mode='F') # Handle unlabeled data if pair_slice["gt"] is None: gt_img = None else: gt_img = Image.fromarray(pair_slice["gt"], mode='F') data_dict = { 'input': input_img, 'gt': gt_img, 'input_metadata': pair_slice['input_metadata'], 'gt_metadata': pair_slice['gt_metadata'], } if self.transform is not None: data_dict = self.transform(data_dict) return data_dict
class DatasetManager(object): def __init__(self, dataset, override_transform=None): self.dataset = dataset self.override_transform = override_transform self._transform_state = None def __enter__(self): if self.override_transform: self._transform_state = self.dataset.transform self.dataset.transform = self.override_transform return self.dataset def __exit__(self, *args): if self._transform_state: self.dataset.transform = self._transform_state
[docs]class SCGMChallenge2DTrain(MRI2DSegmentationDataset): """This is the Spinal Cord Gray Matter Challenge dataset. :param root_dir: the directory containing the training dataset. :param site_ids: a list of site ids to filter (i.e. [1, 3]). :param subj_ids: the list of subject ids to filter. :param rater_ids: the list of the rater ids to filter. :param transform: the transformations that should be applied. :param cache: if the data should be cached in memory or not. :param slice_axis: axis to make the slicing (default axial). .. note:: This dataset assumes that you only have one class in your ground truth mask (w/ 0's and 1's). It also doesn't automatically resample the dataset. .. seealso:: Prados, F., et al (2017). Spinal cord grey matter segmentation challenge. NeuroImage, 152, 312–329. https://doi.org/10.1016/j.neuroimage.2017.03.010 Challenge Website: http://cmictig.cs.ucl.ac.uk/spinal-cord-grey-matter-segmentation-challenge """ NUM_SITES = 4 NUM_SUBJECTS = 10 NUM_RATERS = 4 def __init__(self, root_dir, slice_axis=2, site_ids=None, subj_ids=None, rater_ids=None, cache=True, transform=None, slice_filter_fn=None, canonical=False, labeled=True): self.labeled = labeled self.root_dir = root_dir self.site_ids = site_ids or range(1, SCGMChallenge2DTrain.NUM_SITES + 1) self.subj_ids = subj_ids or range(1, SCGMChallenge2DTrain.NUM_SUBJECTS + 1) self.rater_ids = rater_ids or range(1, SCGMChallenge2DTrain.NUM_RATERS + 1) self.filename_pairs = [] for site_id in self.site_ids: for subj_id in self.subj_ids: if len(self.rater_ids) > 0: for rater_id in self.rater_ids: input_filename = self._build_train_input_filename(site_id, subj_id) gt_filename = self._build_train_input_filename(site_id, subj_id, rater_id) input_filename = os.path.join(self.root_dir, input_filename) gt_filename = os.path.join(self.root_dir, gt_filename) if not self.labeled: gt_filename = None self.filename_pairs.append((input_filename, gt_filename)) else: input_filename = self._build_train_input_filename(site_id, subj_id) gt_filename = None input_filename = os.path.join(self.root_dir, input_filename) if not self.labeled: gt_filename = None self.filename_pairs.append((input_filename, gt_filename)) super().__init__(self.filename_pairs, slice_axis, cache, transform, slice_filter_fn, canonical) @staticmethod def _build_train_input_filename(site_id, subj_id, rater_id=None): if rater_id is None: return "site{:d}-sc{:02d}-image.nii.gz".format(site_id, subj_id) else: return "site{:d}-sc{:02d}-mask-r{:d}.nii.gz".format(site_id, subj_id, rater_id)
[docs]class SCGMChallenge2DTest(MRI2DSegmentationDataset): """This is the Spinal Cord Gray Matter Challenge dataset. :param root_dir: the directory containing the test dataset. :param site_ids: a list of site ids to filter (i.e. [1, 3]). :param subj_ids: the list of subject ids to filter. :param transform: the transformations that should be applied. :param cache: if the data should be cached in memory or not. :param slice_axis: axis to make the slicing (default axial). .. note:: This dataset assumes that you only have one class in your ground truth mask (w/ 0's and 1's). It also doesn't automatically resample the dataset. .. seealso:: Prados, F., et al (2017). Spinal cord grey matter segmentation challenge. NeuroImage, 152, 312–329. https://doi.org/10.1016/j.neuroimage.2017.03.010 Challenge Website: http://cmictig.cs.ucl.ac.uk/spinal-cord-grey-matter-segmentation-challenge """ NUM_SITES = 4 NUM_SUBJECTS = 10 def __init__(self, root_dir, slice_axis=2, site_ids=None, subj_ids=None, cache=True, transform=None, slice_filter_fn=None, canonical=False): self.root_dir = root_dir self.site_ids = site_ids or range(1, SCGMChallenge2DTest.NUM_SITES + 1) self.subj_ids = subj_ids or range(11, 10 + SCGMChallenge2DTest.NUM_SUBJECTS + 1) self.filename_pairs = [] for site_id in self.site_ids: for subj_id in self.subj_ids: input_filename = self._build_train_input_filename(site_id, subj_id) gt_filename = None input_filename = os.path.join(self.root_dir, input_filename) if not os.path.exists(input_filename): raise RuntimeError("Path '{}' doesn't exist !".format(input_filename)) self.filename_pairs.append((input_filename, gt_filename)) super().__init__(self.filename_pairs, slice_axis, cache, transform, slice_filter_fn, canonical) @staticmethod def _build_train_input_filename(site_id, subj_id, rater_id=None): if rater_id is None: return "site{:d}-sc{:02d}-image.nii.gz".format(site_id, subj_id) else: return "site{:d}-sc{:02d}-mask-r{:d}.nii.gz".format(site_id, subj_id, rater_id)
def mt_collate(batch): error_msg = "batch must contain tensors, numbers, dicts or lists; found {}" elem_type = type(batch[0]) if torch.is_tensor(batch[0]): stacked = torch.stack(batch, 0) return stacked elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \ and elem_type.__name__ != 'string_': elem = batch[0] if elem_type.__name__ == 'ndarray': # array of string classes and object if re.search('[SaUO]', elem.dtype.str) is not None: raise TypeError(error_msg.format(elem.dtype)) return torch.stack([torch.from_numpy(b) for b in batch], 0) if elem.shape == (): # scalars py_type = float if elem.dtype.name.startswith('float') else int return __numpy_type_map[elem.dtype.name](list(map(py_type, batch))) elif isinstance(batch[0], int_classes): return torch.LongTensor(batch) elif isinstance(batch[0], float): return torch.DoubleTensor(batch) elif isinstance(batch[0], string_classes): return batch elif isinstance(batch[0], collections.Mapping): return {key: mt_collate([d[key] for d in batch]) for key in batch[0]} elif isinstance(batch[0], collections.Sequence): transposed = zip(*batch) return [mt_collate(samples) for samples in transposed] return batch