Skip to content
Snippets Groups Projects
Commit c2967e1b authored by Tamino Huxohl's avatar Tamino Huxohl
Browse files

add patch discriminator and use it in GAN Training

parent 57d641f4
No related branches found
No related tags found
No related merge requests found
...@@ -70,16 +70,37 @@ class Discriminator(nn.Module): ...@@ -70,16 +70,37 @@ class Discriminator(nn.Module):
return x return x
class PatchDiscriminator(nn.Module):
def __init__(self, in_channels: int = 2, input_size: int = 32):
super().__init__()
self.conv = nn.Sequential(
nn.Conv3d(in_channels=in_channels, out_channels=64, kernel_size=4, stride=2, padding=1),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Conv3d(in_channels=64, out_channels=128, kernel_size=4, stride=2, padding=1),
nn.BatchNorm3d(num_features=128),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Conv3d(in_channels=128, out_channels=256, kernel_size=4, stride=2, padding=1),
nn.BatchNorm3d(num_features=256),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Conv3d(in_channels=256, out_channels=1, kernel_size=4, stride=2, padding=1),
)
def forward(self, x: torch.Tensor):
return self.conv(x)
if __name__ == "__main__": if __name__ == "__main__":
input_size = 16 input_size = 32
net = Discriminator(input_size=input_size) # net = Discriminator(input_size=input_size)
net = PatchDiscriminator(input_size=input_size)
print(net) print(net)
_inputs = torch.rand((4, 1, input_size, input_size, input_size)) _inputs = torch.rand((4, 2, input_size, input_size, input_size))
_outputs = net(_inputs) _outputs = net(_inputs)
_targets = torch.full((4, 1), 1.0) _targets = torch.full(_outputs.shape, 1.0)
criterion = torch.nn.MSELoss() criterion = torch.nn.MSELoss()
loss = criterion(_outputs, _targets) loss = criterion(_outputs, _targets)
print(loss.item()) print(loss.item())
......
...@@ -164,8 +164,6 @@ class cGANTraining: ...@@ -164,8 +164,6 @@ class cGANTraining:
def _step(self, recons, mu_maps_real): def _step(self, recons, mu_maps_real):
batch_size = recons.shape[0] batch_size = recons.shape[0]
labels_real = torch.full((batch_size, 1), LABEL_REAL, device=self.device)
labels_fake = torch.full((batch_size, 1), LABEL_FAKE, device=self.device)
with torch.set_grad_enabled(True): with torch.set_grad_enabled(True):
self.optimizer_d.zero_grad() self.optimizer_d.zero_grad()
...@@ -177,11 +175,13 @@ class cGANTraining: ...@@ -177,11 +175,13 @@ class cGANTraining:
# compute discriminator loss for fake mu maps # compute discriminator loss for fake mu maps
inputs_d_fake = torch.cat((recons, mu_maps_fake), dim=1) inputs_d_fake = torch.cat((recons, mu_maps_fake), dim=1)
outputs_d_fake = self.discriminator(inputs_d_fake.detach()) # note the detach, so that gradients are not computed for the generator outputs_d_fake = self.discriminator(inputs_d_fake.detach()) # note the detach, so that gradients are not computed for the generator
labels_fake = torch.full((outputs_d_fake.shape), LABEL_FAKE, device=self.device)
loss_d_fake = self.criterion_d(outputs_d_fake, labels_fake) loss_d_fake = self.criterion_d(outputs_d_fake, labels_fake)
# compute discriminator loss for real mu maps # compute discriminator loss for real mu maps
inputs_d_real = torch.cat((recons, mu_maps_real), dim=1) inputs_d_real = torch.cat((recons, mu_maps_real), dim=1)
outputs_d_real = self.discriminator(inputs_d_real) # note the detach, so that gradients are not computed for the generator outputs_d_real = self.discriminator(inputs_d_real) # note the detach, so that gradients are not computed for the generator
labels_real = torch.full((outputs_d_fake.shape), LABEL_REAL, device=self.device)
loss_d_real = self.criterion_d(outputs_d_real, labels_real) loss_d_real = self.criterion_d(outputs_d_real, labels_real)
# update discriminator # update discriminator
...@@ -255,7 +255,7 @@ if __name__ == "__main__": ...@@ -255,7 +255,7 @@ if __name__ == "__main__":
from mu_map.dataset.transform import ScaleTransform from mu_map.dataset.transform import ScaleTransform
from mu_map.logging import add_logging_args, get_logger_by_args from mu_map.logging import add_logging_args, get_logger_by_args
from mu_map.models.unet import UNet from mu_map.models.unet import UNet
from mu_map.models.discriminator import Discriminator from mu_map.models.discriminator import Discriminator, PatchDiscriminator
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="Train a UNet model to predict μ-maps from reconstructed scatter images", description="Train a UNet model to predict μ-maps from reconstructed scatter images",
...@@ -421,7 +421,8 @@ if __name__ == "__main__": ...@@ -421,7 +421,8 @@ if __name__ == "__main__":
torch.manual_seed(args.seed) torch.manual_seed(args.seed)
np.random.seed(args.seed) np.random.seed(args.seed)
discriminator = Discriminator(in_channels=2, input_size=args.patch_size) # discriminator = Discriminator(in_channels=2, input_size=args.patch_size)
discriminator = PatchDiscriminator(in_channels=2, input_size=args.patch_size)
discriminator = discriminator.to(device) discriminator = discriminator.to(device)
generator = UNet(in_channels=1, features=args.features) generator = UNet(in_channels=1, features=args.features)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment