Newer
Older

Tamino Huxohl
committed

Tamino Huxohl
committed
"""
A wrapper around a 3D convolutional layer that also contains batch normalization and a ReLU activation function.
"""
def __init__(self, in_channels: int, out_channels: int):

Tamino Huxohl
committed
"""
Create a convolutional layer with batch normalization and a ReLU activation function.
:param in_channels: number of channels receives as input
:param out_channels: number of filters and consequently channels in the output
"""

Tamino Huxohl
committed
self.append(
nn.Conv3d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
stride=1,
padding="same",
)
)
self.append(nn.BatchNorm3d(num_features=out_channels))
self.append(nn.ReLU(inplace=True))

Tamino Huxohl
committed

Tamino Huxohl
committed
"""
Create the discriminator as specified by Shi et al. (2020).
If consists of three convolutional layers with max pooling, followed by three fully connected layers.
"""
def __init__(self, in_channels: int = 1, input_size: int = 16):

Tamino Huxohl
committed
"""
Create the discriminator.
:param in_channels: number channels received as an input
:param input_size: shape of the inputs images which is required to compute the number of features in the first fully connected layer
"""

Tamino Huxohl
committed
# input is halved three time (// 2 ** 3) and we deal with 3D inputs (**3)
fc_input_size = (input_size // 2 ** 3) ** 3
self.conv = nn.Sequential(
Conv(in_channels=in_channels, out_channels=32),
nn.MaxPool3d(kernel_size=2, stride=2),
Conv(in_channels=32, out_channels=64),
nn.MaxPool3d(kernel_size=2, stride=2),
Conv(in_channels=64, out_channels=128),
nn.MaxPool3d(kernel_size=2, stride=2),
)
self.fully_connected = nn.Sequential(

Tamino Huxohl
committed
nn.Linear(in_features=128 * fc_input_size, out_features=512),
nn.ReLU(inplace=True),
nn.Linear(in_features=512, out_features=128),
nn.ReLU(inplace=True),
nn.Linear(in_features=128, out_features=1),

Tamino Huxohl
committed
def forward(self, x: torch.Tensor):
x = self.conv(x)
x = torch.flatten(x, 1)
x = self.fully_connected(x)
return x
class PatchDiscriminator(nn.Module):
def __init__(self, in_channels: int = 2, input_size: int = 32):
super().__init__()
self.conv = nn.Sequential(
nn.Conv3d(in_channels=in_channels, out_channels=64, kernel_size=4, stride=2, padding=1),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Conv3d(in_channels=64, out_channels=128, kernel_size=4, stride=2, padding=1),
nn.BatchNorm3d(num_features=128),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Conv3d(in_channels=128, out_channels=256, kernel_size=4, stride=2, padding=1),
nn.BatchNorm3d(num_features=256),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Conv3d(in_channels=256, out_channels=1, kernel_size=4, stride=2, padding=1),
)
def forward(self, x: torch.Tensor):
return self.conv(x)

Tamino Huxohl
committed
# net = Discriminator(input_size=input_size)
net = PatchDiscriminator(input_size=input_size)
_inputs = torch.rand((4, 2, input_size, input_size, input_size))
_targets = torch.full(_outputs.shape, 1.0)
criterion = torch.nn.MSELoss()
loss = criterion(_outputs, _targets)
print(loss.item())
print(f"Transform {_inputs.shape} to {_outputs.shape}")