Skip to content
Snippets Groups Projects
Commit ce73746b authored by Tamino Huxohl's avatar Tamino Huxohl
Browse files

update training code

parent 79d54434
Branches
No related tags found
No related merge requests found
......@@ -46,7 +46,9 @@ class MuMapDataset(Dataset):
self.transform_normalization = transform_normalization
self.transform_augmentation = transform_augmentation
self.logger = logger if logger is not None else get_logger(name=MuMapDataset.__name__)
self.logger = (
logger if logger is not None else get_logger(name=MuMapDataset.__name__)
)
self.bed_contours_file = (
os.path.join(dataset_dir, bed_contours_file) if bed_contours_file else None
......@@ -56,6 +58,7 @@ class MuMapDataset(Dataset):
)
# read CSV file and from that access DICOM files
self.split_name = split_name
self.table = pd.read_csv(self.csv_file)
if split_name:
self.table = split_csv(self.table, self.split_file)[split_name]
......@@ -73,6 +76,22 @@ class MuMapDataset(Dataset):
self.reconstructions = {}
self.mu_maps = {}
def split_copy(self, split_name: str) -> MuMapDataset:
return MuMapDataset(
dataset_dir=self.dir,
csv_file=os.path.basename(self.csv_file),
split_file=os.path.basename(self.split_file),
split_name=split_name,
images_dir=os.path.basename(self.images_dir),
bed_contours_file=self.bed_contours_file,
discard_mu_map_slices=self.discard_mu_map_slices,
align=self.align,
scatter_correction=self.scatter_correction,
transform_normalization=self.transform_normalization,
transform_augmentation=self.transform_augmentation,
logger=self.logger,
)
def load_image(self, _id: int):
row = self.table[self.table[headers.id] == _id].iloc[0]
_id = row[headers.id]
......
......@@ -34,6 +34,7 @@ class MuMapPatchDataset(MuMapDataset):
):
super().__init__(dataset_dir, **kwargs)
self.kwargs = kwargs
self.patches_per_image = patches_per_image
self.patch_size = patch_size
self.patch_size_z = patch_size_z
......@@ -44,6 +45,17 @@ class MuMapPatchDataset(MuMapDataset):
self.patches = []
self.generate_patches()
def split_copy(self, split_name: str) ->MuMapPatchDataset:
return MuMapPatchDataset(
dataset_dir=self.dir,
patches_per_image=self.patches_per_image,
patch_size=self.patch_size,
patch_size_z=self.patch_size_z,
patch_offset=self.patch_offset,
shuffle=self.shuffle,
**self.kwargs,
)
def generate_patches(self):
"""
Pre-compute patches for each image.
......
import argparse
from typing import Optional, List
import torch
......@@ -151,35 +153,94 @@ class UNet(nn.Module):
return self.get_submodule("out_conv")(x)
@classmethod
def add_arguments(cls, parser: argparse.ArgumentParser, prefix: str = ""):
prefix = f"{prefix}_" if prefix != "" else ""
parser.add_argument(
f"--{prefix}in_channels",
type=int,
default=1,
help="the number of inputs channels of the model",
)
parser.add_argument(
f"--{prefix}out_channels",
type=int,
default=1,
help="the number of output channels of the model",
)
parser.add_argument(
f"--{prefix}features",
type=int,
nargs="+",
default=[64, 128, 256, 512],
help="the number of features of each layer in the UNet hierarchy",
)
parser.add_argument(
f"--{prefix}no_batch_norm",
action="store_true",
help="do not apply batch normalization",
)
parser.add_argument(
f"--{prefix}dropout", type=float, default=0.15, help="the dropout rate"
)
if __name__ == "__main__":
import torch
@classmethod
def from_args(cls, args, prefix: str = ""):
prefix = f"{prefix}_" if prefix != "" else ""
_args = vars(args)
return cls(
in_channels=_args[f"{prefix}in_channels"],
out_channels=_args[f"{prefix}out_channels"],
features=_args[f"{prefix}features"],
batch_norm=not _args[f"{prefix}no_batch_norm"],
dropout=_args[f"{prefix}dropout"],
)
net = UNet(features=[64, 128, 256, 512])
print(net)
_inputs = torch.rand((1, 1, 64, 128, 128))
_outputs = net(_inputs)
if __name__ == "__main__":
import time
print(f"Transform {_inputs.shape} to {_outputs.shape}")
import torch
import time
device = torch.device("cuda")
torch.set_grad_enabled(False)
from tqdm import trange
parser = argparse.ArgumentParser(
description="Test the performance of a UNet model",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--input_size",
type=int,
nargs=3,
default=[64, 128, 128],
help="the size of input images which are passed through the mode",
)
parser.add_argument(
"--device",
choices=["cpu", "cuda"],
default="cuda" if torch.cuda.is_available() else "cpu",
help="the device on which computations are performed",
)
parser.add_argument(
"--iterations", type=int, default=100, help="number of iterations to perform"
)
parser.add_argument(
"--batch_size",
type=int,
default=64,
help="the batch size for which performance is measured",
)
UNet.add_arguments(parser.add_argument_group("Model"))
args = parser.parse_args()
print(args)
device = torch.device(args.device)
net = UNet.from_args(args)
net = net.to(device)
iterations = 100
for batch_size in range(128, 129):
since = time.time()
for i in range(iterations):
print(f"{str(batch_size):>2}/17 - {str(i+1):>3}/{iterations}", end="\r")
# _inputs = torch.rand((batch_size, 1, 64, 128, 128))
_inputs = torch.rand((batch_size, 1, 32, 32, 32))
_inputs = _inputs.to(device)
_outputs = net(_inputs)
_took = time.time() - since
print(f"Batches of size {batch_size} take {_took:.3f}s on average")
for i in trange(args.iterations):
_inputs = torch.rand((args.batch_size, args.in_channels, *args.input_size))
_inputs = _inputs.to(device)
_outputs = net(_inputs)
......@@ -4,107 +4,39 @@ from typing import Dict
import torch
from mu_map.logging import get_logger
from mu_map.training.lib import TrainingParams, AbstractTraining
from mu_map.training.loss import WeightedLoss
class Training:
class Training(AbstractTraining):
def __init__(
self,
model: torch.nn.Module,
data_loaders: Dict[str, torch.utils.data.DataLoader],
epochs: int,
dataset: MuMapDataset,
batch_size: int,
device: torch.device,
loss_func: WeightedLoss,
lr: float,
lr_decay_factor: float,
lr_decay_epoch: int,
snapshot_dir: str,
snapshot_epoch: int,
logger=None,
):
self.model = model
self.data_loaders = data_loaders
self.epochs = epochs
self.device = device
self.lr = lr
self.lr_decay_factor = lr_decay_factor
self.lr_decay_epoch = lr_decay_epoch
self.snapshot_dir = snapshot_dir
self.snapshot_epoch = snapshot_epoch
self.logger = logger if logger is not None else get_logger()
params: TrainingParams,
loss_func: WeightedLoss,
):
super().__init__(epochs, dataset, batch_size, device, snapshot_dir, snapshot_epoch)
self.training_params.append(params)
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr)
self.lr_scheduler = torch.optim.lr_scheduler.StepLR(
self.optimizer, step_size=self.lr_decay_epoch, gamma=self.lr_decay_factor
)
self.loss_func = loss_func
def run(self):
for epoch in range(1, self.epochs + 1):
logger.debug(
f"Run epoch {str(epoch):>{len(str(self.epochs))}}/{self.epochs} ..."
)
self._run_epoch(self.data_loaders["train"], phase="train")
loss_training = self._run_epoch(self.data_loaders["train"], phase="val")
logger.info(
f"Epoch {str(epoch):>{len(str(self.epochs))}}/{self.epochs} - Loss train: {loss_training:.6f}"
)
loss_validation = self._run_epoch(
self.data_loaders["validation"], phase="val"
)
logger.info(
f"Epoch {str(epoch):>{len(str(self.epochs))}}/{self.epochs} - Loss validation: {loss_validation:.6f}"
)
_previous = self.lr_scheduler.get_last_lr()[0]
self.lr_scheduler.step()
logger.debug(
f"Update learning rate from {_previous:.6f} to {self.lr_scheduler.get_last_lr()[0]:.6f}"
)
if epoch % self.snapshot_epoch == 0:
self.store_snapshot(epoch)
logger.debug(
f"Finished epoch {str(epoch):>{len(str(self.epochs))}}/{self.epochs}"
)
def _run_epoch(self, data_loader, phase):
logger.debug(f"Run epoch in phase {phase}")
self.model.train() if phase == "train" else self.model.eval()
epoch_loss = 0
loss_updates = 0
for i, (inputs, labels) in enumerate(data_loader):
print(
f"Batch {str(i):>{len(str(len(data_loader)))}}/{len(data_loader)}",
end="\r",
)
inputs = inputs.to(self.device)
labels = labels.to(self.device)
self.optimizer.zero_grad()
with torch.set_grad_enabled(phase == "train"):
outputs = self.model(inputs)
loss = self.loss_func(outputs, labels)
if phase == "train":
loss.backward()
self.optimizer.step()
epoch_loss += loss.item()
loss_updates += 1
return epoch_loss / loss_updates
def store_snapshot(self, epoch):
snapshot_file = f"{epoch:0{len(str(self.epochs))}d}.pth"
snapshot_file = os.path.join(self.snapshot_dir, snapshot_file)
logger.debug(f"Store snapshot at {snapshot_file}")
torch.save(self.model.state_dict(), snapshot_file)
self.model = params.model
def _train_batch(self, recons: torch.Tensor, mu_maps: torch.Tensor) -> float:
outputs = self.model(inputs)
loss = self.loss_func(outputs, mu_maps)
loss.backward()
return loss.item()
def _eval_batch(self, recons: torch.Tensor, mu_maps: torch.Tensor) -> float:
outpus = self.model(inputs)
loss = torch.nn.functional.loss.l1(outpus, mu_maps)
return loss.item()
if __name__ == "__main__":
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment