Skip to content
Snippets Groups Projects
Commit 7981050c authored by Tamino Huxohl's avatar Tamino Huxohl
Browse files

fixed for cgan training

parent 82017f08
No related branches found
No related tags found
No related merge requests found
......@@ -36,16 +36,12 @@ class cGANTraining:
weight_criterion_adv: float,
logger=None,
):
self.generator = generator
self.discriminator = discriminator
self.data_loaders = data_loaders
self.epochs = epochs
self.device = device
self.snapshot_dir = snapshot_dir
self.snapshot_epoch = snapshot_epoch
self.logger = logger if logger is not None else get_logger()
self.params_g = params_generator
......@@ -61,21 +57,21 @@ class cGANTraining:
loss_val_min = sys.maxsize
for epoch in range(1, self.epochs + 1):
str_epoch = f"{str(epoch):>{len(str(self.epochs))}}"
logger.debug(f"Run epoch {str_epoch}/{self.epochs} ...")
self.logger.debug(f"Run epoch {str_epoch}/{self.epochs} ...")
self._train_epoch()
loss_train = self._eval_epoch("train")
logger.info(
self.logger.info(
f"Epoch {str_epoch}/{self.epochs} - Loss train: {loss_train:.6f}"
)
loss_val = self._eval_epoch("validation")
logger.info(
self.logger.info(
f"Epoch {str_epoch}/{self.epochs} - Loss validation: {loss_val:.6f}"
)
if loss_val < loss_val_min:
loss_val_min = loss_val
logger.info(
self.logger.info(
f"Store snapshot val_min of epoch {str_epoch} with minimal validation loss"
)
self.store_snapshot("val_min")
......@@ -83,13 +79,12 @@ class cGANTraining:
self._store_snapshot(epoch)
if self.params_d.lr_scheduler is not None:
logger.debug("Step LR scheduler of discriminator")
self.logger.debug("Step LR scheduler of discriminator")
self.params_d.lr_scheduler.step()
if self.params_g.lr_scheduler is not None:
logger.debug("Step LR scheduler of generator")
self.logger.debug("Step LR scheduler of generator")
self.params_g.lr_scheduler.step()
return loss_val
return loss_val_min
def _train_epoch(self):
# setup training mode
......@@ -154,8 +149,8 @@ class cGANTraining:
def _eval_epoch(self, split_name):
# setup evaluation mode
torch.set_grad_enabled(False)
self.discriminator = self.discriminator.eval()
self.generator = self.generator.eval()
self.params_d.model = self.params_d.model.eval()
self.params_g.model = self.params_g.model.eval()
data_loader = self.data_loaders[split_name]
loss = 0.0
......@@ -181,9 +176,9 @@ class cGANTraining:
def store_snapshot(self, prefix: str):
snapshot_file_d = os.path.join(self.snapshot_dir, f"{prefix}_discriminator.pth")
snapshot_file_g = os.path.join(self.snapshot_dir, f"{prefix}_generator.pth")
logger.debug(f"Store snapshots at {snapshot_file_d} and {snapshot_file_g}")
torch.save(self.discriminator.state_dict(), snapshot_file_d)
torch.save(self.generator.state_dict(), snapshot_file_g)
self.logger.debug(f"Store snapshots at {snapshot_file_d} and {snapshot_file_g}")
torch.save(self.params_d.model.state_dict(), snapshot_file_d)
torch.save(self.params_g.model.state_dict(), snapshot_file_g)
if __name__ == "__main__":
......@@ -372,7 +367,7 @@ if __name__ == "__main__":
logger = get_logger_by_args(args)
logger.info(args)
args.seed = args.seed if args.seed is not None else random.randint(0, 2 ** 32 - 1)
args.seed = args.seed if args.seed is not None else random.randint(0, 2**32 - 1)
logger.info(f"Seed: {args.seed}")
random.seed(args.seed)
torch.manual_seed(args.seed)
......@@ -418,7 +413,7 @@ if __name__ == "__main__":
)
lr_scheduler = (
torch.optim.lr_scheduler.StepLR(
optimizer, step_size=args.lr_decay_factor, gamma=args.lr_decay_factor
optimizer, step_size=args.lr_decay_epoch, gamma=args.lr_decay_factor
)
if args.decay_lr
else None
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment