Source code for medicaltorch.transforms

import numpy as np
import numbers
import torchvision.transforms.functional as F
from torchvision import transforms
from PIL import Image

from scipy.ndimage.interpolation import map_coordinates
from scipy.ndimage.filters import gaussian_filter


class MTTransform(object):

    def __call__(self, sample):
        raise NotImplementedError("You need to implement the transform() method.")

    def undo_transform(self, sample):
        raise NotImplementedError("You need to implement the undo_transform() method.")


class UndoCompose(object):
    def __init__(self, compose):
        self.transforms = compose.transforms

    def __call__(self):
        for t in self.transforms:
            img = t.undo_transform(img)
        return img


class UndoTransform(object):
    def __init__(self, transform):
        self.transform = transform

    def __call__(self, sample):
        return self.transform.undo_transform(sample)


[docs]class ToTensor(MTTransform): """Convert a PIL image or numpy array to a PyTorch tensor.""" def __init__(self, labeled=True): self.labeled = labeled def __call__(self, sample): rdict = {} input_data = sample['input'] if isinstance(input_data, list): ret_input = [F.to_tensor(item) for item in input_data] else: ret_input = F.to_tensor(input_data) rdict['input'] = ret_input if self.labeled: gt_data = sample['gt'] if gt_data is not None: if isinstance(gt_data, list): ret_gt = [F.to_tensor(item) for item in gt_data] else: ret_gt = F.to_tensor(gt_data) rdict['gt'] = ret_gt sample.update(rdict) return sample
class ToPIL(MTTransform): def __init__(self, labeled=True): self.labeled = labeled def sample_transform(self, sample_data): # Numpy array if not isinstance(sample_data, np.ndarray): input_data_npy = sample_data.numpy() else: input_data_npy = sample_data input_data_npy = np.transpose(input_data_npy, (1, 2, 0)) input_data_npy = np.squeeze(input_data_npy, axis=2) input_data = Image.fromarray(input_data_npy, mode='F') return input_data def __call__(self, sample): rdict = {} input_data = sample['input'] if isinstance(input_data, list): ret_input = [self.sample_transform(item) for item in input_data] else: ret_input = self.sample_transform(input_data) rdict['input'] = ret_input if self.labeled: gt_data = sample['gt'] if isinstance(gt_data, list): ret_gt = [self.sample_transform(item) for item in gt_data] else: ret_gt = self.sample_transform(gt_data) rdict['gt'] = ret_gt sample.update(rdict) return sample class UnCenterCrop2D(MTTransform): def __init__(self, size, segmentation=True): self.size = size self.segmentation = segmentation def __call__(self, sample): input_data, gt_data = sample['input'], sample['gt'] input_metadata, gt_metadata = sample['input_metadata'], sample['gt_metadata'] (fh, fw, w, h) = input_metadata["__centercrop"] (fh, fw, w, h) = gt_metadata["__centercrop"] return sample
[docs]class CenterCrop2D(MTTransform): """Make a center crop of a specified size. :param segmentation: if it is a segmentation task. When this is True (default), the crop will also be applied to the ground truth. """ def __init__(self, size, labeled=True): self.size = size self.labeled = labeled @staticmethod def propagate_params(sample, params): input_metadata = sample['input_metadata'] input_metadata["__centercrop"] = params return input_metadata @staticmethod def get_params(sample): input_metadata = sample['input_metadata'] return input_metadata["__centercrop"] def __call__(self, sample): rdict = {} input_data = sample['input'] w, h = input_data.size th, tw = self.size fh = int(round((h - th) / 2.)) fw = int(round((w - tw) / 2.)) params = (fh, fw, w, h) self.propagate_params(sample, params) input_data = F.center_crop(input_data, self.size) rdict['input'] = input_data if self.labeled: gt_data = sample['gt'] gt_metadata = sample['gt_metadata'] gt_data = F.center_crop(gt_data, self.size) gt_metadata["__centercrop"] = (fh, fw, w, h) rdict['gt'] = gt_data sample.update(rdict) return sample def undo_transform(self, sample): rdict = {} input_data = sample['input'] fh, fw, w, h = self.get_params(sample) th, tw = self.size pad_left = fw pad_right = w - pad_left - tw pad_top = fh pad_bottom = h - pad_top - th padding = (pad_left, pad_top, pad_right, pad_bottom) input_data = F.pad(input_data, padding) rdict['input'] = input_data sample.update(rdict) return sample
[docs]class Normalize(MTTransform): """Normalize a tensor image with mean and standard deviation. :param mean: mean value. :param std: standard deviation value. """ def __init__(self, mean, std): self.mean = mean self.std = std def __call__(self, sample): input_data = sample['input'] input_data = F.normalize(input_data, self.mean, self.std) rdict = { 'input': input_data, } sample.update(rdict) return sample
[docs]class NormalizeInstance(MTTransform): """Normalize a tensor image with mean and standard deviation estimated from the sample itself. :param mean: mean value. :param std: standard deviation value. """ def __call__(self, sample): input_data = sample['input'] mean, std = input_data.mean(), input_data.std() input_data = F.normalize(input_data, [mean], [std]) rdict = { 'input': input_data, } sample.update(rdict) return sample
class RandomRotation(MTTransform): def __init__(self, degrees, resample=False, expand=False, center=None, labeled=True): if isinstance(degrees, numbers.Number): if degrees < 0: raise ValueError("If degrees is a single number, it must be positive.") self.degrees = (-degrees, degrees) else: if len(degrees) != 2: raise ValueError("If degrees is a sequence, it must be of len 2.") self.degrees = degrees self.resample = resample self.expand = expand self.center = center self.labeled = labeled @staticmethod def get_params(degrees): angle = np.random.uniform(degrees[0], degrees[1]) return angle def __call__(self, sample): rdict = {} input_data = sample['input'] angle = self.get_params(self.degrees) input_data = F.rotate(input_data, angle, self.resample, self.expand, self.center) rdict['input'] = input_data if self.labeled: gt_data = sample['gt'] gt_data = F.rotate(gt_data, angle, self.resample, self.expand, self.center) rdict['gt'] = gt_data sample.update(rdict) return sample class RandomAffine(MTTransform): def __init__(self, degrees, translate=None, scale=None, shear=None, resample=False, fillcolor=0, labeled=True): if isinstance(degrees, numbers.Number): if degrees < 0: raise ValueError("If degrees is a single number, it must be positive.") self.degrees = (-degrees, degrees) else: assert isinstance(degrees, (tuple, list)) and len(degrees) == 2, \ "degrees should be a list or tuple and it must be of length 2." self.degrees = degrees if translate is not None: assert isinstance(translate, (tuple, list)) and len(translate) == 2, \ "translate should be a list or tuple and it must be of length 2." for t in translate: if not (0.0 <= t <= 1.0): raise ValueError("translation values should be between 0 and 1") self.translate = translate if scale is not None: assert isinstance(scale, (tuple, list)) and len(scale) == 2, \ "scale should be a list or tuple and it must be of length 2." for s in scale: if s <= 0: raise ValueError("scale values should be positive") self.scale = scale if shear is not None: if isinstance(shear, numbers.Number): if shear < 0: raise ValueError("If shear is a single number, it must be positive.") self.shear = (-shear, shear) else: assert isinstance(shear, (tuple, list)) and len(shear) == 2, \ "shear should be a list or tuple and it must be of length 2." self.shear = shear else: self.shear = shear self.resample = resample self.fillcolor = fillcolor self.labeled = labeled @staticmethod def get_params(degrees, translate, scale_ranges, shears, img_size): """Get parameters for affine transformation Returns: sequence: params to be passed to the affine transformation """ angle = np.random.uniform(degrees[0], degrees[1]) if translate is not None: max_dx = translate[0] * img_size[0] max_dy = translate[1] * img_size[1] translations = (np.round(np.random.uniform(-max_dx, max_dx)), np.round(np.random.uniform(-max_dy, max_dy))) else: translations = (0, 0) if scale_ranges is not None: scale = np.random.uniform(scale_ranges[0], scale_ranges[1]) else: scale = 1.0 if shears is not None: shear = np.random.uniform(shears[0], shears[1]) else: shear = 0.0 return angle, translations, scale, shear def sample_augment(self, input_data, params): input_data = F.affine(input_data, *params, resample=self.resample, fillcolor=self.fillcolor) return input_data def label_augment(self, gt_data, params): gt_data = self.sample_augment(gt_data, params) np_gt_data = np.array(gt_data) np_gt_data[np_gt_data >= 0.5] = 1.0 np_gt_data[np_gt_data < 0.5] = 0.0 gt_data = Image.fromarray(np_gt_data, mode='F') return gt_data def __call__(self, sample): """ img (PIL Image): Image to be transformed. Returns: PIL Image: Affine transformed image. """ rdict = {} input_data = sample['input'] if isinstance(input_data, list): input_data_size = input_data[0].size else: input_data_size = input_data.size params = self.get_params(self.degrees, self.translate, self.scale, self.shear, input_data_size) if isinstance(input_data, list): ret_input = [self.sample_augment(item, params) for item in input_data] else: ret_input = self.sample_augment(input_data, params) rdict['input'] = ret_input if self.labeled: gt_data = sample['gt'] if isinstance(gt_data, list): ret_gt = [self.label_augment(item, params) for item in gt_data] else: ret_gt = self.label_augment(gt_data, params) rdict['gt'] = ret_gt sample.update(rdict) return sample class RandomTensorChannelShift(MTTransform): def __init__(self, shift_range): self.shift_range = shift_range @staticmethod def get_params(shift_range): sampled_value = np.random.uniform(shift_range[0], shift_range[1]) return sampled_value def sample_augment(self, input_data, params): np_input_data = np.array(input_data) np_input_data += params input_data = Image.fromarray(np_input_data, mode='F') return input_data def __call__(self, sample): input_data = sample['input'] params = self.get_params(self.shift_range) if isinstance(input_data, list): #ret_input = [self.sample_augment(item, params) # for item in input_data] # Augment just the image, not the mask # TODO: fix it later ret_input = [] ret_input.append(self.sample_augment(input_data[0], params)) ret_input.append(input_data[1]) else: ret_input = self.sample_augment(input_data, params) rdict = { 'input': ret_input, } sample.update(rdict) return sample class ElasticTransform(MTTransform): def __init__(self, alpha_range, sigma_range, p=0.5, labeled=True): self.alpha_range = alpha_range self.sigma_range = sigma_range self.labeled = labeled self.p = p @staticmethod def get_params(alpha, sigma): alpha = np.random.uniform(alpha[0], alpha[1]) sigma = np.random.uniform(sigma[0], sigma[1]) return alpha, sigma @staticmethod def elastic_transform(image, alpha, sigma): shape = image.shape dx = gaussian_filter((np.random.rand(*shape) * 2 - 1), sigma, mode="constant", cval=0) * alpha dy = gaussian_filter((np.random.rand(*shape) * 2 - 1), sigma, mode="constant", cval=0) * alpha x, y = np.meshgrid(np.arange(shape[0]), np.arange(shape[1]), indexing='ij') indices = np.reshape(x+dx, (-1, 1)), np.reshape(y+dy, (-1, 1)) return map_coordinates(image, indices, order=1).reshape(shape) def sample_augment(self, input_data, params): param_alpha, param_sigma = params np_input_data = np.array(input_data) np_input_data = self.elastic_transform(np_input_data, param_alpha, param_sigma) input_data = Image.fromarray(np_input_data, mode='F') return input_data def label_augment(self, gt_data, params): param_alpha, param_sigma = params np_gt_data = np.array(gt_data) np_gt_data = self.elastic_transform(np_gt_data, param_alpha, param_sigma) np_gt_data[np_gt_data >= 0.5] = 1.0 np_gt_data[np_gt_data < 0.5] = 0.0 gt_data = Image.fromarray(np_gt_data, mode='F') return gt_data def __call__(self, sample): rdict = {} if np.random.random() < self.p: input_data = sample['input'] params = self.get_params(self.alpha_range, self.sigma_range) if isinstance(input_data, list): ret_input = [self.sample_augment(item, params) for item in input_data] else: ret_input = self.sample_augment(input_data, params) rdict['input'] = ret_input if self.labeled: gt_data = sample['gt'] if isinstance(gt_data, list): ret_gt = [self.label_augment(item, params) for item in gt_data] else: ret_gt = self.label_augment(gt_data, params) rdict['gt'] = ret_gt sample.update(rdict) return sample # TODO: Resample should keep state after changing state. # By changing pixel dimensions, we should be # able to return later to the original space. class Resample(MTTransform): def __init__(self, wspace, hspace, interpolation=Image.BILINEAR, labeled=True): self.hspace = hspace self.wspace = wspace self.interpolation = interpolation self.labeled = labeled def __call__(self, sample): rdict = {} input_data = sample['input'] input_metadata = sample['input_metadata'] # Voxel dimension in mm hzoom, wzoom = input_metadata["zooms"] hshape, wshape = input_metadata["data_shape"] hfactor = hzoom / self.hspace wfactor = wzoom / self.wspace hshape_new = int(hshape * hfactor) wshape_new = int(wshape * wfactor) input_data = input_data.resize((wshape_new, hshape_new), resample=self.interpolation) rdict['input'] = input_data if self.labeled: gt_data = sample['gt'] gt_metadata = sample['gt_metadata'] gt_data = gt_data.resize((wshape_new, hshape_new), resample=self.interpolation) np_gt_data = np.array(gt_data) np_gt_data[np_gt_data >= 0.5] = 1.0 np_gt_data[np_gt_data < 0.5] = 0.0 gt_data = Image.fromarray(np_gt_data, mode='F') rdict['gt'] = gt_data sample.update(rdict) return sample class AdditiveGaussianNoise(MTTransform): def __init__(self, mean=0.0, std=0.01): self.mean = mean self.std = std def __call__(self, sample): rdict = {} input_data = sample['input'] noise = np.random.normal(self.mean, self.std, input_data.size) noise = noise.astype(np.float32) np_input_data = np.array(input_data) np_input_data += noise input_data = Image.fromarray(np_input_data, mode='F') rdict['input'] = input_data sample.update(rdict) return sample