import torch
import torch.nn as nn
from torch.nn import Module
import torch.nn.functional as F
[docs]class NoPoolASPP(Module):
"""
.. image:: _static/img/nopool_aspp_arch.png
:align: center
:scale: 25%
An ASPP-based model without initial pooling layers.
:param drop_rate: dropout rate.
:param bn_momentum: batch normalization momentum.
.. seealso::
Perone, C. S., et al (2017). Spinal cord gray matter
segmentation using deep dilated convolutions.
Nature Scientific Reports link:
https://www.nature.com/articles/s41598-018-24304-3
"""
def __init__(self, drop_rate=0.4, bn_momentum=0.1,
base_num_filters=64):
super().__init__()
self.conv1a = nn.Conv2d(1, base_num_filters, kernel_size=3, padding=1)
self.conv1a_bn = nn.BatchNorm2d(base_num_filters, momentum=bn_momentum)
self.conv1a_drop = nn.Dropout2d(drop_rate)
self.conv1b = nn.Conv2d(base_num_filters, base_num_filters, kernel_size=3, padding=1)
self.conv1b_bn = nn.BatchNorm2d(base_num_filters, momentum=bn_momentum)
self.conv1b_drop = nn.Dropout2d(drop_rate)
self.conv2a = nn.Conv2d(base_num_filters, base_num_filters, kernel_size=3, padding=2, dilation=2)
self.conv2a_bn = nn.BatchNorm2d(base_num_filters, momentum=bn_momentum)
self.conv2a_drop = nn.Dropout2d(drop_rate)
self.conv2b = nn.Conv2d(base_num_filters, base_num_filters, kernel_size=3, padding=2, dilation=2)
self.conv2b_bn = nn.BatchNorm2d(base_num_filters, momentum=bn_momentum)
self.conv2b_drop = nn.Dropout2d(drop_rate)
# Branch 1x1 convolution
self.branch1a = nn.Conv2d(base_num_filters, base_num_filters, kernel_size=3, padding=1)
self.branch1a_bn = nn.BatchNorm2d(base_num_filters, momentum=bn_momentum)
self.branch1a_drop = nn.Dropout2d(drop_rate)
self.branch1b = nn.Conv2d(base_num_filters, base_num_filters, kernel_size=1)
self.branch1b_bn = nn.BatchNorm2d(base_num_filters, momentum=bn_momentum)
self.branch1b_drop = nn.Dropout2d(drop_rate)
# Branch for 3x3 rate 6
self.branch2a = nn.Conv2d(base_num_filters, base_num_filters, kernel_size=3, padding=6, dilation=6)
self.branch2a_bn = nn.BatchNorm2d(base_num_filters, momentum=bn_momentum)
self.branch2a_drop = nn.Dropout2d(drop_rate)
self.branch2b = nn.Conv2d(base_num_filters, base_num_filters, kernel_size=3, padding=6, dilation=6)
self.branch2b_bn = nn.BatchNorm2d(base_num_filters, momentum=bn_momentum)
self.branch2b_drop = nn.Dropout2d(drop_rate)
# Branch for 3x3 rate 12
self.branch3a = nn.Conv2d(base_num_filters, base_num_filters, kernel_size=3, padding=12, dilation=12)
self.branch3a_bn = nn.BatchNorm2d(base_num_filters, momentum=bn_momentum)
self.branch3a_drop = nn.Dropout2d(drop_rate)
self.branch3b = nn.Conv2d(base_num_filters, base_num_filters, kernel_size=3, padding=12, dilation=12)
self.branch3b_bn = nn.BatchNorm2d(base_num_filters, momentum=bn_momentum)
self.branch3b_drop = nn.Dropout2d(drop_rate)
# Branch for 3x3 rate 18
self.branch4a = nn.Conv2d(base_num_filters, base_num_filters, kernel_size=3, padding=18, dilation=18)
self.branch4a_bn = nn.BatchNorm2d(base_num_filters, momentum=bn_momentum)
self.branch4a_drop = nn.Dropout2d(drop_rate)
self.branch4b = nn.Conv2d(base_num_filters, base_num_filters, kernel_size=3, padding=18, dilation=18)
self.branch4b_bn = nn.BatchNorm2d(base_num_filters, momentum=bn_momentum)
self.branch4b_drop = nn.Dropout2d(drop_rate)
# Branch for 3x3 rate 24
self.branch5a = nn.Conv2d(base_num_filters, base_num_filters, kernel_size=3, padding=24, dilation=24)
self.branch5a_bn = nn.BatchNorm2d(base_num_filters, momentum=bn_momentum)
self.branch5a_drop = nn.Dropout2d(drop_rate)
self.branch5b = nn.Conv2d(base_num_filters, base_num_filters, kernel_size=3, padding=24, dilation=24)
self.branch5b_bn = nn.BatchNorm2d(base_num_filters, momentum=bn_momentum)
self.branch5b_drop = nn.Dropout2d(drop_rate)
self.concat_drop = nn.Dropout2d(drop_rate)
self.concat_bn = nn.BatchNorm2d(6*base_num_filters, momentum=bn_momentum)
self.amort = nn.Conv2d(6*base_num_filters, base_num_filters*2, kernel_size=1)
self.amort_bn = nn.BatchNorm2d(base_num_filters*2, momentum=bn_momentum)
self.amort_drop = nn.Dropout2d(drop_rate)
self.prediction = nn.Conv2d(base_num_filters*2, 1, kernel_size=1)
[docs] def forward(self, x):
"""Model forward pass.
:param x: input data.
"""
x = F.relu(self.conv1a(x))
x = self.conv1a_bn(x)
x = self.conv1a_drop(x)
x = F.relu(self.conv1b(x))
x = self.conv1b_bn(x)
x = self.conv1b_drop(x)
x = F.relu(self.conv2a(x))
x = self.conv2a_bn(x)
x = self.conv2a_drop(x)
x = F.relu(self.conv2b(x))
x = self.conv2b_bn(x)
x = self.conv2b_drop(x)
# Branch 1x1 convolution
branch1 = F.relu(self.branch1a(x))
branch1 = self.branch1a_bn(branch1)
branch1 = self.branch1a_drop(branch1)
branch1 = F.relu(self.branch1b(branch1))
branch1 = self.branch1b_bn(branch1)
branch1 = self.branch1b_drop(branch1)
# Branch for 3x3 rate 6
branch2 = F.relu(self.branch2a(x))
branch2 = self.branch2a_bn(branch2)
branch2 = self.branch2a_drop(branch2)
branch2 = F.relu(self.branch2b(branch2))
branch2 = self.branch2b_bn(branch2)
branch2 = self.branch2b_drop(branch2)
# Branch for 3x3 rate 6
branch3 = F.relu(self.branch3a(x))
branch3 = self.branch3a_bn(branch3)
branch3 = self.branch3a_drop(branch3)
branch3 = F.relu(self.branch3b(branch3))
branch3 = self.branch3b_bn(branch3)
branch3 = self.branch3b_drop(branch3)
# Branch for 3x3 rate 18
branch4 = F.relu(self.branch4a(x))
branch4 = self.branch4a_bn(branch4)
branch4 = self.branch4a_drop(branch4)
branch4 = F.relu(self.branch4b(branch4))
branch4 = self.branch4b_bn(branch4)
branch4 = self.branch4b_drop(branch4)
# Branch for 3x3 rate 24
branch5 = F.relu(self.branch5a(x))
branch5 = self.branch5a_bn(branch5)
branch5 = self.branch5a_drop(branch5)
branch5 = F.relu(self.branch5b(branch5))
branch5 = self.branch5b_bn(branch5)
branch5 = self.branch5b_drop(branch5)
# Global Average Pooling
global_pool = F.avg_pool2d(x, kernel_size=x.size()[2:])
global_pool = global_pool.expand(x.size())
concatenation = torch.cat([branch1,
branch2,
branch3,
branch4,
branch5,
global_pool], dim=1)
concatenation = self.concat_bn(concatenation)
concatenation = self.concat_drop(concatenation)
amort = F.relu(self.amort(concatenation))
amort = self.amort_bn(amort)
amort = self.amort_drop(amort)
predictions = self.prediction(amort)
predictions = F.sigmoid(predictions)
return predictions
class DownConv(Module):
def __init__(self, in_feat, out_feat, drop_rate=0.4, bn_momentum=0.1):
super(DownConv, self).__init__()
self.conv1 = nn.Conv2d(in_feat, out_feat, kernel_size=3, padding=1)
self.conv1_bn = nn.BatchNorm2d(out_feat, momentum=bn_momentum)
self.conv1_drop = nn.Dropout2d(drop_rate)
self.conv2 = nn.Conv2d(out_feat, out_feat, kernel_size=3, padding=1)
self.conv2_bn = nn.BatchNorm2d(out_feat, momentum=bn_momentum)
self.conv2_drop = nn.Dropout2d(drop_rate)
def forward(self, x):
x = F.relu(self.conv1(x))
x = self.conv1_bn(x)
x = self.conv1_drop(x)
x = F.relu(self.conv2(x))
x = self.conv2_bn(x)
x = self.conv2_drop(x)
return x
class UpConv(Module):
def __init__(self, in_feat, out_feat, drop_rate=0.4, bn_momentum=0.1):
super(UpConv, self).__init__()
self.up1 = nn.Upsample(scale_factor=2, mode='bilinear')
self.downconv = DownConv(in_feat, out_feat, drop_rate, bn_momentum)
def forward(self, x, y):
x = self.up1(x)
x = torch.cat([x, y], dim=1)
x = self.downconv(x)
return x
[docs]class Unet(Module):
"""A reference U-Net model.
.. seealso::
Ronneberger, O., et al (2015). U-Net: Convolutional
Networks for Biomedical Image Segmentation
ArXiv link: https://arxiv.org/abs/1505.04597
"""
def __init__(self, drop_rate=0.4, bn_momentum=0.1):
super(Unet, self).__init__()
#Downsampling path
self.conv1 = DownConv(1, 64, drop_rate, bn_momentum)
self.mp1 = nn.MaxPool2d(2)
self.conv2 = DownConv(64, 128, drop_rate, bn_momentum)
self.mp2 = nn.MaxPool2d(2)
self.conv3 = DownConv(128, 256, drop_rate, bn_momentum)
self.mp3 = nn.MaxPool2d(2)
# Bottom
self.conv4 = DownConv(256, 256, drop_rate, bn_momentum)
# Upsampling path
self.up1 = UpConv(512, 256, drop_rate, bn_momentum)
self.up2 = UpConv(384, 128, drop_rate, bn_momentum)
self.up3 = UpConv(192, 64, drop_rate, bn_momentum)
self.conv9 = nn.Conv2d(64, 1, kernel_size=3, padding=1)
def forward(self, x):
x1 = self.conv1(x)
x2 = self.mp1(x1)
x3 = self.conv2(x2)
x4 = self.mp2(x3)
x5 = self.conv3(x4)
x6 = self.mp3(x5)
# Bottom
x7 = self.conv4(x6)
# Up-sampling
x8 = self.up1(x7, x5)
x9 = self.up2(x8, x3)
x10 = self.up3(x9, x1)
x11 = self.conv9(x10)
preds = F.sigmoid(x11)
return preds