
MedicalTorch is an open-source framework for PyTorch, implementing an extensive set of loaders, pre-processors and datasets for medical imaging.
Rationale & Philosophy¶
Nowadays there are a lot of repositories with code for parsing medical imaging data in PyTorch, however, some challenges still remain:
- many of these repositories don’t contain a single word of documentation;
- many are just for classification datasets;
- the majority of them aren’t maintained;
- most of them take your freedom due to design mistakes;
- many models in these repositories are locked in monolithic code where you cannot repurpose for your own goals;
- many of them don’t contain a single line of testing;
- missing examples on how to use them.
The idea of this framework is to provide an elegant design to solve the issues of medical imaging in PyTorch. The design principles of this framework are the following:
- easy reusable components;
- well-documented code and APIs;
- documentation with examples and manuals;
- extensive testing coverage;
- easy to integrate into your pipeline;
- support for a variety of medical imaging sources;
- close as possible to the PyTorch design.
With that in mind, there is a long road ahead and contributions are always welcome.
Changelog¶
In this section you’ll find information about what’s new in the newer releases of the project.
Release v.0.2¶
This is the release v.0.2 of medicaltorch ! This is mainly a bug-fix release with many important issues found by users and new tutorials. Thanks a lot for the contributors of this release !
Changes in this version:
- Fixed issue with missing re import, thanks @cclauss (issue #5);
- Fixed issue with async/non_blocking keywords, thanks @cclauss (issue #4);
- Added CircleCI continuous integration for testing;
- Added a new tutorial, thanks @omarsar (issue #10);
- Fixed issue with tensorboardx requirement (issue #11);
- Fixed issue with requirements file (issue #15);
- Added new tutorial notebook with examples for the DataLoader creation, thanks @MohitTare (issue #13);
Release v.0.1¶
This is the first release of the project.
Getting Started¶
In this section you’ll find a tutorial to learn more about MedicalTorch.
Installation¶
To install MedicalTorch, use pip (recommended method) or easy_install:
pip install medicaltorch
Tutorials¶
This notebook provides a walk-through on how to use medicaltorch for spinal cord gray matter segmentation. More tutorials are coming soon.
Architecture¶
Under construction.
Examples¶
In this section you can see various examples using MedicalTorch API.
U-Net with GM Segmentation Challenge¶
Please note that this example requires TensorboardX to write statistics into a TensorBoard format. You can install it with:
pip install tensorboardx
The example is described below:
from collections import defaultdict
import time
import os
import numpy as np
from tqdm import tqdm
from tensorboardX import SummaryWriter
from medicaltorch import datasets as mt_datasets
from medicaltorch import models as mt_models
from medicaltorch import transforms as mt_transforms
from medicaltorch import losses as mt_losses
from medicaltorch import metrics as mt_metrics
from medicaltorch import filters as mt_filters
import torch
from torchvision import transforms
from torch.utils.data import DataLoader
from torch import autograd, optim
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torchvision.utils as vutils
cudnn.benchmark = True
def threshold_predictions(predictions, thr=0.999):
thresholded_preds = predictions[:]
low_values_indices = thresholded_preds < thr
thresholded_preds[low_values_indices] = 0
low_values_indices = thresholded_preds >= thr
thresholded_preds[low_values_indices] = 1
return thresholded_preds
def run_main():
train_transform = transforms.Compose([
mt_transforms.CenterCrop2D((200, 200)),
mt_transforms.ElasticTransform(alpha_range=(28.0, 30.0),
sigma_range=(3.5, 4.0),
p=0.3),
mt_transforms.RandomAffine(degrees=4.6,
scale=(0.98, 1.02),
translate=(0.03, 0.03)),
mt_transforms.RandomTensorChannelShift((-0.10, 0.10)),
mt_transforms.ToTensor(),
mt_transforms.NormalizeInstance(),
])
val_transform = transforms.Compose([
mt_transforms.CenterCrop2D((200, 200)),
mt_transforms.ToTensor(),
mt_transforms.NormalizeInstance(),
])
# Here we assume that the SC GM Challenge data is inside the folder
# "../data" and it was previously resampled.
gmdataset_train = mt_datasets.SCGMChallenge2DTrain(root_dir="../data",
subj_ids=range(1, 9),
transform=train_transform,
slice_filter_fn=mt_filters.SliceFilter())
# Here we assume that the SC GM Challenge data is inside the folder
# "../data" and it was previously resampled.
gmdataset_val = mt_datasets.SCGMChallenge2DTrain(root_dir="../data",
subj_ids=range(9, 11),
transform=val_transform)
train_loader = DataLoader(gmdataset_train, batch_size=16,
shuffle=True, pin_memory=True,
collate_fn=mt_datasets.mt_collate,
num_workers=1)
val_loader = DataLoader(gmdataset_val, batch_size=16,
shuffle=True, pin_memory=True,
collate_fn=mt_datasets.mt_collate,
num_workers=1)
model = mt_models.Unet(drop_rate=0.4, bn_momentum=0.1)
model.cuda()
num_epochs = 200
initial_lr = 0.001
optimizer = optim.Adam(model.parameters(), lr=initial_lr)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, num_epochs)
writer = SummaryWriter(log_dir="log_exp")
for epoch in tqdm(range(1, num_epochs+1)):
start_time = time.time()
scheduler.step()
lr = scheduler.get_lr()[0]
writer.add_scalar('learning_rate', lr, epoch)
model.train()
train_loss_total = 0.0
num_steps = 0
for i, batch in enumerate(train_loader):
input_samples, gt_samples = batch["input"], batch["gt"]
var_input = input_samples.cuda()
var_gt = gt_samples.cuda(non_blocking=True)
preds = model(var_input)
loss = mt_losses.dice_loss(preds, var_gt)
train_loss_total += loss.item()
optimizer.zero_grad()
loss.backward()
optimizer.step()
num_steps += 1
if epoch % 5 == 0:
grid_img = vutils.make_grid(input_samples,
normalize=True,
scale_each=True)
writer.add_image('Input', grid_img, epoch)
grid_img = vutils.make_grid(preds.data.cpu(),
normalize=True,
scale_each=True)
writer.add_image('Predictions', grid_img, epoch)
grid_img = vutils.make_grid(gt_samples,
normalize=True,
scale_each=True)
writer.add_image('Ground Truth', grid_img, epoch)
train_loss_total_avg = train_loss_total / num_steps
model.eval()
val_loss_total = 0.0
num_steps = 0
metric_fns = [mt_metrics.dice_score,
mt_metrics.hausdorff_score,
mt_metrics.precision_score,
mt_metrics.recall_score,
mt_metrics.specificity_score,
mt_metrics.intersection_over_union,
mt_metrics.accuracy_score]
metric_mgr = mt_metrics.MetricManager(metric_fns)
for i, batch in enumerate(val_loader):
input_samples, gt_samples = batch["input"], batch["gt"]
with torch.no_grad():
var_input = input_samples.cuda()
var_gt = gt_samples.cuda(async=True)
preds = model(var_input)
loss = mt_losses.dice_loss(preds, var_gt)
val_loss_total += loss.item()
# Metrics computation
gt_npy = gt_samples.numpy().astype(np.uint8)
gt_npy = gt_npy.squeeze(axis=1)
preds = preds.data.cpu().numpy()
preds = threshold_predictions(preds)
preds = preds.astype(np.uint8)
preds = preds.squeeze(axis=1)
metric_mgr(preds, gt_npy)
num_steps += 1
metrics_dict = metric_mgr.get_results()
metric_mgr.reset()
writer.add_scalars('metrics', metrics_dict, epoch)
val_loss_total_avg = val_loss_total / num_steps
writer.add_scalars('losses', {
'val_loss': val_loss_total_avg,
'train_loss': train_loss_total_avg
}, epoch)
end_time = time.time()
total_time = end_time - start_time
tqdm.write("Epoch {} took {:.2f} seconds.".format(epoch, total_time))
writer.add_scalars('losses', {
'train_loss': train_loss_total_avg
}, epoch)
if __name__ == '__main__':
run_main()
API Documentation¶
All modules above listed are under the “medicaltorch” namespace.
Contents:
medicaltorch.datasets
– Datasets¶
-
class
medicaltorch.datasets.
MRI2DSegmentationDataset
(filename_pairs, slice_axis=2, cache=True, transform=None, slice_filter_fn=None, canonical=False)[source]¶ This is a generic class for 2D (slice-wise) segmentation datasets.
Parameters: - filename_pairs – a list of tuples in the format (input filename, ground truth filename).
- slice_axis – axis to make the slicing (default axial).
- cache – if the data should be cached in memory or not.
- transform – transformations to apply.
-
class
medicaltorch.datasets.
SCGMChallenge2DTest
(root_dir, slice_axis=2, site_ids=None, subj_ids=None, cache=True, transform=None, slice_filter_fn=None, canonical=False)[source]¶ This is the Spinal Cord Gray Matter Challenge dataset.
Parameters: - root_dir – the directory containing the test dataset.
- site_ids – a list of site ids to filter (i.e. [1, 3]).
- subj_ids – the list of subject ids to filter.
- transform – the transformations that should be applied.
- cache – if the data should be cached in memory or not.
- slice_axis – axis to make the slicing (default axial).
Note
This dataset assumes that you only have one class in your ground truth mask (w/ 0’s and 1’s). It also doesn’t automatically resample the dataset.
See also
Prados, F., et al (2017). Spinal cord grey matter segmentation challenge. NeuroImage, 152, 312–329. https://doi.org/10.1016/j.neuroimage.2017.03.010
Challenge Website: http://cmictig.cs.ucl.ac.uk/spinal-cord-grey-matter-segmentation-challenge
-
class
medicaltorch.datasets.
SCGMChallenge2DTrain
(root_dir, slice_axis=2, site_ids=None, subj_ids=None, rater_ids=None, cache=True, transform=None, slice_filter_fn=None, canonical=False, labeled=True)[source]¶ This is the Spinal Cord Gray Matter Challenge dataset.
Parameters: - root_dir – the directory containing the training dataset.
- site_ids – a list of site ids to filter (i.e. [1, 3]).
- subj_ids – the list of subject ids to filter.
- rater_ids – the list of the rater ids to filter.
- transform – the transformations that should be applied.
- cache – if the data should be cached in memory or not.
- slice_axis – axis to make the slicing (default axial).
Note
This dataset assumes that you only have one class in your ground truth mask (w/ 0’s and 1’s). It also doesn’t automatically resample the dataset.
See also
Prados, F., et al (2017). Spinal cord grey matter segmentation challenge. NeuroImage, 152, 312–329. https://doi.org/10.1016/j.neuroimage.2017.03.010
Challenge Website: http://cmictig.cs.ucl.ac.uk/spinal-cord-grey-matter-segmentation-challenge
-
class
medicaltorch.datasets.
SegmentationPair2D
(input_filename, gt_filename, cache=True, canonical=False)[source]¶ This class is used to build 2D segmentation datasets. It represents a pair of of two data volumes (the input data and the ground truth data).
Parameters: - input_filename – the input filename (supported by nibabel).
- gt_filename – the ground-truth filename.
- cache – if the data should be cached in memory or not.
- canonical – canonical reordering of the volume axes.
-
get_pair_data
()[source]¶ Return the tuble (input, ground truth) with the data content in numpy array.
medicaltorch.transforms
– Transformations¶
-
class
medicaltorch.transforms.
CenterCrop2D
(size, labeled=True)[source]¶ Make a center crop of a specified size.
Parameters: segmentation – if it is a segmentation task. When this is True (default), the crop will also be applied to the ground truth.
-
class
medicaltorch.transforms.
Normalize
(mean, std)[source]¶ Normalize a tensor image with mean and standard deviation.
Parameters: - mean – mean value.
- std – standard deviation value.
medicaltorch.metrics
– Metrics¶
medicaltorch.models
– Models¶
-
class
medicaltorch.models.
NoPoolASPP
(drop_rate=0.4, bn_momentum=0.1, base_num_filters=64)[source]¶ An ASPP-based model without initial pooling layers.
Parameters: - drop_rate – dropout rate.
- bn_momentum – batch normalization momentum.
See also
Perone, C. S., et al (2017). Spinal cord gray matter segmentation using deep dilated convolutions. Nature Scientific Reports link: https://www.nature.com/articles/s41598-018-24304-3
-
class
medicaltorch.models.
Unet
(drop_rate=0.4, bn_momentum=0.1)[source]¶ A reference U-Net model.
See also
Ronneberger, O., et al (2015). U-Net: Convolutional Networks for Biomedical Image Segmentation ArXiv link: https://arxiv.org/abs/1505.04597
medicaltorch.losses
– Losses¶
Contributors¶
We are very thankful to all our contributors ! For a complete list of contributors, please see the official Contributors list in the Github repository.
Contribute or Report a bug¶
MedicalTorch is an open-source project created and maintained by Christian S. Perone. You can contribute by donating or helping with a pull-request or a bug report. You can get the source-code of the project in the Github project page.
License¶
Apache 2.0 License:
Copyright 2018 Christian S. Perone
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.