Skip to content
Snippets Groups Projects
Commit 8f7bbd3a authored by Tamino Huxohl's avatar Tamino Huxohl
Browse files

fix cgan training and patch dataset

parent 56e09b73
No related branches found
No related tags found
No related merge requests found
...@@ -86,7 +86,7 @@ class MuMapPatchDataset(MuMapDataset): ...@@ -86,7 +86,7 @@ class MuMapPatchDataset(MuMapDataset):
ps = self.patch_size ps = self.patch_size
ps_z = self.patch_size_z ps_z = self.patch_size_z
recon, mu_map = super().__getitem__(index) recon, mu_map = super().getitem_by_id(_id)
recon = torch.nn.functional.pad(recon, padding, mode="constant", value=0) recon = torch.nn.functional.pad(recon, padding, mode="constant", value=0)
mu_map = torch.nn.functional.pad(mu_map, padding, mode="constant", value=0) mu_map = torch.nn.functional.pad(mu_map, padding, mode="constant", value=0)
......
from dataclasses import dataclass
import os import os
from typing import Dict, Optional from typing import Dict, Optional
import sys import sys
...@@ -12,8 +13,6 @@ from mu_map.logging import get_logger ...@@ -12,8 +13,6 @@ from mu_map.logging import get_logger
LABEL_REAL = 1.0 LABEL_REAL = 1.0
LABEL_FAKE = 0.0 LABEL_FAKE = 0.0
from dataclass import dataclass
@dataclass @dataclass
class TrainingParams: class TrainingParams:
...@@ -56,7 +55,7 @@ class cGANTraining: ...@@ -56,7 +55,7 @@ class cGANTraining:
self.weight_criterion_adv = weight_criterion_adv self.weight_criterion_adv = weight_criterion_adv
self.criterion_adv = torch.nn.MSELoss(reduction="mean") self.criterion_adv = torch.nn.MSELoss(reduction="mean")
self.criterion_dist = self.loss_func_dist self.criterion_dist = loss_func_dist
def run(self): def run(self):
loss_val_min = sys.maxsize loss_val_min = sys.maxsize
...@@ -77,7 +76,7 @@ class cGANTraining: ...@@ -77,7 +76,7 @@ class cGANTraining:
if loss_val < loss_val_min: if loss_val < loss_val_min:
loss_val_min = loss_val loss_val_min = loss_val
logger.info( logger.info(
"Store snapshot val_min of epoch {str_epoch} with minimal validation loss" f"Store snapshot val_min of epoch {str_epoch} with minimal validation loss"
) )
self.store_snapshot("val_min") self.store_snapshot("val_min")
if epoch % self.snapshot_epoch == 0: if epoch % self.snapshot_epoch == 0:
...@@ -93,15 +92,13 @@ class cGANTraining: ...@@ -93,15 +92,13 @@ class cGANTraining:
return loss_val return loss_val
def _train_epoch(self): def _train_epoch(self):
logger.debug(f"Train epoch")
# setup training mode # setup training mode
torch.set_grad_enabled(True) torch.set_grad_enabled(True)
self.params_d.model.train() self.params_d.model.train()
self.params_g.model.train() self.params_g.model.train()
data_loader = self.data_loaders["train"] data_loader = self.data_loaders["train"]
for i, (recons, mu_maps) in enumerate(data_loader): for i, (recons, mu_maps_real) in enumerate(data_loader):
print( print(
f"Batch {str(i):>{len(str(len(data_loader)))}}/{len(data_loader)}", f"Batch {str(i):>{len(str(len(data_loader)))}}/{len(data_loader)}",
end="\r", end="\r",
...@@ -109,7 +106,7 @@ class cGANTraining: ...@@ -109,7 +106,7 @@ class cGANTraining:
batch_size = recons.shape[0] batch_size = recons.shape[0]
recons = recons.to(self.device) recons = recons.to(self.device)
mu_maps = mu_maps.to(self.device) mu_maps_real = mu_maps_real.to(self.device)
self.params_d.optimizer.zero_grad() self.params_d.optimizer.zero_grad()
self.params_g.optimizer.zero_grad() self.params_g.optimizer.zero_grad()
...@@ -154,7 +151,7 @@ class cGANTraining: ...@@ -154,7 +151,7 @@ class cGANTraining:
loss_g.backward() loss_g.backward()
self.params_g.optimizer.step() self.params_g.optimizer.step()
def _eval_epoch(self, epoch, split_name): def _eval_epoch(self, split_name):
# setup evaluation mode # setup evaluation mode
torch.set_grad_enabled(False) torch.set_grad_enabled(False)
self.discriminator = self.discriminator.eval() self.discriminator = self.discriminator.eval()
...@@ -171,7 +168,7 @@ class cGANTraining: ...@@ -171,7 +168,7 @@ class cGANTraining:
recons = recons.to(self.device) recons = recons.to(self.device)
mu_maps = mu_maps.to(self.device) mu_maps = mu_maps.to(self.device)
outputs = self.params_g(recons) outputs = self.params_g.model(recons)
loss += torch.nn.functional.l1_loss(outputs, mu_maps) loss += torch.nn.functional.l1_loss(outputs, mu_maps)
updates += 1 updates += 1
...@@ -259,7 +256,7 @@ if __name__ == "__main__": ...@@ -259,7 +256,7 @@ if __name__ == "__main__":
help="do not shuffle patches in the dataset", help="do not shuffle patches in the dataset",
) )
parser.add_argument( parser.add_argument(
"scatter_correction", "--scatter_correction",
action="store_true", action="store_true",
help="use the scatter corrected reconstructions in the dataset", help="use the scatter corrected reconstructions in the dataset",
) )
...@@ -450,7 +447,7 @@ if __name__ == "__main__": ...@@ -450,7 +447,7 @@ if __name__ == "__main__":
) )
dist_criterion = WeightedLoss.from_str(args.dist_loss_func) dist_criterion = WeightedLoss.from_str(args.dist_loss_func)
logger.debug(f"Use distance criterion: {criterion}") logger.debug(f"Use distance criterion: {dist_criterion}")
training = cGANTraining( training = cGANTraining(
data_loaders=data_loaders, data_loaders=data_loaders,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment