Skip to content
Snippets Groups Projects
Commit 837a67db authored by Tamino Huxohl's avatar Tamino Huxohl
Browse files

rename default training into distance training and fix issues with usage as abstract training

parent bda67c1e
Branches
No related tags found
No related merge requests found
...@@ -76,7 +76,7 @@ class MuMapDataset(Dataset): ...@@ -76,7 +76,7 @@ class MuMapDataset(Dataset):
self.reconstructions = {} self.reconstructions = {}
self.mu_maps = {} self.mu_maps = {}
def split_copy(self, split_name: str) -> MuMapDataset: def split_copy(self, split_name: str):
return MuMapDataset( return MuMapDataset(
dataset_dir=self.dir, dataset_dir=self.dir,
csv_file=os.path.basename(self.csv_file), csv_file=os.path.basename(self.csv_file),
......
...@@ -45,7 +45,7 @@ class MuMapPatchDataset(MuMapDataset): ...@@ -45,7 +45,7 @@ class MuMapPatchDataset(MuMapDataset):
self.patches = [] self.patches = []
self.generate_patches() self.generate_patches()
def split_copy(self, split_name: str) ->MuMapPatchDataset: def split_copy(self, split_name: str):
return MuMapPatchDataset( return MuMapPatchDataset(
dataset_dir=self.dir, dataset_dir=self.dir,
patches_per_image=self.patches_per_image, patches_per_image=self.patches_per_image,
......
...@@ -3,6 +3,7 @@ from typing import Dict ...@@ -3,6 +3,7 @@ from typing import Dict
import torch import torch
from mu_map.dataset.default import MuMapDataset
from mu_map.logging import get_logger from mu_map.logging import get_logger
from mu_map.training.lib import TrainingParams, AbstractTraining from mu_map.training.lib import TrainingParams, AbstractTraining
from mu_map.training.loss import WeightedLoss from mu_map.training.loss import WeightedLoss
...@@ -17,25 +18,25 @@ class Training(AbstractTraining): ...@@ -17,25 +18,25 @@ class Training(AbstractTraining):
device: torch.device, device: torch.device,
snapshot_dir: str, snapshot_dir: str,
snapshot_epoch: int, snapshot_epoch: int,
params: TrainingParams, params: TrainingParams,
loss_func: WeightedLoss, loss_func: WeightedLoss,
logger,
): ):
super().__init__(epochs, dataset, batch_size, device, snapshot_dir, snapshot_epoch) super().__init__(epochs, dataset, batch_size, device, snapshot_dir, snapshot_epoch, logger)
self.training_params.append(params) self.training_params.append(params)
self.loss_func = loss_func self.loss_func = loss_func
self.model = params.model self.model = params.model
def _train_batch(self, recons: torch.Tensor, mu_maps: torch.Tensor) -> float: def _train_batch(self, recons: torch.Tensor, mu_maps: torch.Tensor) -> float:
outputs = self.model(inputs) mu_maps_predicted = self.model(recons)
loss = self.loss_func(outputs, mu_maps) loss = self.loss_func(mu_maps_predicted, mu_maps)
loss.backward() loss.backward()
return loss.item() return loss.item()
def _eval_batch(self, recons: torch.Tensor, mu_maps: torch.Tensor) -> float: def _eval_batch(self, recons: torch.Tensor, mu_maps: torch.Tensor) -> float:
outpus = self.model(inputs) mu_maps_predicted = self.model(recons)
loss = torch.nn.functional.loss.l1(outpus, mu_maps) loss = torch.nn.functional.loss.l1(mu_maps_predicted, mu_maps)
return loss.item() return loss.item()
...@@ -144,6 +145,11 @@ if __name__ == "__main__": ...@@ -144,6 +145,11 @@ if __name__ == "__main__":
default="l1", default="l1",
help="define the loss function used for training, e.g. 0.75*l1+0.25*gdl", help="define the loss function used for training, e.g. 0.75*l1+0.25*gdl",
) )
parser.add_argument(
"--decay_lr",
action="store_true",
help="decay the learning rate",
)
parser.add_argument( parser.add_argument(
"--lr", type=float, default=0.001, help="the initial learning rate for training" "--lr", type=float, default=0.001, help="the initial learning rate for training"
) )
...@@ -203,8 +209,6 @@ if __name__ == "__main__": ...@@ -203,8 +209,6 @@ if __name__ == "__main__":
torch.manual_seed(args.seed) torch.manual_seed(args.seed)
np.random.seed(args.seed) np.random.seed(args.seed)
model = UNet(in_channels=1, features=args.features)
model = model.to(device)
transform_normalization = None transform_normalization = None
if args.input_norm == "mean": if args.input_norm == "mean":
...@@ -214,41 +218,38 @@ if __name__ == "__main__": ...@@ -214,41 +218,38 @@ if __name__ == "__main__":
elif args.input_norm == "gaussian": elif args.input_norm == "gaussian":
transform_normalization = GaussianNormTransform() transform_normalization = GaussianNormTransform()
data_loaders = {} dataset = MuMapPatchDataset(
for split in ["train", "validation"]: args.dataset_dir,
dataset = MuMapPatchDataset( patches_per_image=args.number_of_patches,
args.dataset_dir, patch_size=args.patch_size,
patches_per_image=args.number_of_patches, patch_offset=args.patch_offset,
patch_size=args.patch_size, shuffle=not args.no_shuffle,
patch_offset=args.patch_offset, transform_normalization=transform_normalization,
shuffle=not args.no_shuffle, logger=logger,
split_name=split, )
transform_normalization=transform_normalization,
logger=logger, model = UNet(in_channels=1, features=args.features)
) optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, betas=(0.5, 0.999))
data_loader = torch.utils.data.DataLoader( lr_scheduler = (
dataset=dataset, torch.optim.lr_scheduler.StepLR(
batch_size=args.batch_size, optimizer, step_size=args.lr_decay_factor, gamma=args.lr_decay_factor
shuffle=True,
pin_memory=True,
num_workers=1,
) )
data_loaders[split] = data_loader if args.decay_lr
else None
)
params = TrainingParams(name="Model", model=model, optimizer=optimizer, lr_scheduler=lr_scheduler)
criterion = WeightedLoss.from_str(args.loss_func) criterion = WeightedLoss.from_str(args.loss_func)
logger.debug(f"Criterion: {criterion}")
training = Training( training = Training(
model=model,
data_loaders=data_loaders,
epochs=args.epochs, epochs=args.epochs,
dataset=dataset,
batch_size=args.batch_size,
device=device, device=device,
loss_func=criterion,
lr=args.lr,
lr_decay_factor=args.lr_decay_factor,
lr_decay_epoch=args.lr_decay_epoch,
snapshot_dir=args.snapshot_dir, snapshot_dir=args.snapshot_dir,
snapshot_epoch=args.snapshot_epoch, snapshot_epoch=args.snapshot_epoch,
params=params,
loss_func=criterion,
logger=logger, logger=logger,
) )
training.run() training.run()
from dataclasses import dataclass from dataclasses import dataclass
import os import os
from typing import Dict from typing import Dict, Optional
import sys
import torch import torch
from torch import Tensor from torch import Tensor
from mu_map.dataset.default import MuMapDataset from mu_map.dataset.default import MuMapDataset
@dataclass @dataclass
class TrainingParams: class TrainingParams:
name: str name: str
...@@ -14,8 +16,8 @@ class TrainingParams: ...@@ -14,8 +16,8 @@ class TrainingParams:
optimizer: torch.optim.Optimizer optimizer: torch.optim.Optimizer
lr_scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] lr_scheduler: Optional[torch.optim.lr_scheduler._LRScheduler]
class AbstractTraining:
class AbstractTraining:
def __init__( def __init__(
self, self,
epochs: int, epochs: int,
...@@ -24,7 +26,7 @@ class AbstractTraining: ...@@ -24,7 +26,7 @@ class AbstractTraining:
device: torch.device, device: torch.device,
snapshot_dir: str, snapshot_dir: str,
snapshot_epoch: int, snapshot_epoch: int,
logger=None, logger, # TODO make optional?
): ):
self.epochs = epochs self.epochs = epochs
self.batch_size = batch_size self.batch_size = batch_size
...@@ -37,7 +39,21 @@ class AbstractTraining: ...@@ -37,7 +39,21 @@ class AbstractTraining:
self.logger = logger self.logger = logger
self.training_params = [] self.training_params = []
self.data_loaders = dict(
[
(
split_name,
torch.utils.data.DataLoader(
dataset.split_copy(split_name),
batch_size=self.batch_size,
shuffle=True,
pin_memory=True,
num_workers=1,
),
)
for split_name in ["train", "validation"]
]
)
def run(self) -> float: def run(self) -> float:
loss_val_min = sys.maxsize loss_val_min = sys.maxsize
...@@ -88,7 +104,7 @@ class AbstractTraining: ...@@ -88,7 +104,7 @@ class AbstractTraining:
for param in self.training_params: for param in self.training_params:
param.optimizer.zero_grad() param.optimizer.zero_grad()
loss = loss + self._train_batch(self, inputs, targets) loss = loss + self._train_batch(inputs, targets)
for param in self.training_params: for param in self.training_params:
param.optimizer.step() param.optimizer.step()
...@@ -110,12 +126,14 @@ class AbstractTraining: ...@@ -110,12 +126,14 @@ class AbstractTraining:
inputs = inputs.to(self.device) inputs = inputs.to(self.device)
targets = targets.to(self.device) targets = targets.to(self.device)
loss = loss + self._eval_batch(self, inputs, targets) loss = loss + self._eval_batch(inputs, targets)
return loss / len(data_loader) return loss / len(data_loader)
def store_snapshot(self, prefix: str): def store_snapshot(self, prefix: str):
for param in self.training_params: for param in self.training_params:
snapshot_file = os.path.join(self.snapshot_dir, f"{prefix}_{param.name}.pth") snapshot_file = os.path.join(
self.snapshot_dir, f"{prefix}_{param.name}.pth"
)
self.logger.debug(f"Store snapshot at {snapshot_file}") self.logger.debug(f"Store snapshot at {snapshot_file}")
torch.save(param.model.state_dict(), snapshot_file) torch.save(param.model.state_dict(), snapshot_file)
...@@ -124,4 +142,3 @@ class AbstractTraining: ...@@ -124,4 +142,3 @@ class AbstractTraining:
def _eval_batch(self, inputs: torch.Tensor, targets: torch.Tensor) -> float: def _eval_batch(self, inputs: torch.Tensor, targets: torch.Tensor) -> float:
return 0 return 0
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment