diff --git a/mu_map/data/preprocessing.py b/mu_map/data/preprocessing.py deleted file mode 100644 index dc57c6f1de0980238302c460af2d8d9278f8826b..0000000000000000000000000000000000000000 --- a/mu_map/data/preprocessing.py +++ /dev/null @@ -1,38 +0,0 @@ -import torch - - -def norm_max(tensor: torch.Tensor): - return (tensor - tensor.min()) / (tensor.max() - tensor.min()) - - -class MaxNorm: - def __call__(self, tensor: torch.Tensor): - return norm_max(tensor) - - -def norm_mean(tensor: torch.Tensor): - return tensor / tensor.mean() - - -class MeanNorm: - def __call__(self, tensor: torch.Tensor): - return norm_mean(tensor) - - -def norm_gaussian(tensor: torch.Tensor): - return (tensor - tensor.mean()) / tensor.std() - - -class GaussianNorm: - def __call__(self, tensor: torch.Tensor): - return norm_gaussian(tensor) - - -__all__ = [ - norm_max.__name__, - norm_mean.__name__, - norm_gaussian.__name__, - MaxNorm.__name__, - MeanNorm.__name__, - GaussianNorm.__name__, -] diff --git a/mu_map/data/split.py b/mu_map/data/split.py index b21dda72dc67fa462f3409aedf8108a3a205ea44..a685e07a15497ef492180fa2e08f36cb69a210c9 100644 --- a/mu_map/data/split.py +++ b/mu_map/data/split.py @@ -1,6 +1,6 @@ import pandas as pd -from typing import List +from typing import Dict, List def parse_split_str(_str: str, delimitier: str = "/") -> List[float]: @@ -23,7 +23,7 @@ def parse_split_str(_str: str, delimitier: str = "/") -> List[float]: return split -def split_csv(data: pd.DataFrame, split_csv: str) -> List[pd.DataFrame]: +def split_csv(data: pd.DataFrame, split_csv: str) -> Dict[str, pd.DataFrame]: """ Split a data frames based on a file defining a split. @@ -40,7 +40,9 @@ def split_csv(data: pd.DataFrame, split_csv: str) -> List[pd.DataFrame]: lambda patient_ids: data[data["patient_id"].isin(patient_ids)], split_patient_ids, ) - return list(splits) + splits = zip(split_names, splits) + splits = dict(splits) + return splits if __name__ == "__main__": diff --git a/mu_map/dataset/__init__.py b/mu_map/dataset/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/mu_map/data/datasets.py b/mu_map/dataset/default.py similarity index 76% rename from mu_map/data/datasets.py rename to mu_map/dataset/default.py index e92ad5acaa1a92f5877fdc36867256d8f9615b24..8291615b2802dc2c23685b18a8bdefac2a170b27 100644 --- a/mu_map/data/datasets.py +++ b/mu_map/dataset/default.py @@ -5,11 +5,15 @@ import cv2 as cv import pandas as pd import pydicom import numpy as np +import torch from torch.utils.data import Dataset from mu_map.data.prepare import headers from mu_map.data.remove_bed import DEFAULT_BED_CONTOURS_FILENAME, load_contours from mu_map.data.review_mu_map import discard_slices +from mu_map.data.split import split_csv +from mu_map.dataset.transform import Transform +from mu_map.logging import get_logger def align_images(image_1: np.ndarray, image_2: np.ndarray) -> np.ndarray: @@ -43,16 +47,26 @@ class MuMapDataset(Dataset): self, dataset_dir: str, csv_file: str = "meta.csv", + split_file: str = "split.csv", + split_name: str = None, images_dir: str = "images", bed_contours_file: Optional[str] = DEFAULT_BED_CONTOURS_FILENAME, discard_mu_map_slices: bool = True, align: bool = True, + transform_normalization: Transform = Transform(), + transform_augmentation: Transform = Transform(), + logger=None, ): super().__init__() self.dir = dataset_dir self.dir_images = os.path.join(dataset_dir, images_dir) self.csv_file = os.path.join(dataset_dir, csv_file) + self.split_file = os.path.join(dataset_dir, split_file) + + self.transform_normalization = transform_normalization + self.transform_augmentation = transform_augmentation + self.logger = logger if logger is not None else get_logger() self.bed_contours_file = ( os.path.join(dataset_dir, bed_contours_file) if bed_contours_file else None @@ -63,6 +77,8 @@ class MuMapDataset(Dataset): # read CSV file and from that access DICOM files self.table = pd.read_csv(self.csv_file) + if split_name: + self.table = split_csv(self.table, self.split_file)[split_name] self.table["id"] = self.table["id"].apply(int) self.discard_mu_map_slices = discard_mu_map_slices @@ -73,7 +89,7 @@ class MuMapDataset(Dataset): self.pre_load_images() def pre_load_images(self): - print("Pre-loading images ...", end="\r") + self.logger.debug("Pre-loading images ...") for i in range(len(self.table)): row = self.table.iloc[i] _id = row["id"] @@ -86,14 +102,25 @@ class MuMapDataset(Dataset): bed_contour = self.bed_contours[row["id"]] for i in range(mu_map.shape[0]): mu_map[i] = cv.drawContours(mu_map[i], [bed_contour], -1, 0.0, -1) - self.mu_maps[_id] = mu_map recon_file = os.path.join(self.dir_images, row[headers.file_recon_nac_nsc]) recon = pydicom.dcmread(recon_file).pixel_array if self.align: recon = align_images(recon, mu_map) + + mu_map = mu_map.astype(np.float32) + mu_map = torch.from_numpy(mu_map) + mu_map = mu_map.unsqueeze(dim=0) + + recon = recon.astype(np.float32) + recon = torch.from_numpy(recon) + recon = recon.unsqueeze(dim=0) + + recon, mu_map = self.transform_normalization(recon, mu_map) + + self.mu_maps[_id] = mu_map self.reconstructions[_id] = recon - print("Pre-loading images done!") + self.logger.debug("Pre-loading images done!") def __getitem__(self, index: int): row = self.table.iloc[index] @@ -102,22 +129,7 @@ class MuMapDataset(Dataset): recon = self.reconstructions[_id] mu_map = self.mu_maps[_id] - # recon_file = os.path.join(self.dir_images, row[headers.file_recon_nac_nsc]) - # mu_map_file = os.path.join(self.dir_images, row[headers.file_mu_map]) - - # recon = pydicom.dcmread(recon_file).pixel_array - # mu_map = pydicom.dcmread(mu_map_file).pixel_array - - # if self.discard_mu_map_slices: - # mu_map = discard_slices(row, mu_map) - - # if self.bed_contours: - # bed_contour = self.bed_contours[row["id"]] - # for i in range(mu_map.shape[0]): - # mu_map[i] = cv.drawContours(mu_map[i], [bed_contour], -1, 0.0, -1) - - # if self.align: - # recon = align_images(recon, mu_map) + recon, mu_map = self.transform_augmentation(recon, mu_map) return recon, mu_map @@ -127,46 +139,9 @@ class MuMapDataset(Dataset): __all__ = [MuMapDataset.__name__] -if __name__ == "__main__": - import argparse - +def main(dataset): from mu_map.util import to_grayscale, COLOR_WHITE - parser = argparse.ArgumentParser( - description="Visualize the images of a MuMapDataset", - formatter_class=argparse.ArgumentDefaultsHelpFormatter, - ) - parser.add_argument( - "dataset_dir", type=str, help="the directory from which the dataset is loaded" - ) - parser.add_argument( - "--unaligned", - action="store_true", - help="do not perform center alignment of reconstruction an mu-map slices", - ) - parser.add_argument( - "--show_bed", - action="store_true", - help="do not remove the bed contour from the mu map", - ) - parser.add_argument( - "--full_mu_map", - action="store_true", - help="do not remove broken slices of the mu map", - ) - args = parser.parse_args() - - align = not args.unaligned - discard_mu_map_slices = not args.full_mu_map - bed_contours_file = None if args.show_bed else DEFAULT_BED_CONTOURS_FILENAME - - dataset = MuMapDataset( - args.dataset_dir, - align=align, - discard_mu_map_slices=discard_mu_map_slices, - bed_contours_file=bed_contours_file, - ) - wname = "Dataset" cv.namedWindow(wname, cv.WINDOW_NORMAL) cv.resizeWindow(wname, 1600, 900) @@ -194,17 +169,20 @@ if __name__ == "__main__": im = 0 recon, mu_map = dataset[i] + recon = recon.squeeze().numpy() + mu_map = mu_map.squeeze().numpy() print(f"{(i+1):>{len(str(len(dataset)))}}/{len(dataset)}", end="\r") cv.imshow(wname, combine_images((recon, mu_map), (ir, im))) key = cv.waitKey(timeout) + running = 0 while True: ir = (ir + 1) % recon.shape[0] im = (im + 1) % mu_map.shape[0] - cv.imshow(wname, combine_images((recon, mu_map), (ir, im))) - + to_show = combine_images((recon, mu_map), (ir, im)) + cv.imshow(wname, to_show) key = cv.waitKey(timeout) if key == ord("n"): @@ -218,3 +196,59 @@ if __name__ == "__main__": elif key == 81: # left arrow key ir = max(ir - 2, 0) im = max(im - 2, 0) + elif key == ord("s"): + cv.imwrite(f"{running:03d}.png", to_show) + running += 1 + + +if __name__ == "__main__": + import argparse + + from mu_map.logging import add_logging_args, get_logger_by_args + + parser = argparse.ArgumentParser( + description="Visualize the images of a MuMapDataset", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "dataset_dir", type=str, help="the directory from which the dataset is loaded" + ) + parser.add_argument( + "--split", + type=str, + choices=["train", "validation", "test"], + help="choose the split of the data for the dataset", + ) + parser.add_argument( + "--unaligned", + action="store_true", + help="do not perform center alignment of reconstruction an mu-map slices", + ) + parser.add_argument( + "--show_bed", + action="store_true", + help="do not remove the bed contour from the mu map", + ) + parser.add_argument( + "--full_mu_map", + action="store_true", + help="do not remove broken slices of the mu map", + ) + add_logging_args(parser, defaults={"--loglevel": "DEBUG"}) + args = parser.parse_args() + + align = not args.unaligned + discard_mu_map_slices = not args.full_mu_map + bed_contours_file = None if args.show_bed else DEFAULT_BED_CONTOURS_FILENAME + logger = get_logger_by_args(args) + + dataset = MuMapDataset( + args.dataset_dir, + align=align, + discard_mu_map_slices=discard_mu_map_slices, + bed_contours_file=bed_contours_file, + split_name=args.split, + logger=logger, + ) + main(dataset) + diff --git a/mu_map/dataset/mock.py b/mu_map/dataset/mock.py new file mode 100644 index 0000000000000000000000000000000000000000..8cfd95122c21b805f574abc44a1155fd7a49ffcc --- /dev/null +++ b/mu_map/dataset/mock.py @@ -0,0 +1,27 @@ +from mu_map.dataset.default import MuMapDataset +from mu_map.dataset.normalization import MaxNormTransform + + +class MuMapMockDataset(MuMapDataset): + def __init__(self, dataset_dir: str = "data/initial/", num_images: int = 16, logger=None): + super().__init__(dataset_dir=dataset_dir, transform_normalization=MaxNormTransform(), logger=logger) + self.len = num_images + + def __getitem__(self, index: int): + recon, mu_map = super().__getitem__(0) + recon = recon[:, :32, :, :] + mu_map = mu_map[:, :32, :, :] + return recon, mu_map / 40206.0 + + def __len__(self): + return self.len + + +if __name__ == "__main__": + import cv2 as cv + import numpy as np + + from mu_map.dataset.default import main + + dataset = MuMapMockDataset() + main(dataset) diff --git a/mu_map/dataset/normalization.py b/mu_map/dataset/normalization.py new file mode 100644 index 0000000000000000000000000000000000000000..9cfbfc6cbdf31c203ae30e595be956b9c41741c9 --- /dev/null +++ b/mu_map/dataset/normalization.py @@ -0,0 +1,53 @@ +from typing import Tuple + +from torch import Tensor + +from mu_map.dataset.transform import Transform + + +def norm_max(tensor: Tensor) -> Tensor: + return (tensor - tensor.min()) / (tensor.max() - tensor.min()) + + +class MaxNormTransform(Transform): + def __init__(self, max_vals: Tuple[float, float] = None): + self.max_vals = max_vals + + def __call__( + self, inputs: Tensor, outputs_expected: Tensor + ) -> Tuple[Tensor, Tensor]: + if self.max_vals: + return inputs / self.max_vals[0], outputs_expected / self.max_vals[1] + return norm_max(inputs), outputs_expected + + +def norm_mean(tensor: Tensor): + return tensor / tensor.mean() + + +class MeanNormTransform(Transform): + def __call__( + self, inputs: Tensor, outputs_expected: Tensor + ) -> Tuple[Tensor, Tensor]: + return norm_mean(inputs), outputs_expected + + +def norm_gaussian(tensor: Tensor): + return (tensor - tensor.mean()) / tensor.std() + + +class GaussianNormTransform(Transform): + def __call__( + self, inputs: Tensor, outputs_expected: Tensor + ) -> Tuple[Tensor, Tensor]: + return norm_gaussian(inputs), outputs_expected + + +__all__ = [ + norm_max.__name__, + norm_mean.__name__, + norm_gaussian.__name__, + MaxNormTransform.__name__, + MeanNormTransform.__name__, + GaussianNormTransform.__name__, +] diff --git a/mu_map/data/patch_dataset.py b/mu_map/dataset/patches.py similarity index 64% rename from mu_map/data/patch_dataset.py rename to mu_map/dataset/patches.py index cf6cb786551095c314e88f45ce40818488847f83..df9f2c971b991c2278a7e4af9a6376c601af6d71 100644 --- a/mu_map/data/patch_dataset.py +++ b/mu_map/dataset/patches.py @@ -1,7 +1,10 @@ import math import random -from mu_map.data.datasets import MuMapDataset +import numpy as np +import torch + +from mu_map.dataset.default import MuMapDataset class MuMapPatchDataset(MuMapDataset): @@ -16,44 +19,44 @@ class MuMapPatchDataset(MuMapDataset): self.generate_patches() def generate_patches(self): - for i, (recon, mu_map) in enumerate(zip(self.reconstructions, self.mu_maps)): + for _id in self.reconstructions: + recon = self.reconstructions[_id].squeeze() + mu_map = self.mu_maps[_id].squeeze() + assert ( recon.shape[0] == mu_map.shape[0] ), f"Reconstruction and MuMap were not aligned for patch dataset" - _id = self.table.iloc[i]["id"] - z_range = (0, max(recon.shape[0] - self.patch_size, 0)) # sometimes the mu_maps have fewer than 32 slices # in this case the z-axis will be padded to the patch size, but this means we only have a single option for z - y_range = (0, recon.shape[1] - self.patch_size) - x_range = (0, recon.shape[2] - self.patch_size) + y_range = (20, recon.shape[1] - self.patch_size - 20) + x_range = (20, recon.shape[2] - self.patch_size - 20) - padding = [(0, 0), (0, 0), (0, 0)] + padding = [0, 0, 0, 0, 0, 0, 0, 0] if recon.shape[0] < self.patch_size: diff = self.patch_size - recon.shape[0] - padding_bef = math.ceil(diff / 2) - padding_aft = math.floor(diff / 2) - padding[0] = (padding_bef, padding_aft) + padding[4] = math.ceil(diff / 2) + padding[5] = math.floor(diff / 2) for j in range(self.patches_per_image): z = random.randint(*z_range) y = random.randint(*y_range) x = random.randint(*x_range) - self.patches.append(_id, z, y, x) + self.patches.append((_id, z, y, x, padding)) - def __getitem___(self, index: int): + def __getitem__(self, index: int): _id, z, y, x, padding = self.patches[index] - s = self.patches + s = self.patch_size recon = self.reconstructions[_id] mu_map = self.mu_maps[_id] - recon = np.pad(recon, padding, mode="constant", constant_values=0) - mu_map = np.pad(mu_map, padding, mode="constant", constant_values=0) + recon = torch.nn.functional.pad(recon, padding, mode="constant", value=0) + mu_map = torch.nn.functional.pad(mu_map, padding, mode="constant", value=0) - recon = recon[z : z + s, y : y + s, x : x + s] - mu_map = mu_map[z : z + s, y : y + s, x : x + s] + recon = recon[:, z : z + s, y : y + s, x : x + s] + mu_map = mu_map[:, z : z + s, y : y + s, x : x + s] return recon, mu_map @@ -70,19 +73,18 @@ if __name__ == "__main__": cv.namedWindow(wname, cv.WINDOW_NORMAL) cv.resizeWindow(wname, 1600, 900) - dataset = MuMapPatchDataset("data/initial/", patches_per_image=5) + dataset = MuMapPatchDataset("data/initial/", patches_per_image=1) print(f"Images (Patches) in the dataset {len(dataset)}") def create_image(recon, mu_map, recon_orig, patch, _slice): - s = recon.shape[0] + s = dataset.patch_size _id, _, y, x, padding = patch - _recon_orig = np.pad(recon_orig, patch, mode="constant", constant_values=0) _recon_orig = recon_orig[_slice] _recon_orig = to_grayscale(_recon_orig) _recon_orig = grayscale_to_rgb(_recon_orig) - _recon_orig = cv.rectangle(_recon_orig, (x, y), (x + s, y + s), thickness=1) + _recon_orig = cv.rectangle(_recon_orig, (x, y), (x + s, y + s), color=(255, 0, 0), thickness=1) _recon_orig = cv.resize(_recon_orig, (512, 512), cv.INTER_AREA) _recon = recon[_slice] @@ -95,7 +97,7 @@ if __name__ == "__main__": _mu_map = cv.resize(_mu_map, (512, 512), cv.INTER_AREA) _mu_map = grayscale_to_rgb(_mu_map) - space = np.full((3, 512, 10), 239, np.uint8) + space = np.full((512, 10, 3), 239, np.uint8) return np.hstack((_recon, space, _mu_map, space, _recon_orig)) for i in range(len(dataset)): @@ -104,18 +106,24 @@ if __name__ == "__main__": patch = dataset.patches[i] _id, z, y, x, padding = patch print( - "Patch {str(i+1):>len(str(len(dataset)))}/{len(dataset)} - Location [{z:02d}, {y:02d}, {x:02d}] - Padding [{padding[0][0], padding[0][0]}]" + f"Patch {str(i+1):>{len(str(len(dataset)))}}/{len(dataset)} - Location [{z:02d}, {y:02d}, {x:02d}] - Padding [{padding[5], padding[6]}]" ) recon, mu_map = dataset[i] + recon = recon.squeeze().numpy() + mu_map = mu_map.squeeze().numpy() + recon_orig = dataset.reconstructions[_id] + recon_orig = torch.nn.functional.pad(recon_orig, padding, mode="constant", value=0) + recon_orig = recon_orig.squeeze().numpy() - cv.imshow(combine_images(recon, mu_map, recon_orig, patch, _i)) + cv.imshow(wname, create_image(recon, mu_map, recon_orig, patch, _i)) key = cv.waitKey(100) while True: _i = (_i + 1) % recon.shape[0] - cv.imshow(combine_images(recon, mu_map, recon_orig, patch, _i)) + cv.imshow(wname, create_image(recon, mu_map, recon_orig, patch, _i)) + key = cv.waitKey(100) if key == ord("n"): break diff --git a/mu_map/dataset/transform.py b/mu_map/dataset/transform.py new file mode 100644 index 0000000000000000000000000000000000000000..7b2685e0cebd568b611c4ac4bf22f69304906b33 --- /dev/null +++ b/mu_map/dataset/transform.py @@ -0,0 +1,36 @@ +from typing import List, Tuple + + +from torch import Tensor + + +class Transform: + """ + Interface of a transformer. A transformer can be initialized and then applied to + an input tensor and expected output tensor as returned by a dataset. It can be + used for normalization and data augmentation. + """ + + def __call__( + self, inputs: Tensor, outputs_expected: Tensor + ) -> Tuple[Tensor, Tensor]: + """ + Apply the transformer to a pair of inputs and expected outputs in a dataset. + """ + return inputs, outputs_expected + + +class SequenceTransform(Transform): + """ + A transformer that applies a sequence of transformers sequentially. + """ + + def __init__(self, transforms: List[Transform]): + self.transforms = transforms + + def __call__( + self, inputs: Tensor, outputs_expected: Tensor + ) -> Tuple[Tensor, Tensor]: + for transforms in self.transforms: + inputs, outputs_expected = transforms(inputs, outputs_expected) + return inputs, outputs_expected diff --git a/mu_map/logging.py b/mu_map/logging.py index 29bf6215e745399b151392f0a45ba8b44c2467c0..fb43420249f05405719411dd880b2e8d45ab157c 100644 --- a/mu_map/logging.py +++ b/mu_map/logging.py @@ -1,15 +1,16 @@ import argparse -import datetime +from dataclasses import dataclass +from datetime import datetime import logging from logging import Formatter, getLogger, StreamHandler from logging.handlers import WatchedFileHandler import os import shutil -from typing import Dict, Optional - +from typing import Dict, Optional, List +date_format="%m/%d/%Y %I:%M:%S" FORMATTER = Formatter( - fmt="%(asctime)s - %(levelname)s: %(message)s", datefmt="%m/%d/%Y %I:%M:%S" + fmt="%(asctime)s - %(levelname)7s - %(message)s", datefmt=date_format ) @@ -35,7 +36,7 @@ def add_logging_args(parser: argparse.ArgumentParser, defaults: Dict[str, str]): def timestamp_filename(filename: str): - timestamp = datetime.datetime.now().strftime("%Y-%m-%d-%H:%M:%S") + timestamp = datetime.now().strftime("%Y-%m-%d-%H:%M:%S") basename, ext = os.path.splitext(filename) return f"{basename}_{timestamp}{ext}" @@ -71,6 +72,34 @@ def get_logger_by_args(args): return get_logger(args.logfile, args.loglevel) +@dataclass +class LogLine: + time: datetime + loglevel: str + message: str + + def __repr__(self): + return f"{self.time.strftime(date_format)} - {self.loglevel:>7} - {self.message}" + + +def parse_line(logline): + _split = logline.strip().split("-") + assert len(_split) >= 3, f"A logged line should consists of a least three elements with the format [TIME - LOGLEVEL - MESSAGE] but got [{logline.strip()}]" + + time_str = _split[0].strip() + time = datetime.strptime(time_str, date_format) + + loglevel = _split[1].strip() + + message = "-".join(_split[2:]).strip() + return LogLine(time=time, loglevel=loglevel, message=message) + +def parse_file(logfile: str) -> List[LogLine]: + with open(logfile, mode="r") as f: + lines = f.readlines() + lines = map(parse_line, lines) + return list(lines) + if __name__ == "__main__": parser = argparse.ArgumentParser() add_logging_args(parser, defaults={"--loglevel": "DEBUG", "--logfile": "tmp.log"}) diff --git a/mu_map/models/unet.py b/mu_map/models/unet.py index 9a4d5c875b4ecfa2678d358a3504db084dcac1a6..2d0f6a536784f242e3724860cfe29023af6e5fa0 100644 --- a/mu_map/models/unet.py +++ b/mu_map/models/unet.py @@ -1,5 +1,6 @@ from typing import Optional, List +import torch import torch.nn as nn diff --git a/mu_map/test.py b/mu_map/test.py index 1da35dfb4f733be69e8ff23f51d1fbc744d4b9bc..33c04b243422651967c4bd96bdff2b8ee945acaf 100644 --- a/mu_map/test.py +++ b/mu_map/test.py @@ -1,26 +1,143 @@ +import cv2 as cv +import matplotlib.pyplot as plt +import numpy as np import torch -from .data.preprocessing import * +from mu_map.dataset.default import MuMapDataset +from mu_map.dataset.mock import MuMapMockDataset +from mu_map.dataset.normalization import norm_max +from mu_map.models.unet import UNet +from mu_map.util import to_grayscale, COLOR_WHITE -means = torch.full((10, 10, 10), 5.0) -stds = torch.full((10, 10, 10), 10.0) -x = torch.normal(means, stds) +torch.set_grad_enabled(False) -print(f"Before: mean={x.mean():.3f} std={x.std():.3f}") +dataset = MuMapMockDataset("data/initial/") -y = norm_gaussian(x) -print(f" After: mean={y.mean():.3f} std={y.std():.3f}") -y = GaussianNorm()(x) -print(f" After: mean={y.mean():.3f} std={y.std():.3f}") +model = UNet(in_channels=1, features=[8, 16]) +device = torch.device("cpu") +weights = torch.load("tmp/10.pth", map_location=device) +model.load_state_dict(weights) +model = model.eval() +recon, mu_map = dataset[0] +recon = recon.unsqueeze(dim=0) +recon = norm_max(recon) -import cv2 as cv -import numpy as np +output = model(recon) +output = output * 40206.0 + +diff = ((mu_map - output) ** 2).mean() +print(f"Diff: {diff:.3f}") + +output = output.squeeze().numpy() +mu_map = mu_map.squeeze().numpy() + +assert output.shape[0] == mu_map.shape[0] + +wname = "Dataset" +cv.namedWindow(wname, cv.WINDOW_NORMAL) +cv.resizeWindow(wname, 1600, 900) +space = np.full((1024, 10), 239, np.uint8) + +def to_display_image(image, _slice): + _image = to_grayscale(image[_slice], min_val=image.min(), max_val=image.max()) + _image = cv.resize(_image, (1024, 1024), cv.INTER_AREA) + _text = f"{str(_slice):>{len(str(image.shape[0]))}}/{str(image.shape[0])}" + _image = cv.putText( + _image, _text, (0, 30), cv.FONT_HERSHEY_SIMPLEX, 1, COLOR_WHITE, 3 + ) + return _image + +def com(image1, image2, _slice): + image1 = to_display_image(image1, _slice) + image2 = to_display_image(image2, _slice) + space = np.full((image1.shape[0], 10), 239, np.uint8) + return np.hstack((image1, space, image2)) + + +i = 0 +while True: + x = com(output, mu_map, i) + cv.imshow(wname, x) + key = cv.waitKey(100) + + if key == ord("q"): + break + + i = (i + 1) % output.shape[0] + + + + +# dataset = MuMapDataset("data/initial") + +# # print(" Recon || MuMap") +# # print(" Min | Max | Average || Min | Max | Average") +# r_max = [] +# r_avg = [] + +# r_max_p = [] +# r_avg_p = [] + +# r_avg_x = [] + +# m_max = [] +# for recon, mu_map in dataset: + # r_max.append(recon.max()) + # r_avg.append(recon.mean()) + + # recon_p = recon[:, :, 16:112, 16:112] + # r_max_p.append(recon_p.max()) + # r_avg_p.append(recon_p.mean()) + + # r_avg_x.append(recon.sum() / (recon > 0.0).sum()) + # # r_min = f"{recon.min():5.3f}" + # # r_max = f"{recon.max():5.3f}" + # # r_avg = f"{recon.mean():5.3f}" + # # m_min = f"{mu_map.min():5.3f}" + # # m_max = f"{mu_map.max():5.3f}" + # # m_avg = f"{mu_map.mean():5.3f}" + # # print(f"{r_min} | {r_max} | {r_avg} || {m_min} | {m_max} | {m_avg}") + # m_max.append(mu_map.max()) + # # print(mu_map.max()) +# r_max = np.array(r_max) +# r_avg = np.array(r_avg) + +# r_max_p = np.array(r_max_p) +# r_avg_p = np.array(r_avg_p) + +# r_avg_x = np.array(r_avg_x) + +# m_max = np.array(m_max) +# fig, ax = plt.subplots() +# ax.scatter(r_max, m_max) + +# # fig, axs = plt.subplots(4, 3, figsize=(16, 12)) +# # axs[0, 0].hist(r_max) +# # axs[0, 0].set_title("Max") +# # axs[1, 0].hist(r_avg) +# # axs[1, 0].set_title("Mean") +# # axs[2, 0].hist(r_max / r_avg) +# # axs[2, 0].set_title("Max / Mean") +# # axs[3, 0].hist(recon.flatten()) +# # axs[3, 0].set_title("Example Reconstruction") + +# # axs[0, 1].hist(r_max_p) +# # axs[0, 1].set_title("Max") +# # axs[1, 1].hist(r_avg_p) +# # axs[1, 1].set_title("Mean") +# # axs[2, 1].hist(r_max_p / r_avg_p) +# # axs[2, 1].set_title("Max / Mean") +# # axs[3, 1].hist(recon_p.flatten()) +# # axs[3, 1].set_title("Example Reconstruction") -x = np.zeros((512, 512), np.uint8) -cv.imshow("X", x) -key = cv.waitKey(0) -while key != ord("q"): - print(key) - key = cv.waitKey(0) +# # axs[0, 2].hist(r_max_p) +# # axs[0, 2].set_title("Max") +# # axs[1, 2].hist(r_avg_x) +# # axs[1, 2].set_title("Mean") +# # axs[2, 2].hist(r_max_p / r_avg_x) +# # axs[2, 2].set_title("Max / Mean") +# # axs[3, 2].hist(torch.masked_select(recon, (recon > 0.0))) +# # axs[3, 2].set_title("Example Reconstruction") +# plt.show() diff --git a/mu_map/training/default.py b/mu_map/training/default.py index 15330982e772dcb3298f0a99a6b8e77adebc0a97..adf89cefbd3a4a518893710f05f26769a0a81c92 100644 --- a/mu_map/training/default.py +++ b/mu_map/training/default.py @@ -1,44 +1,220 @@ +import os +from typing import Dict +import torch -class Training(): +from mu_map.logging import get_logger - def __init__(self, epochs): + +class Training: + def __init__( + self, + model: torch.nn.Module, + data_loaders: Dict[str, torch.utils.data.DataLoader], + epochs: int, + device: torch.device, + lr: float, + lr_decay_factor: float, + lr_decay_epoch: int, + snapshot_dir: str, + snapshot_epoch: int, + logger=None, + ): + self.model = model + self.data_loaders = data_loaders self.epochs = epochs + self.device = device + + self.lr = lr + self.lr_decay_factor = lr_decay_factor + self.lr_decay_epoch = lr_decay_epoch + + self.snapshot_dir = snapshot_dir + self.snapshot_epoch = snapshot_epoch + + self.logger = logger if logger is not None else get_logger() + + self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr) + self.lr_scheduler = torch.optim.lr_scheduler.StepLR( + self.optimizer, step_size=self.lr_decay_epoch, gamma=self.lr_decay_factor + ) + self.loss_func = torch.nn.MSELoss(reduction="mean") + def run(self): for epoch in range(1, self.epochs + 1): - self.run_epoch(self.data_loader["train"], phase="train") - loss_training = self.run_epoch(self.data_loader["train"], phase="eval") - loss_validation = self.run_epoch(self.data_loader["validation"], phase="eval") + logger.debug( + f"Run epoch {str(epoch):>{len(str(self.epochs))}}/{self.epochs} ..." + ) + self._run_epoch(self.data_loaders["train"], phase="train") + + loss_training = self._run_epoch(self.data_loaders["train"], phase="val") + logger.info( + f"Epoch {str(epoch):>{len(str(self.epochs))}}/{self.epochs} - Loss TRAIN: {loss_training:.4f}" + ) + loss_validation = self._run_epoch( + self.data_loaders["validation"], phase="val" + ) + logger.info( + f"Epoch {str(epoch):>{len(str(self.epochs))}}/{self.epochs} - Loss VAL: {loss_validation:.4f}" + ) # ToDo: log outputs and time + _previous = self.lr_scheduler.get_last_lr()[0] self.lr_scheduler.step() + logger.debug( + f"Update learning rate from {_previous:.4f} to {self.lr_scheduler.get_last_lr()[0]:.4f}" + ) - if epoch % self.snapshot_epoch: + if epoch % self.snapshot_epoch == 0: self.store_snapshot(epoch) - + logger.debug( + f"Finished epoch {str(epoch):>{len(str(self.epochs))}}/{self.epochs}" + ) - def run_epoch(self, data_loader, phase): + def _run_epoch(self, data_loader, phase): + logger.debug(f"Run epoch in phase {phase}") self.model.train() if phase == "train" else self.model.eval() epoch_loss = 0 - for inputs, labels in self.data_loader: + loss_updates = 0 + for i, (inputs, labels) in enumerate(data_loader): + print( + f"Batch {str(i):>{len(str(len(data_loader)))}}/{len(data_loader)}", + end="\r", + ) inputs = inputs.to(self.device) labels = labels.to(self.device) self.optimizer.zero_grad() with torch.set_grad_enabled(phase == "train"): outputs = self.model(inputs) - loss = self.loss(outputs, labels) + loss = self.loss_func(outputs, labels) if phase == "train": loss.backward() - optimizer.step() - - epoch_loss += loss.item() / inputs.size[0] - return epoch_loss + self.optimizer.step() + epoch_loss += loss.item() + loss_updates += 1 + return epoch_loss / loss_updates def store_snapshot(self, epoch): - pass + snapshot_file = f"{epoch:0{len(str(self.epochs))}d}.pth" + snapshot_file = os.path.join(self.snapshot_dir, snapshot_file) + logger.debug(f"Store snapshot at {snapshot_file}") + torch.save(self.model.state_dict(), snapshot_file) + + +if __name__ == "__main__": + import argparse + + from mu_map.dataset.mock import MuMapMockDataset + from mu_map.logging import add_logging_args, get_logger_by_args + from mu_map.models.unet import UNet + + parser = argparse.ArgumentParser( + description="Train a UNet model to predict μ-maps from reconstructed scatter images", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + + # Model Args + parser.add_argument( + "--features", + type=int, + nargs="+", + default=[8, 16], + help="number of features in the layers of the UNet structure", + ) + + # Dataset Args + # parser.add_argument("--features", type=int, nargs="+", default=[8, 16], help="number of features in the layers of the UNet structure") + + # Training Args + parser.add_argument( + "--output_dir", + type=str, + default="train_data", + help="directory in which results (snapshots and logs) of this training are saved", + ) + parser.add_argument( + "--epochs", + type=int, + default=10, + help="the number of epochs for which the model is trained", + ) + parser.add_argument( + "--device", + type=str, + default="cuda:0" if torch.cuda.is_available() else "cpu", + help="the device (cpu or gpu) with which the training is performed", + ) + parser.add_argument( + "--lr", type=float, default=0.1, help="the initial learning rate for training" + ) + parser.add_argument( + "--lr_decay_factor", + type=float, + default=0.99, + help="decay factor for the learning rate", + ) + parser.add_argument( + "--lr_decay_epoch", + type=int, + default=1, + help="frequency in epochs at which the learning rate is decayed", + ) + parser.add_argument( + "--snapshot_dir", + type=str, + default="snapshots", + help="directory under --output_dir where snapshots are stored", + ) + parser.add_argument( + "--snapshot_epoch", + type=int, + default=10, + help="frequency in epochs at which snapshots are stored", + ) + + # Logging Args + add_logging_args(parser, defaults={"--logfile": "train.log"}) + + args = parser.parse_args() + + if not os.path.exists(args.output_dir): + os.mkdir(args.output_dir) + + args.snapshot_dir = os.path.join(args.output_dir, args.snapshot_dir) + if not os.path.exists(args.snapshot_dir): + os.mkdir(args.snapshot_dir) + + args.logfile = os.path.join(args.output_dir, args.logfile) + + device = torch.device(args.device) + logger = get_logger_by_args(args) + + model = UNet(in_channels=1, features=args.features) + dataset = MuMapMockDataset(logger=logger) + data_loader_train = torch.utils.data.DataLoader( + dataset=dataset, batch_size=2, shuffle=True, pin_memory=True, num_workers=1 + ) + data_loader_val = torch.utils.data.DataLoader( + dataset=dataset, batch_size=2, shuffle=True, pin_memory=True, num_workers=1 + ) + data_loaders = {"train": data_loader_train, "validation": data_loader_val} + + training = Training( + model=model, + data_loaders=data_loaders, + epochs=args.epochs, + device=device, + lr=args.lr, + lr_decay_factor=args.lr_decay_factor, + lr_decay_epoch=args.lr_decay_epoch, + snapshot_dir=args.snapshot_dir, + snapshot_epoch=args.snapshot_epoch, + logger=logger, + ) + training.run() diff --git a/mu_map/util.py b/mu_map/util.py index 259a7bbb970f2b2aa8132355f8eb367664d8adb4..914793b49953f190912a5456a0e9207945711c6b 100644 --- a/mu_map/util.py +++ b/mu_map/util.py @@ -23,6 +23,9 @@ def to_grayscale( if max_val is None: max_val = img.max() + if (max_val - min_val) == 0: + return np.zeros(img.shape, np.uint8) + _img = (img - min_val) / (max_val - min_val) _img = (_img * 255).astype(np.uint8) return _img @@ -36,4 +39,4 @@ def grayscale_to_rgb(img: np.ndarray): :return: the image in rgb """ assert img.ndim == 2, f"grascale image has more than 2 dimensions {img.shape}" - return img.repeat(3).reshape((3, *img.shape)) + return img.repeat(3).reshape((*img.shape, 3)) diff --git a/mu_map/vis/loss_curve.py b/mu_map/vis/loss_curve.py new file mode 100644 index 0000000000000000000000000000000000000000..336b1ca949066b82662ff2628f64c883b2d973b8 --- /dev/null +++ b/mu_map/vis/loss_curve.py @@ -0,0 +1,66 @@ +import argparse + +import matplotlib.pyplot as plt +import numpy as np + +from mu_map.logging import parse_file + +SIZE_DEFAULT = 12 +plt.rc("font", family="Roboto") # controls default font +plt.rc("font", weight="normal") # controls default font +plt.rc("font", size=SIZE_DEFAULT) # controls default text sizes +plt.rc("axes", titlesize=18) # fontsize of the axes title + +# https://colorbrewer2.org/#type=diverging&scheme=RdBu&n=3lk +COLORS = ["#ef8a62", "#67a9cf"] + +parser = argparse.ArgumentParser(description="TODO", formatter_class=argparse.ArgumentDefaultsHelpFormatter) +parser.add_argument("logfile", type=str, help="TODO") +parser.add_argument("--normalize", action="store_true", help="TODO") +args = parser.parse_args() + +logs = parse_file(args.logfile) +logs = list(filter(lambda logline: logline.loglevel == "INFO", logs)) + +def parse_loss(logs, phase): + _logs = map(lambda logline: logline.message, logs) + _logs = filter(lambda log: phase in log, _logs) + _logs = filter(lambda log: "Loss" in log, _logs) + _logs = list(_logs) + + losses = map(lambda log: log.split("-")[-1].strip(), _logs) + losses = map(lambda log: log.split(":")[-1].strip(), losses) + losses = map(float, losses) + + epochs = map(lambda log: log.split("-")[0].strip(), _logs) + epochs = list(epochs) + epochs = map(lambda log: log.split(" ")[-1], epochs) + epochs = map(lambda log: log.split("/")[0], epochs) + epochs = map(int, epochs) + + return np.array(list(epochs)), np.array(list(losses)) + +phases = ["TRAIN", "VAL"] +labels = ["Training", "Validation"] + +fig, ax = plt.subplots() +for phase, label, color in zip(phases, labels, COLORS): + epochs, loss = parse_loss(logs, phase) + + if args.normalize: + loss = loss / loss.max() + + ax.plot(epochs, loss, label=label, color=color) + ax.scatter(epochs, loss, s=15, color=color) + +ax.spines["left"].set_visible(False) +ax.spines["right"].set_visible(False) +ax.spines["top"].set_visible(False) + +ax.grid(axis="y", alpha=0.7) +ax.legend() +ax.set_xlabel("Epoch") +ax.set_ylabel("Loss") +plt.tight_layout() +plt.show() +