Skip to content
Snippets Groups Projects
unet.py 5.55 KiB
from typing import Optional, List

import torch
import torch.nn as nn


class TwoConv(nn.Sequential):
    """
    Combine two convolutions with ReLU activations as a sequential module.
    Optionally, batch normalization and dropout can be added.
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        batch_norm: bool = True,
        dropout: Optional[float] = None,
    ):
        """
        Create a sequential module consisting of two convolutions with ReLU activations.

        :param in_channels: the number of channels the first convolution has to deal with
        :param out_channels: the number of features computed by both convolutions
        :param batch_norm: if batch normalization should be applied after each convolution
        :param dropout: optional dropout probability used for a dropout layer between both convolutions
        """
        super().__init__()

        self.append(
            nn.Conv3d(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=3,
                stride=1,
                padding="same",
            )
        )
        if batch_norm:
            self.append(nn.BatchNorm3d(num_features=out_channels))
        self.append(nn.ReLU(inplace=True))

        if dropout is not None:
            self.append(nn.Dropout3d(p=dropout))

        self.append(
            nn.Conv3d(
                in_channels=out_channels,
                out_channels=out_channels,
                kernel_size=3,
                stride=1,
                padding="same",
            )
        )
        if batch_norm:
            self.append(nn.BatchNorm3d(num_features=out_channels))
        self.append(nn.ReLU(inplace=True))


class UNet(nn.Module):
    """
    Create a UNet for three dimensional inputs as used in the paper by Shi et al. (2020).
    Differences to the default UNet are:
      * the usage of padding,
      * batch normalization is applied after each convolution, and
      * dropout is applied to the bottleneck layer.
    """

    def __init__(
        self,
        in_channels: int = 1,
        out_channels: int = 1,
        features: List[int] = [64, 128, 256, 512],
        batch_norm: bool = True,
        dropout: Optional[float] = 0.15,
    ):
        """
        Initialize the UNet.

        :param in_channels: number of input channels
        :param out_channels: number of output channels
        :param features: number of features computed by the convolutions of each layer
        :param batch_norm: if batch normalization should be added after each convolution
        :param dropout: dropout probability used for dropout at the bottleneck layer
        """
        super().__init__()

        self.features = features

        self.layers = list(range(len(features) - 1))
        for i in self.layers:
            _in = features[i - 1] if i > 0 else in_channels

            self.add_module(
                f"down_{i + 1}_conv",
                TwoConv(
                    in_channels=_in, out_channels=features[i], batch_norm=batch_norm
                ),
            )
            self.add_module(f"down_{i + 1}_pool", nn.MaxPool3d(kernel_size=2, stride=2))

        self.add_module(
            "bottleneck",
            TwoConv(
                in_channels=features[-2], out_channels=features[-1], dropout=dropout
            ),
        )

        for i in self.layers[::-1]:
            self.add_module(
                f"up_{i + 1}_up",
                nn.ConvTranspose3d(
                    in_channels=features[i + 1],
                    out_channels=features[i],
                    kernel_size=2,
                    stride=2,
                ),
            )
            self.add_module(
                f"up_{i + 1}_conv",
                TwoConv(
                    in_channels=features[i + 1],
                    out_channels=features[i],
                    batch_norm=batch_norm,
                ),
            )

        self.add_module(
            "out_conv",
            nn.Conv3d(
                in_channels=features[0],
                out_channels=out_channels,
                stride=1,
                kernel_size=1,
            ),
        )

    def forward(self, x):
        intermediate = []
        for i in range(1, len(self.features)):
            x = self.get_submodule(f"down_{i}_conv")(x)
            intermediate.append(x)
            x = self.get_submodule(f"down_{i}_pool")(x)

        x = self.get_submodule("bottleneck")(x)

        for i in range(len(self.features) - 1, 0, -1):
            x = self.get_submodule(f"up_{i}_up")(x)
            x = torch.cat((x, intermediate[i - 1]), dim=1)
            x = self.get_submodule(f"up_{i}_conv")(x)

        return self.get_submodule("out_conv")(x)


if __name__ == "__main__":
    import torch

    net = UNet(features=[64, 128, 256, 512])
    print(net)

    _inputs = torch.rand((1, 1, 64, 128, 128))
    _outputs = net(_inputs)

    print(f"Transform {_inputs.shape} to {_outputs.shape}")

    import time
    
    device = torch.device("cuda")
    net = net.to(device)
    iterations = 100
    
    for batch_size in range(128, 129):
        since = time.time()
        for i in range(iterations):
            print(f"{str(batch_size):>2}/17 - {str(i+1):>3}/{iterations}", end="\r")
            # _inputs = torch.rand((batch_size, 1, 64, 128, 128))
            _inputs = torch.rand((batch_size, 1, 32, 32, 32))
            _inputs = _inputs.to(device)

            _outputs = net(_inputs)
        _took = time.time() - since
        print(f"Batches of size {batch_size} take {_took:.3f}s on average")