Skip to content
Snippets Groups Projects
Commit 3e0a078e authored by Tamino Huxohl's avatar Tamino Huxohl
Browse files

revice documentation of model module

parent 13a7f37e
No related branches found
No related tags found
No related merge requests found
"""
Module containing neural networks which can be used as a discriminator in cGAN training.
"""
from functools import reduce from functools import reduce
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
...@@ -7,7 +10,7 @@ import torch.nn as nn ...@@ -7,7 +10,7 @@ import torch.nn as nn
class Conv(nn.Sequential): class Conv(nn.Sequential):
""" """
A wrapper around a 3D convolutional layer that also contains batch normalization, a ReLU activation function and a max pooling. A wrapper around a 3D convolutional layer that also contains batch normalization, a ReLU activation function and max pooling.
""" """
def __init__( def __init__(
...@@ -85,7 +88,9 @@ class Discriminator(nn.Module): ...@@ -85,7 +88,9 @@ class Discriminator(nn.Module):
self.conv = nn.Sequential(*conv_layers) self.conv = nn.Sequential(*conv_layers)
# input is halved by each convolutional layer # input is halved by each convolutional layer
self.fc_input_size = map(lambda x: x // 2 ** (len(conv_features)), self.input_size) self.fc_input_size = map(
lambda x: x // 2 ** (len(conv_features)), self.input_size
)
self.fc_input_size = reduce(lambda x, y: x * y, self.fc_input_size) self.fc_input_size = reduce(lambda x, y: x * y, self.fc_input_size)
fc_layers = [] fc_layers = []
......
"""
Module containing a 3D U-Net.
"""
import argparse import argparse
from typing import Optional, List from typing import Optional, List
...@@ -22,10 +25,16 @@ class TwoConv(nn.Sequential): ...@@ -22,10 +25,16 @@ class TwoConv(nn.Sequential):
""" """
Create a sequential module consisting of two convolutions with ReLU activations. Create a sequential module consisting of two convolutions with ReLU activations.
:param in_channels: the number of channels the first convolution has to deal with Parameters
:param out_channels: the number of features computed by both convolutions ----------
:param batch_norm: if batch normalization should be applied after each convolution in_channels: int
:param dropout: optional dropout probability used for a dropout layer between both convolutions the number of channels the first convolution has to deal with
out_channels: int
the number of features computed by both convolutions
batch_norm: bool, optional
if batch normalization should be applied after each convolution
dropout: float, optional
dropout probability used for a dropout layer between both convolutions
""" """
super().__init__() super().__init__()
...@@ -61,8 +70,8 @@ class TwoConv(nn.Sequential): ...@@ -61,8 +70,8 @@ class TwoConv(nn.Sequential):
class UNet(nn.Module): class UNet(nn.Module):
""" """
Create a UNet for three dimensional inputs as used in the paper by Shi et al. (2020). Create a 3D U-Net as used by Shi et al. (2020).
Differences to the default UNet are: Differences to the default U-Net are:
* the usage of padding, * the usage of padding,
* batch normalization is applied after each convolution, and * batch normalization is applied after each convolution, and
* dropout is applied to the bottleneck layer. * dropout is applied to the bottleneck layer.
...@@ -74,16 +83,22 @@ class UNet(nn.Module): ...@@ -74,16 +83,22 @@ class UNet(nn.Module):
out_channels: int = 1, out_channels: int = 1,
features: List[int] = [64, 128, 256, 512], features: List[int] = [64, 128, 256, 512],
batch_norm: bool = True, batch_norm: bool = True,
dropout: Optional[float] = 0.15, dropout: float = 0.15,
): ):
""" """
Initialize the UNet. Initialize the UNet.
:param in_channels: number of input channels Parameters
:param out_channels: number of output channels ----------
:param features: number of features computed by the convolutions of each layer in_channels: int
:param batch_norm: if batch normalization should be added after each convolution number of input channels
:param dropout: dropout probability used for dropout at the bottleneck layer out_channels: int
number of output channels
features: list of int, optional
number of features computed by the convolutions of each layer
batch_norm: bool, optional
if batch normalization should be added after each convolution
dropout: float, optional, dropout probability used for dropout at the bottleneck layer
""" """
super().__init__() super().__init__()
...@@ -155,6 +170,18 @@ class UNet(nn.Module): ...@@ -155,6 +170,18 @@ class UNet(nn.Module):
@classmethod @classmethod
def add_arguments(cls, parser: argparse.ArgumentParser, prefix: str = ""): def add_arguments(cls, parser: argparse.ArgumentParser, prefix: str = ""):
"""
Add arguments to an argument parser to create a U-Net from command
line arguments.
Parameters
----------
parser: argparse.ArgumentParser
the parser to which the arguments are added
prefix: str, optional
prefix for the added arguments, which is useful if there are naming
conflicts
"""
prefix = f"{prefix}_" if prefix != "" else "" prefix = f"{prefix}_" if prefix != "" else ""
parser.add_argument( parser.add_argument(
f"--{prefix}in_channels", f"--{prefix}in_channels",
...@@ -185,7 +212,18 @@ class UNet(nn.Module): ...@@ -185,7 +212,18 @@ class UNet(nn.Module):
) )
@classmethod @classmethod
def from_args(cls, args, prefix: str = ""): def from_args(cls, args: argparse.Namespace, prefix: str = ""):
"""
Create a U-Net from command line arguments added by the `add_arguments`
method.
Parameters
----------
args: argparse.Namespace
the command line arguments
prefix: str, optional
needs to be the same as given to the `add_arguments` method.
"""
prefix = f"{prefix}_" if prefix != "" else "" prefix = f"{prefix}_" if prefix != "" else ""
_args = vars(args) _args = vars(args)
return cls( return cls(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment