import torch
import torch.nn as nn
from torch.nn import Module
import torch.nn.functional as F
[docs]class NoPoolASPP(Module):
"""
.. image:: _static/img/nopool_aspp_arch.png
:align: center
:scale: 25%
An ASPP-based model without initial pooling layers.
:param drop_rate: dropout rate.
:param bn_momentum: batch normalization momentum.
.. seealso::
Perone, C. S., et al (2017). Spinal cord gray matter
segmentation using deep dilated convolutions.
Nature Scientific Reports link:
https://www.nature.com/articles/s41598-018-24304-3
"""
def __init__(self, drop_rate=0.4, bn_momentum=0.1,
base_num_filters=64):
super().__init__()
self.conv1a = nn.Conv2d(1, base_num_filters, kernel_size=3, padding=1)
self.conv1a_bn = nn.BatchNorm2d(base_num_filters, momentum=bn_momentum)
self.conv1a_drop = nn.Dropout2d(drop_rate)
self.conv1b = nn.Conv2d(base_num_filters, base_num_filters, kernel_size=3, padding=1)
self.conv1b_bn = nn.BatchNorm2d(base_num_filters, momentum=bn_momentum)
self.conv1b_drop = nn.Dropout2d(drop_rate)
self.conv2a = nn.Conv2d(base_num_filters, base_num_filters, kernel_size=3, padding=2, dilation=2)
self.conv2a_bn = nn.BatchNorm2d(base_num_filters, momentum=bn_momentum)
self.conv2a_drop = nn.Dropout2d(drop_rate)
self.conv2b = nn.Conv2d(base_num_filters, base_num_filters, kernel_size=3, padding=2, dilation=2)
self.conv2b_bn = nn.BatchNorm2d(base_num_filters, momentum=bn_momentum)
self.conv2b_drop = nn.Dropout2d(drop_rate)
# Branch 1x1 convolution
self.branch1a = nn.Conv2d(base_num_filters, base_num_filters, kernel_size=3, padding=1)
self.branch1a_bn = nn.BatchNorm2d(base_num_filters, momentum=bn_momentum)
self.branch1a_drop = nn.Dropout2d(drop_rate)
self.branch1b = nn.Conv2d(base_num_filters, base_num_filters, kernel_size=1)
self.branch1b_bn = nn.BatchNorm2d(base_num_filters, momentum=bn_momentum)
self.branch1b_drop = nn.Dropout2d(drop_rate)
# Branch for 3x3 rate 6
self.branch2a = nn.Conv2d(base_num_filters, base_num_filters, kernel_size=3, padding=6, dilation=6)
self.branch2a_bn = nn.BatchNorm2d(base_num_filters, momentum=bn_momentum)
self.branch2a_drop = nn.Dropout2d(drop_rate)
self.branch2b = nn.Conv2d(base_num_filters, base_num_filters, kernel_size=3, padding=6, dilation=6)
self.branch2b_bn = nn.BatchNorm2d(base_num_filters, momentum=bn_momentum)
self.branch2b_drop = nn.Dropout2d(drop_rate)
# Branch for 3x3 rate 12
self.branch3a = nn.Conv2d(base_num_filters, base_num_filters, kernel_size=3, padding=12, dilation=12)
self.branch3a_bn = nn.BatchNorm2d(base_num_filters, momentum=bn_momentum)
self.branch3a_drop = nn.Dropout2d(drop_rate)
self.branch3b = nn.Conv2d(base_num_filters, base_num_filters, kernel_size=3, padding=12, dilation=12)
self.branch3b_bn = nn.BatchNorm2d(base_num_filters, momentum=bn_momentum)
self.branch3b_drop = nn.Dropout2d(drop_rate)
# Branch for 3x3 rate 18
self.branch4a = nn.Conv2d(base_num_filters, base_num_filters, kernel_size=3, padding=18, dilation=18)
self.branch4a_bn = nn.BatchNorm2d(base_num_filters, momentum=bn_momentum)
self.branch4a_drop = nn.Dropout2d(drop_rate)
self.branch4b = nn.Conv2d(base_num_filters, base_num_filters, kernel_size=3, padding=18, dilation=18)
self.branch4b_bn = nn.BatchNorm2d(base_num_filters, momentum=bn_momentum)
self.branch4b_drop = nn.Dropout2d(drop_rate)
# Branch for 3x3 rate 24
self.branch5a = nn.Conv2d(base_num_filters, base_num_filters, kernel_size=3, padding=24, dilation=24)
self.branch5a_bn = nn.BatchNorm2d(base_num_filters, momentum=bn_momentum)
self.branch5a_drop = nn.Dropout2d(drop_rate)
self.branch5b = nn.Conv2d(base_num_filters, base_num_filters, kernel_size=3, padding=24, dilation=24)
self.branch5b_bn = nn.BatchNorm2d(base_num_filters, momentum=bn_momentum)
self.branch5b_drop = nn.Dropout2d(drop_rate)
self.concat_drop = nn.Dropout2d(drop_rate)
self.concat_bn = nn.BatchNorm2d(6*base_num_filters, momentum=bn_momentum)
self.amort = nn.Conv2d(6*base_num_filters, base_num_filters*2, kernel_size=1)
self.amort_bn = nn.BatchNorm2d(base_num_filters*2, momentum=bn_momentum)
self.amort_drop = nn.Dropout2d(drop_rate)
self.prediction = nn.Conv2d(base_num_filters*2, 1, kernel_size=1)
[docs] def forward(self, x):
"""Model forward pass.
:param x: input data.
"""
x = F.relu(self.conv1a(x))
x = self.conv1a_bn(x)
x = self.conv1a_drop(x)
x = F.relu(self.conv1b(x))
x = self.conv1b_bn(x)
x = self.conv1b_drop(x)
x = F.relu(self.conv2a(x))
x = self.conv2a_bn(x)
x = self.conv2a_drop(x)
x = F.relu(self.conv2b(x))
x = self.conv2b_bn(x)
x = self.conv2b_drop(x)
# Branch 1x1 convolution
branch1 = F.relu(self.branch1a(x))
branch1 = self.branch1a_bn(branch1)
branch1 = self.branch1a_drop(branch1)
branch1 = F.relu(self.branch1b(branch1))
branch1 = self.branch1b_bn(branch1)
branch1 = self.branch1b_drop(branch1)
# Branch for 3x3 rate 6
branch2 = F.relu(self.branch2a(x))
branch2 = self.branch2a_bn(branch2)
branch2 = self.branch2a_drop(branch2)
branch2 = F.relu(self.branch2b(branch2))
branch2 = self.branch2b_bn(branch2)
branch2 = self.branch2b_drop(branch2)
# Branch for 3x3 rate 6
branch3 = F.relu(self.branch3a(x))
branch3 = self.branch3a_bn(branch3)
branch3 = self.branch3a_drop(branch3)
branch3 = F.relu(self.branch3b(branch3))
branch3 = self.branch3b_bn(branch3)
branch3 = self.branch3b_drop(branch3)
# Branch for 3x3 rate 18
branch4 = F.relu(self.branch4a(x))
branch4 = self.branch4a_bn(branch4)
branch4 = self.branch4a_drop(branch4)
branch4 = F.relu(self.branch4b(branch4))
branch4 = self.branch4b_bn(branch4)
branch4 = self.branch4b_drop(branch4)
# Branch for 3x3 rate 24
branch5 = F.relu(self.branch5a(x))
branch5 = self.branch5a_bn(branch5)
branch5 = self.branch5a_drop(branch5)
branch5 = F.relu(self.branch5b(branch5))
branch5 = self.branch5b_bn(branch5)
branch5 = self.branch5b_drop(branch5)
# Global Average Pooling
global_pool = F.avg_pool2d(x, kernel_size=x.size()[2:])
global_pool = global_pool.expand(x.size())
concatenation = torch.cat([branch1,
branch2,
branch3,
branch4,
branch5,
global_pool], dim=1)
concatenation = self.concat_bn(concatenation)
concatenation = self.concat_drop(concatenation)
amort = F.relu(self.amort(concatenation))
amort = self.amort_bn(amort)
amort = self.amort_drop(amort)
predictions = self.prediction(amort)
predictions = torch.sigmoid(predictions)
return predictions
class DownConv(Module):
def __init__(self, in_feat, out_feat, drop_rate=0.4, bn_momentum=0.1):
super(DownConv, self).__init__()
self.conv1 = nn.Conv2d(in_feat, out_feat, kernel_size=3, padding=1)
self.conv1_bn = nn.BatchNorm2d(out_feat, momentum=bn_momentum)
self.conv1_drop = nn.Dropout2d(drop_rate)
self.conv2 = nn.Conv2d(out_feat, out_feat, kernel_size=3, padding=1)
self.conv2_bn = nn.BatchNorm2d(out_feat, momentum=bn_momentum)
self.conv2_drop = nn.Dropout2d(drop_rate)
def forward(self, x):
x = F.relu(self.conv1(x))
x = self.conv1_bn(x)
x = self.conv1_drop(x)
x = F.relu(self.conv2(x))
x = self.conv2_bn(x)
x = self.conv2_drop(x)
return x
class UpConv(Module):
def __init__(self, in_feat, out_feat, drop_rate=0.4, bn_momentum=0.1):
super(UpConv, self).__init__()
self.up1 = nn.functional.interpolate
self.downconv = DownConv(in_feat, out_feat, drop_rate, bn_momentum)
def forward(self, x, y):
x = self.up1(x, scale_factor=2, mode='bilinear', align_corners=True)
x = torch.cat([x, y], dim=1)
x = self.downconv(x)
return x
[docs]class Unet(Module):
"""A reference U-Net model.
.. seealso::
Ronneberger, O., et al (2015). U-Net: Convolutional
Networks for Biomedical Image Segmentation
ArXiv link: https://arxiv.org/abs/1505.04597
"""
def __init__(self, drop_rate=0.4, bn_momentum=0.1):
super(Unet, self).__init__()
#Downsampling path
self.conv1 = DownConv(1, 64, drop_rate, bn_momentum)
self.mp1 = nn.MaxPool2d(2)
self.conv2 = DownConv(64, 128, drop_rate, bn_momentum)
self.mp2 = nn.MaxPool2d(2)
self.conv3 = DownConv(128, 256, drop_rate, bn_momentum)
self.mp3 = nn.MaxPool2d(2)
# Bottom
self.conv4 = DownConv(256, 256, drop_rate, bn_momentum)
# Upsampling path
self.up1 = UpConv(512, 256, drop_rate, bn_momentum)
self.up2 = UpConv(384, 128, drop_rate, bn_momentum)
self.up3 = UpConv(192, 64, drop_rate, bn_momentum)
self.conv9 = nn.Conv2d(64, 1, kernel_size=3, padding=1)
def forward(self, x):
x1 = self.conv1(x)
x2 = self.mp1(x1)
x3 = self.conv2(x2)
x4 = self.mp2(x3)
x5 = self.conv3(x4)
x6 = self.mp3(x5)
# Bottom
x7 = self.conv4(x6)
# Up-sampling
x8 = self.up1(x7, x5)
x9 = self.up2(x8, x3)
x10 = self.up3(x9, x1)
x11 = self.conv9(x10)
preds = torch.sigmoid(x11)
return preds
class UNet3D(nn.Module):
"""A reference of 3D U-Net model.
Implementation origin :
https://github.com/shiba24/3d-unet/blob/master/pytorch/model.py
.. seealso::
Özgün Çiçek, Ahmed Abdulkadir, Soeren S. Lienkamp, Thomas Brox
and Olaf Ronneberger (2016). 3D U-Net: Learning Dense Volumetric
Segmentation from Sparse Annotation
ArXiv link: https://arxiv.org/pdf/1606.06650.pdf
"""
def __init__(self, in_channel, n_classes):
self.in_channel = in_channel
self.n_classes = n_classes
super(UNet3D, self).__init__()
self.ec0 = self.down_conv(self.in_channel, 32, bias=False, batchnorm=False)
self.ec1 = self.down_conv(32, 64, bias=False, batchnorm=False)
self.ec2 = self.down_conv(64, 64, bias=False, batchnorm=False)
self.ec3 = self.down_conv(64, 128, bias=False, batchnorm=False)
self.ec4 = self.down_conv(128, 128, bias=False, batchnorm=False)
self.ec5 = self.down_conv(128, 256, bias=False, batchnorm=False)
self.ec6 = self.down_conv(256, 256, bias=False, batchnorm=False)
self.ec7 = self.down_conv(256, 512, bias=False, batchnorm=False)
self.pool0 = nn.MaxPool3d(2)
self.pool1 = nn.MaxPool3d(2)
self.pool2 = nn.MaxPool3d(2)
self.dc9 = self.up_conv(512, 512, kernel_size=2, stride=2, bias=False)
self.dc8 = self.down_conv(256 + 512, 256, bias=False)
self.dc7 = self.down_conv(256, 256, bias=False)
self.dc6 = self.up_conv(256, 256, kernel_size=2, stride=2, bias=False)
self.dc5 = self.down_conv(128 + 256, 128, bias=False)
self.dc4 = self.down_conv(128, 128, bias=False)
self.dc3 = self.up_conv(128, 128, kernel_size=2, stride=2, bias=False)
self.dc2 = self.down_conv(64 + 128, 64, bias=False)
self.dc1 = self.down_conv(64, 64, kernel_size=3, stride=1, padding=1, bias=False)
self.dc0 = self.down_conv(64, n_classes, kernel_size=1, stride=1, padding=0, bias=False)
def down_conv(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1,
bias=True, batchnorm=False):
if batchnorm:
layer = nn.Sequential(
nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=bias),
nn.BatchNorm2d(out_channels),
nn.LeakyReLU())
else:
layer = nn.Sequential(
nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=bias),
nn.LeakyReLU())
return layer
def up_conv(self, in_channels, out_channels, kernel_size=2, stride=2, padding=0,
output_padding=0, bias=True):
layer = nn.Sequential(
nn.ConvTranspose3d(in_channels, out_channels, kernel_size, stride=stride,
padding=padding, output_padding=output_padding, bias=bias),
nn.LeakyReLU())
return layer
def forward(self, x):
e0 = self.ec0(x)
syn0 = self.ec1(e0)
e1 = self.pool0(syn0)
e2 = self.ec2(e1)
syn1 = self.ec3(e2)
del e0, e1, e2
e3 = self.pool1(syn1)
e4 = self.ec4(e3)
syn2 = self.ec5(e4)
del e3, e4
e5 = self.pool2(syn2)
e6 = self.ec6(e5)
e7 = self.ec7(e6)
del e5, e6
d9 = torch.cat((self.dc9(e7), syn2), dim=1)
del e7, syn2
d8 = self.dc8(d9)
d7 = self.dc7(d8)
del d9, d8
d6 = torch.cat((self.dc6(d7), syn1), dim=1)
del d7, syn1
d5 = self.dc5(d6)
d4 = self.dc4(d5)
del d6, d5
d3 = torch.cat((self.dc3(d4), syn0), dim=1)
del d4, syn0
d2 = self.dc2(d3)
d1 = self.dc1(d2)
del d3, d2
d0 = self.dc0(d1)
return d0