_images/logo_hr.png

MedicalTorch is an open-source framework for PyTorch, implementing an extensive set of loaders, pre-processors and datasets for medical imaging.

Rationale & Philosophy

Nowadays there are a lot of repositories with code for parsing medical imaging data in PyTorch, however, some challenges still remain:

  • many of these repositories don’t contain a single word of documentation;
  • many are just for classification datasets;
  • the majority of them aren’t maintained;
  • most of them take your freedom due to design mistakes;
  • many models in these repositories are locked in monolithic code where you cannot repurpose for your own goals;
  • many of them don’t contain a single line of testing;
  • missing examples on how to use them.

The idea of this framework is to provide an elegant design to solve the issues of medical imaging in PyTorch. The design principles of this framework are the following:

  • easy reusable components;
  • well-documented code and APIs;
  • documentation with examples and manuals;
  • extensive testing coverage;
  • easy to integrate into your pipeline;
  • support for a variety of medical imaging sources;
  • close as possible to the PyTorch design.

With that in mind, there is a long road ahead and contributions are always welcome.

Changelog

In this section you’ll find information about what’s new in the newer releases of the project.

Release v.0.2

This is the release v.0.2 of medicaltorch ! This is mainly a bug-fix release with many important issues found by users and new tutorials. Thanks a lot for the contributors of this release !

Changes in this version:

  • Fixed issue with missing re import, thanks @cclauss (issue #5);
  • Fixed issue with async/non_blocking keywords, thanks @cclauss (issue #4);
  • Added CircleCI continuous integration for testing;
  • Added a new tutorial, thanks @omarsar (issue #10);
  • Fixed issue with tensorboardx requirement (issue #11);
  • Fixed issue with requirements file (issue #15);
  • Added new tutorial notebook with examples for the DataLoader creation, thanks @MohitTare (issue #13);

Release v.0.1

This is the first release of the project.

Getting Started

In this section you’ll find a tutorial to learn more about MedicalTorch.

Installation

To install MedicalTorch, use pip (recommended method) or easy_install:

pip install medicaltorch

Tutorials

This notebook provides a walk-through on how to use medicaltorch for spinal cord gray matter segmentation. More tutorials are coming soon.

Architecture

Under construction.

Examples

In this section you can see various examples using MedicalTorch API.

U-Net with GM Segmentation Challenge

Please note that this example requires TensorboardX to write statistics into a TensorBoard format. You can install it with:

pip install tensorboardx

The example is described below:

from collections import defaultdict
import time
import os

import numpy as np

from tqdm import tqdm

from tensorboardX import SummaryWriter

from medicaltorch import datasets as mt_datasets
from medicaltorch import models as mt_models
from medicaltorch import transforms as mt_transforms
from medicaltorch import losses as mt_losses
from medicaltorch import metrics as mt_metrics
from medicaltorch import filters as mt_filters

import torch
from torchvision import transforms
from torch.utils.data import DataLoader
from torch import autograd, optim
import torch.backends.cudnn as cudnn
import torch.nn as nn

import torchvision.utils as vutils

cudnn.benchmark = True


def threshold_predictions(predictions, thr=0.999):
    thresholded_preds = predictions[:]
    low_values_indices = thresholded_preds < thr
    thresholded_preds[low_values_indices] = 0
    low_values_indices = thresholded_preds >= thr
    thresholded_preds[low_values_indices] = 1
    return thresholded_preds


def run_main():
    train_transform = transforms.Compose([
        mt_transforms.CenterCrop2D((200, 200)),
        mt_transforms.ElasticTransform(alpha_range=(28.0, 30.0),
                                       sigma_range=(3.5, 4.0),
                                       p=0.3),
        mt_transforms.RandomAffine(degrees=4.6,
                                   scale=(0.98, 1.02),
                                   translate=(0.03, 0.03)),
        mt_transforms.RandomTensorChannelShift((-0.10, 0.10)),
        mt_transforms.ToTensor(),
        mt_transforms.NormalizeInstance(),
    ])

    val_transform = transforms.Compose([
        mt_transforms.CenterCrop2D((200, 200)),
        mt_transforms.ToTensor(),
        mt_transforms.NormalizeInstance(),
    ])

    # Here we assume that the SC GM Challenge data is inside the folder
    # "../data" and it was previously resampled.
    gmdataset_train = mt_datasets.SCGMChallenge2DTrain(root_dir="../data",
                                                       subj_ids=range(1, 9),
                                                       transform=train_transform,
                                                       slice_filter_fn=mt_filters.SliceFilter())

    # Here we assume that the SC GM Challenge data is inside the folder
    # "../data" and it was previously resampled.
    gmdataset_val = mt_datasets.SCGMChallenge2DTrain(root_dir="../data",
                                                     subj_ids=range(9, 11),
                                                     transform=val_transform)

    train_loader = DataLoader(gmdataset_train, batch_size=16,
                              shuffle=True, pin_memory=True,
                              collate_fn=mt_datasets.mt_collate,
                              num_workers=1)

    val_loader = DataLoader(gmdataset_val, batch_size=16,
                            shuffle=True, pin_memory=True,
                            collate_fn=mt_datasets.mt_collate,
                            num_workers=1)

    model = mt_models.Unet(drop_rate=0.4, bn_momentum=0.1)
    model.cuda()

    num_epochs = 200
    initial_lr = 0.001

    optimizer = optim.Adam(model.parameters(), lr=initial_lr)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, num_epochs)

    writer = SummaryWriter(log_dir="log_exp")
    for epoch in tqdm(range(1, num_epochs+1)):
        start_time = time.time()

        scheduler.step()

        lr = scheduler.get_lr()[0]
        writer.add_scalar('learning_rate', lr, epoch)

        model.train()
        train_loss_total = 0.0
        num_steps = 0
        for i, batch in enumerate(train_loader):
            input_samples, gt_samples = batch["input"], batch["gt"]

            var_input = input_samples.cuda()
            var_gt = gt_samples.cuda(non_blocking=True)

            preds = model(var_input)

            loss = mt_losses.dice_loss(preds, var_gt)
            train_loss_total += loss.item()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            num_steps += 1

            if epoch % 5 == 0:
                grid_img = vutils.make_grid(input_samples,
                                            normalize=True,
                                            scale_each=True)
                writer.add_image('Input', grid_img, epoch)

                grid_img = vutils.make_grid(preds.data.cpu(),
                                            normalize=True,
                                            scale_each=True)
                writer.add_image('Predictions', grid_img, epoch)

                grid_img = vutils.make_grid(gt_samples,
                                            normalize=True,
                                            scale_each=True)
                writer.add_image('Ground Truth', grid_img, epoch)

        train_loss_total_avg = train_loss_total / num_steps

        model.eval()
        val_loss_total = 0.0
        num_steps = 0

        metric_fns = [mt_metrics.dice_score,
                      mt_metrics.hausdorff_score,
                      mt_metrics.precision_score,
                      mt_metrics.recall_score,
                      mt_metrics.specificity_score,
                      mt_metrics.intersection_over_union,
                      mt_metrics.accuracy_score]

        metric_mgr = mt_metrics.MetricManager(metric_fns)

        for i, batch in enumerate(val_loader):
            input_samples, gt_samples = batch["input"], batch["gt"]

            with torch.no_grad():
                var_input = input_samples.cuda()
                var_gt = gt_samples.cuda(async=True)

                preds = model(var_input)
                loss = mt_losses.dice_loss(preds, var_gt)
                val_loss_total += loss.item()

            # Metrics computation
            gt_npy = gt_samples.numpy().astype(np.uint8)
            gt_npy = gt_npy.squeeze(axis=1)

            preds = preds.data.cpu().numpy()
            preds = threshold_predictions(preds)
            preds = preds.astype(np.uint8)
            preds = preds.squeeze(axis=1)

            metric_mgr(preds, gt_npy)

            num_steps += 1

        metrics_dict = metric_mgr.get_results()
        metric_mgr.reset()

        writer.add_scalars('metrics', metrics_dict, epoch)

        val_loss_total_avg = val_loss_total / num_steps

        writer.add_scalars('losses', {
                                'val_loss': val_loss_total_avg,
                                'train_loss': train_loss_total_avg
                            }, epoch)

        end_time = time.time()
        total_time = end_time - start_time
        tqdm.write("Epoch {} took {:.2f} seconds.".format(epoch, total_time))

        writer.add_scalars('losses', {
                                'train_loss': train_loss_total_avg
                            }, epoch)


if __name__ == '__main__':
    run_main()

Dataloader Tutorial for NIFTI images

The tutorial for creating a dataloader using medicaltorch can be found here .

API Documentation

All modules above listed are under the “medicaltorch” namespace.

Contents:

medicaltorch.datasets – Datasets

class medicaltorch.datasets.MRI2DSegmentationDataset(filename_pairs, slice_axis=2, cache=True, transform=None, slice_filter_fn=None, canonical=False)[source]

This is a generic class for 2D (slice-wise) segmentation datasets.

Parameters:
  • filename_pairs – a list of tuples in the format (input filename, ground truth filename).
  • slice_axis – axis to make the slicing (default axial).
  • cache – if the data should be cached in memory or not.
  • transform – transformations to apply.
compute_mean_std(verbose=False)[source]

Compute the mean and standard deviation of the entire dataset.

Parameters:verbose – if True, it will show a progress bar.
Returns:tuple (mean, std dev)
set_transform(transform)[source]

This method will replace the current transformation for the dataset.

Parameters:transform – the new transformation
class medicaltorch.datasets.SCGMChallenge2DTest(root_dir, slice_axis=2, site_ids=None, subj_ids=None, cache=True, transform=None, slice_filter_fn=None, canonical=False)[source]

This is the Spinal Cord Gray Matter Challenge dataset.

Parameters:
  • root_dir – the directory containing the test dataset.
  • site_ids – a list of site ids to filter (i.e. [1, 3]).
  • subj_ids – the list of subject ids to filter.
  • transform – the transformations that should be applied.
  • cache – if the data should be cached in memory or not.
  • slice_axis – axis to make the slicing (default axial).

Note

This dataset assumes that you only have one class in your ground truth mask (w/ 0’s and 1’s). It also doesn’t automatically resample the dataset.

See also

Prados, F., et al (2017). Spinal cord grey matter segmentation challenge. NeuroImage, 152, 312–329. https://doi.org/10.1016/j.neuroimage.2017.03.010

Challenge Website: http://cmictig.cs.ucl.ac.uk/spinal-cord-grey-matter-segmentation-challenge

class medicaltorch.datasets.SCGMChallenge2DTrain(root_dir, slice_axis=2, site_ids=None, subj_ids=None, rater_ids=None, cache=True, transform=None, slice_filter_fn=None, canonical=False, labeled=True)[source]

This is the Spinal Cord Gray Matter Challenge dataset.

Parameters:
  • root_dir – the directory containing the training dataset.
  • site_ids – a list of site ids to filter (i.e. [1, 3]).
  • subj_ids – the list of subject ids to filter.
  • rater_ids – the list of the rater ids to filter.
  • transform – the transformations that should be applied.
  • cache – if the data should be cached in memory or not.
  • slice_axis – axis to make the slicing (default axial).

Note

This dataset assumes that you only have one class in your ground truth mask (w/ 0’s and 1’s). It also doesn’t automatically resample the dataset.

See also

Prados, F., et al (2017). Spinal cord grey matter segmentation challenge. NeuroImage, 152, 312–329. https://doi.org/10.1016/j.neuroimage.2017.03.010

Challenge Website: http://cmictig.cs.ucl.ac.uk/spinal-cord-grey-matter-segmentation-challenge

class medicaltorch.datasets.SegmentationPair2D(input_filename, gt_filename, cache=True, canonical=False)[source]

This class is used to build 2D segmentation datasets. It represents a pair of of two data volumes (the input data and the ground truth data).

Parameters:
  • input_filename – the input filename (supported by nibabel).
  • gt_filename – the ground-truth filename.
  • cache – if the data should be cached in memory or not.
  • canonical – canonical reordering of the volume axes.
get_pair_data()[source]

Return the tuble (input, ground truth) with the data content in numpy array.

get_pair_shapes()[source]

Return the tuple (input, ground truth) representing both the input and ground truth shapes.

get_pair_slice(slice_index, slice_axis=2)[source]

Return the specified slice from (input, ground truth).

Parameters:
  • slice_index – the slice number.
  • slice_axis – axis to make the slicing.

medicaltorch.transforms – Transformations

class medicaltorch.transforms.CenterCrop2D(size, labeled=True)[source]

Make a center crop of a specified size.

Parameters:segmentation – if it is a segmentation task. When this is True (default), the crop will also be applied to the ground truth.
class medicaltorch.transforms.Normalize(mean, std)[source]

Normalize a tensor image with mean and standard deviation.

Parameters:
  • mean – mean value.
  • std – standard deviation value.
class medicaltorch.transforms.NormalizeInstance[source]

Normalize a tensor image with mean and standard deviation estimated from the sample itself.

Parameters:
  • mean – mean value.
  • std – standard deviation value.
class medicaltorch.transforms.ToTensor(labeled=True)[source]

Convert a PIL image or numpy array to a PyTorch tensor.

medicaltorch.metrics – Metrics

medicaltorch.metrics.numeric_score(prediction, groundtruth)[source]

Computation of statistical numerical scores:

  • FP = False Positives
  • FN = False Negatives
  • TP = True Positives
  • TN = True Negatives

return: tuple (FP, FN, TP, TN)

medicaltorch.models – Models

class medicaltorch.models.NoPoolASPP(drop_rate=0.4, bn_momentum=0.1, base_num_filters=64)[source]
_images/nopool_aspp_arch.png

An ASPP-based model without initial pooling layers.

Parameters:
  • drop_rate – dropout rate.
  • bn_momentum – batch normalization momentum.

See also

Perone, C. S., et al (2017). Spinal cord gray matter segmentation using deep dilated convolutions. Nature Scientific Reports link: https://www.nature.com/articles/s41598-018-24304-3

forward(x)[source]

Model forward pass.

Parameters:x – input data.
class medicaltorch.models.Unet(drop_rate=0.4, bn_momentum=0.1)[source]

A reference U-Net model.

See also

Ronneberger, O., et al (2015). U-Net: Convolutional Networks for Biomedical Image Segmentation ArXiv link: https://arxiv.org/abs/1505.04597

medicaltorch.losses – Losses

class medicaltorch.losses.MaskedDiceLoss(ignore_value=-100.0)[source]

A masked version of the Dice loss.

Parameters:ignore_value – the value to ignore.
medicaltorch.losses.dice_loss(input, target)[source]

Dice loss.

Parameters:
  • input – The input (predicted)
  • target – The target (ground truth)
Returns:

the Dice score between 0 and 1.

Contributors

We are very thankful to all our contributors ! For a complete list of contributors, please see the official Contributors list in the Github repository.

Contribute or Report a bug

MedicalTorch is an open-source project created and maintained by Christian S. Perone. You can contribute by donating or helping with a pull-request or a bug report. You can get the source-code of the project in the Github project page.

License

Apache 2.0 License:

Copyright 2018 Christian S. Perone

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

Indices and tables