Skip to content
Snippets Groups Projects
Commit c2967e1b authored by Tamino Huxohl's avatar Tamino Huxohl
Browse files

add patch discriminator and use it in GAN Training

parent 57d641f4
No related branches found
No related tags found
No related merge requests found
......@@ -70,16 +70,37 @@ class Discriminator(nn.Module):
return x
class PatchDiscriminator(nn.Module):
def __init__(self, in_channels: int = 2, input_size: int = 32):
super().__init__()
self.conv = nn.Sequential(
nn.Conv3d(in_channels=in_channels, out_channels=64, kernel_size=4, stride=2, padding=1),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Conv3d(in_channels=64, out_channels=128, kernel_size=4, stride=2, padding=1),
nn.BatchNorm3d(num_features=128),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Conv3d(in_channels=128, out_channels=256, kernel_size=4, stride=2, padding=1),
nn.BatchNorm3d(num_features=256),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Conv3d(in_channels=256, out_channels=1, kernel_size=4, stride=2, padding=1),
)
def forward(self, x: torch.Tensor):
return self.conv(x)
if __name__ == "__main__":
input_size = 16
input_size = 32
net = Discriminator(input_size=input_size)
# net = Discriminator(input_size=input_size)
net = PatchDiscriminator(input_size=input_size)
print(net)
_inputs = torch.rand((4, 1, input_size, input_size, input_size))
_inputs = torch.rand((4, 2, input_size, input_size, input_size))
_outputs = net(_inputs)
_targets = torch.full((4, 1), 1.0)
_targets = torch.full(_outputs.shape, 1.0)
criterion = torch.nn.MSELoss()
loss = criterion(_outputs, _targets)
print(loss.item())
......
......@@ -164,8 +164,6 @@ class cGANTraining:
def _step(self, recons, mu_maps_real):
batch_size = recons.shape[0]
labels_real = torch.full((batch_size, 1), LABEL_REAL, device=self.device)
labels_fake = torch.full((batch_size, 1), LABEL_FAKE, device=self.device)
with torch.set_grad_enabled(True):
self.optimizer_d.zero_grad()
......@@ -177,11 +175,13 @@ class cGANTraining:
# compute discriminator loss for fake mu maps
inputs_d_fake = torch.cat((recons, mu_maps_fake), dim=1)
outputs_d_fake = self.discriminator(inputs_d_fake.detach()) # note the detach, so that gradients are not computed for the generator
labels_fake = torch.full((outputs_d_fake.shape), LABEL_FAKE, device=self.device)
loss_d_fake = self.criterion_d(outputs_d_fake, labels_fake)
# compute discriminator loss for real mu maps
inputs_d_real = torch.cat((recons, mu_maps_real), dim=1)
outputs_d_real = self.discriminator(inputs_d_real) # note the detach, so that gradients are not computed for the generator
labels_real = torch.full((outputs_d_fake.shape), LABEL_REAL, device=self.device)
loss_d_real = self.criterion_d(outputs_d_real, labels_real)
# update discriminator
......@@ -255,7 +255,7 @@ if __name__ == "__main__":
from mu_map.dataset.transform import ScaleTransform
from mu_map.logging import add_logging_args, get_logger_by_args
from mu_map.models.unet import UNet
from mu_map.models.discriminator import Discriminator
from mu_map.models.discriminator import Discriminator, PatchDiscriminator
parser = argparse.ArgumentParser(
description="Train a UNet model to predict μ-maps from reconstructed scatter images",
......@@ -421,7 +421,8 @@ if __name__ == "__main__":
torch.manual_seed(args.seed)
np.random.seed(args.seed)
discriminator = Discriminator(in_channels=2, input_size=args.patch_size)
# discriminator = Discriminator(in_channels=2, input_size=args.patch_size)
discriminator = PatchDiscriminator(in_channels=2, input_size=args.patch_size)
discriminator = discriminator.to(device)
generator = UNet(in_channels=1, features=args.features)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment