diff --git a/mu_map/dataset/default.py b/mu_map/dataset/default.py index 651325b1b2134de50e19eaec164e2a81a185a1ad..385ca93470888ba413da1701b5680dafd82d0dc2 100644 --- a/mu_map/dataset/default.py +++ b/mu_map/dataset/default.py @@ -76,7 +76,7 @@ class MuMapDataset(Dataset): self.reconstructions = {} self.mu_maps = {} - def split_copy(self, split_name: str) -> MuMapDataset: + def split_copy(self, split_name: str): return MuMapDataset( dataset_dir=self.dir, csv_file=os.path.basename(self.csv_file), diff --git a/mu_map/dataset/patches.py b/mu_map/dataset/patches.py index fbed83e5aa62621aa94294ef69380648fcce998d..74872696de73f18bfaca7e73390cec811fd78a21 100644 --- a/mu_map/dataset/patches.py +++ b/mu_map/dataset/patches.py @@ -45,7 +45,7 @@ class MuMapPatchDataset(MuMapDataset): self.patches = [] self.generate_patches() - def split_copy(self, split_name: str) ->MuMapPatchDataset: + def split_copy(self, split_name: str): return MuMapPatchDataset( dataset_dir=self.dir, patches_per_image=self.patches_per_image, diff --git a/mu_map/training/default.py b/mu_map/training/distance.py similarity index 83% rename from mu_map/training/default.py rename to mu_map/training/distance.py index 14898a06ac0a5f88ed39a6f344ac4e1e0d0e0e0e..e607dffd86d9227e0da770b467233c1e37db8f53 100644 --- a/mu_map/training/default.py +++ b/mu_map/training/distance.py @@ -3,6 +3,7 @@ from typing import Dict import torch +from mu_map.dataset.default import MuMapDataset from mu_map.logging import get_logger from mu_map.training.lib import TrainingParams, AbstractTraining from mu_map.training.loss import WeightedLoss @@ -17,25 +18,25 @@ class Training(AbstractTraining): device: torch.device, snapshot_dir: str, snapshot_epoch: int, - params: TrainingParams, loss_func: WeightedLoss, + logger, ): - super().__init__(epochs, dataset, batch_size, device, snapshot_dir, snapshot_epoch) + super().__init__(epochs, dataset, batch_size, device, snapshot_dir, snapshot_epoch, logger) self.training_params.append(params) self.loss_func = loss_func self.model = params.model def _train_batch(self, recons: torch.Tensor, mu_maps: torch.Tensor) -> float: - outputs = self.model(inputs) - loss = self.loss_func(outputs, mu_maps) + mu_maps_predicted = self.model(recons) + loss = self.loss_func(mu_maps_predicted, mu_maps) loss.backward() return loss.item() def _eval_batch(self, recons: torch.Tensor, mu_maps: torch.Tensor) -> float: - outpus = self.model(inputs) - loss = torch.nn.functional.loss.l1(outpus, mu_maps) + mu_maps_predicted = self.model(recons) + loss = torch.nn.functional.loss.l1(mu_maps_predicted, mu_maps) return loss.item() @@ -144,6 +145,11 @@ if __name__ == "__main__": default="l1", help="define the loss function used for training, e.g. 0.75*l1+0.25*gdl", ) + parser.add_argument( + "--decay_lr", + action="store_true", + help="decay the learning rate", + ) parser.add_argument( "--lr", type=float, default=0.001, help="the initial learning rate for training" ) @@ -203,8 +209,6 @@ if __name__ == "__main__": torch.manual_seed(args.seed) np.random.seed(args.seed) - model = UNet(in_channels=1, features=args.features) - model = model.to(device) transform_normalization = None if args.input_norm == "mean": @@ -214,41 +218,38 @@ if __name__ == "__main__": elif args.input_norm == "gaussian": transform_normalization = GaussianNormTransform() - data_loaders = {} - for split in ["train", "validation"]: - dataset = MuMapPatchDataset( - args.dataset_dir, - patches_per_image=args.number_of_patches, - patch_size=args.patch_size, - patch_offset=args.patch_offset, - shuffle=not args.no_shuffle, - split_name=split, - transform_normalization=transform_normalization, - logger=logger, - ) - data_loader = torch.utils.data.DataLoader( - dataset=dataset, - batch_size=args.batch_size, - shuffle=True, - pin_memory=True, - num_workers=1, + dataset = MuMapPatchDataset( + args.dataset_dir, + patches_per_image=args.number_of_patches, + patch_size=args.patch_size, + patch_offset=args.patch_offset, + shuffle=not args.no_shuffle, + transform_normalization=transform_normalization, + logger=logger, + ) + + model = UNet(in_channels=1, features=args.features) + optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, betas=(0.5, 0.999)) + lr_scheduler = ( + torch.optim.lr_scheduler.StepLR( + optimizer, step_size=args.lr_decay_factor, gamma=args.lr_decay_factor ) - data_loaders[split] = data_loader + if args.decay_lr + else None + ) + params = TrainingParams(name="Model", model=model, optimizer=optimizer, lr_scheduler=lr_scheduler) criterion = WeightedLoss.from_str(args.loss_func) - logger.debug(f"Criterion: {criterion}") training = Training( - model=model, - data_loaders=data_loaders, epochs=args.epochs, + dataset=dataset, + batch_size=args.batch_size, device=device, - loss_func=criterion, - lr=args.lr, - lr_decay_factor=args.lr_decay_factor, - lr_decay_epoch=args.lr_decay_epoch, snapshot_dir=args.snapshot_dir, snapshot_epoch=args.snapshot_epoch, + params=params, + loss_func=criterion, logger=logger, ) training.run() diff --git a/mu_map/training/lib.py b/mu_map/training/lib.py index 00811cd63d651220f13f9fbf857536b63a5d872a..dc92337f1d3a28aea466bd39595d26bd5d09bcf6 100644 --- a/mu_map/training/lib.py +++ b/mu_map/training/lib.py @@ -1,12 +1,14 @@ from dataclasses import dataclass import os -from typing import Dict +from typing import Dict, Optional +import sys import torch from torch import Tensor from mu_map.dataset.default import MuMapDataset + @dataclass class TrainingParams: name: str @@ -14,8 +16,8 @@ class TrainingParams: optimizer: torch.optim.Optimizer lr_scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] -class AbstractTraining: +class AbstractTraining: def __init__( self, epochs: int, @@ -24,7 +26,7 @@ class AbstractTraining: device: torch.device, snapshot_dir: str, snapshot_epoch: int, - logger=None, + logger, # TODO make optional? ): self.epochs = epochs self.batch_size = batch_size @@ -37,7 +39,21 @@ class AbstractTraining: self.logger = logger self.training_params = [] - + self.data_loaders = dict( + [ + ( + split_name, + torch.utils.data.DataLoader( + dataset.split_copy(split_name), + batch_size=self.batch_size, + shuffle=True, + pin_memory=True, + num_workers=1, + ), + ) + for split_name in ["train", "validation"] + ] + ) def run(self) -> float: loss_val_min = sys.maxsize @@ -88,7 +104,7 @@ class AbstractTraining: for param in self.training_params: param.optimizer.zero_grad() - loss = loss + self._train_batch(self, inputs, targets) + loss = loss + self._train_batch(inputs, targets) for param in self.training_params: param.optimizer.step() @@ -110,12 +126,14 @@ class AbstractTraining: inputs = inputs.to(self.device) targets = targets.to(self.device) - loss = loss + self._eval_batch(self, inputs, targets) + loss = loss + self._eval_batch(inputs, targets) return loss / len(data_loader) def store_snapshot(self, prefix: str): for param in self.training_params: - snapshot_file = os.path.join(self.snapshot_dir, f"{prefix}_{param.name}.pth") + snapshot_file = os.path.join( + self.snapshot_dir, f"{prefix}_{param.name}.pth" + ) self.logger.debug(f"Store snapshot at {snapshot_file}") torch.save(param.model.state_dict(), snapshot_file) @@ -124,4 +142,3 @@ class AbstractTraining: def _eval_batch(self, inputs: torch.Tensor, targets: torch.Tensor) -> float: return 0 -