Skip to content
Snippets Groups Projects
Commit db4917fa authored by Tamino Huxohl's avatar Tamino Huxohl
Browse files
parents 6d5bbf9c ac263a4e
No related branches found
No related tags found
No related merge requests found
import torch
def norm_max(tensor: torch.Tensor):
return (tensor - tensor.min()) / (tensor.max() - tensor.min())
class MaxNorm:
def __call__(self, tensor: torch.Tensor):
return norm_max(tensor)
def norm_mean(tensor: torch.Tensor):
return tensor / tensor.mean()
class MeanNorm:
def __call__(self, tensor: torch.Tensor):
return norm_mean(tensor)
def norm_gaussian(tensor: torch.Tensor):
return (tensor - tensor.mean()) / tensor.std()
class GaussianNorm:
def __call__(self, tensor: torch.Tensor):
return norm_gaussian(tensor)
__all__ = [
norm_max.__name__,
norm_mean.__name__,
norm_gaussian.__name__,
MaxNorm.__name__,
MeanNorm.__name__,
GaussianNorm.__name__,
]
import pandas as pd import pandas as pd
from typing import List from typing import Dict, List
def parse_split_str(_str: str, delimitier: str = "/") -> List[float]: def parse_split_str(_str: str, delimitier: str = "/") -> List[float]:
...@@ -23,7 +23,7 @@ def parse_split_str(_str: str, delimitier: str = "/") -> List[float]: ...@@ -23,7 +23,7 @@ def parse_split_str(_str: str, delimitier: str = "/") -> List[float]:
return split return split
def split_csv(data: pd.DataFrame, split_csv: str) -> List[pd.DataFrame]: def split_csv(data: pd.DataFrame, split_csv: str) -> Dict[str, pd.DataFrame]:
""" """
Split a data frames based on a file defining a split. Split a data frames based on a file defining a split.
...@@ -40,7 +40,9 @@ def split_csv(data: pd.DataFrame, split_csv: str) -> List[pd.DataFrame]: ...@@ -40,7 +40,9 @@ def split_csv(data: pd.DataFrame, split_csv: str) -> List[pd.DataFrame]:
lambda patient_ids: data[data["patient_id"].isin(patient_ids)], lambda patient_ids: data[data["patient_id"].isin(patient_ids)],
split_patient_ids, split_patient_ids,
) )
return list(splits) splits = zip(split_names, splits)
splits = dict(splits)
return splits
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -5,11 +5,15 @@ import cv2 as cv ...@@ -5,11 +5,15 @@ import cv2 as cv
import pandas as pd import pandas as pd
import pydicom import pydicom
import numpy as np import numpy as np
import torch
from torch.utils.data import Dataset from torch.utils.data import Dataset
from mu_map.data.prepare import headers from mu_map.data.prepare import headers
from mu_map.data.remove_bed import DEFAULT_BED_CONTOURS_FILENAME, load_contours from mu_map.data.remove_bed import DEFAULT_BED_CONTOURS_FILENAME, load_contours
from mu_map.data.review_mu_map import discard_slices from mu_map.data.review_mu_map import discard_slices
from mu_map.data.split import split_csv
from mu_map.dataset.transform import Transform
from mu_map.logging import get_logger
def align_images(image_1: np.ndarray, image_2: np.ndarray) -> np.ndarray: def align_images(image_1: np.ndarray, image_2: np.ndarray) -> np.ndarray:
...@@ -43,16 +47,26 @@ class MuMapDataset(Dataset): ...@@ -43,16 +47,26 @@ class MuMapDataset(Dataset):
self, self,
dataset_dir: str, dataset_dir: str,
csv_file: str = "meta.csv", csv_file: str = "meta.csv",
split_file: str = "split.csv",
split_name: str = None,
images_dir: str = "images", images_dir: str = "images",
bed_contours_file: Optional[str] = DEFAULT_BED_CONTOURS_FILENAME, bed_contours_file: Optional[str] = DEFAULT_BED_CONTOURS_FILENAME,
discard_mu_map_slices: bool = True, discard_mu_map_slices: bool = True,
align: bool = True, align: bool = True,
transform_normalization: Transform = Transform(),
transform_augmentation: Transform = Transform(),
logger=None,
): ):
super().__init__() super().__init__()
self.dir = dataset_dir self.dir = dataset_dir
self.dir_images = os.path.join(dataset_dir, images_dir) self.dir_images = os.path.join(dataset_dir, images_dir)
self.csv_file = os.path.join(dataset_dir, csv_file) self.csv_file = os.path.join(dataset_dir, csv_file)
self.split_file = os.path.join(dataset_dir, split_file)
self.transform_normalization = transform_normalization
self.transform_augmentation = transform_augmentation
self.logger = logger if logger is not None else get_logger()
self.bed_contours_file = ( self.bed_contours_file = (
os.path.join(dataset_dir, bed_contours_file) if bed_contours_file else None os.path.join(dataset_dir, bed_contours_file) if bed_contours_file else None
...@@ -63,6 +77,8 @@ class MuMapDataset(Dataset): ...@@ -63,6 +77,8 @@ class MuMapDataset(Dataset):
# read CSV file and from that access DICOM files # read CSV file and from that access DICOM files
self.table = pd.read_csv(self.csv_file) self.table = pd.read_csv(self.csv_file)
if split_name:
self.table = split_csv(self.table, self.split_file)[split_name]
self.table["id"] = self.table["id"].apply(int) self.table["id"] = self.table["id"].apply(int)
self.discard_mu_map_slices = discard_mu_map_slices self.discard_mu_map_slices = discard_mu_map_slices
...@@ -73,7 +89,7 @@ class MuMapDataset(Dataset): ...@@ -73,7 +89,7 @@ class MuMapDataset(Dataset):
self.pre_load_images() self.pre_load_images()
def pre_load_images(self): def pre_load_images(self):
print("Pre-loading images ...", end="\r") self.logger.debug("Pre-loading images ...")
for i in range(len(self.table)): for i in range(len(self.table)):
row = self.table.iloc[i] row = self.table.iloc[i]
_id = row["id"] _id = row["id"]
...@@ -86,14 +102,25 @@ class MuMapDataset(Dataset): ...@@ -86,14 +102,25 @@ class MuMapDataset(Dataset):
bed_contour = self.bed_contours[row["id"]] bed_contour = self.bed_contours[row["id"]]
for i in range(mu_map.shape[0]): for i in range(mu_map.shape[0]):
mu_map[i] = cv.drawContours(mu_map[i], [bed_contour], -1, 0.0, -1) mu_map[i] = cv.drawContours(mu_map[i], [bed_contour], -1, 0.0, -1)
self.mu_maps[_id] = mu_map
recon_file = os.path.join(self.dir_images, row[headers.file_recon_nac_nsc]) recon_file = os.path.join(self.dir_images, row[headers.file_recon_nac_nsc])
recon = pydicom.dcmread(recon_file).pixel_array recon = pydicom.dcmread(recon_file).pixel_array
if self.align: if self.align:
recon = align_images(recon, mu_map) recon = align_images(recon, mu_map)
mu_map = mu_map.astype(np.float32)
mu_map = torch.from_numpy(mu_map)
mu_map = mu_map.unsqueeze(dim=0)
recon = recon.astype(np.float32)
recon = torch.from_numpy(recon)
recon = recon.unsqueeze(dim=0)
recon, mu_map = self.transform_normalization(recon, mu_map)
self.mu_maps[_id] = mu_map
self.reconstructions[_id] = recon self.reconstructions[_id] = recon
print("Pre-loading images done!") self.logger.debug("Pre-loading images done!")
def __getitem__(self, index: int): def __getitem__(self, index: int):
row = self.table.iloc[index] row = self.table.iloc[index]
...@@ -102,22 +129,7 @@ class MuMapDataset(Dataset): ...@@ -102,22 +129,7 @@ class MuMapDataset(Dataset):
recon = self.reconstructions[_id] recon = self.reconstructions[_id]
mu_map = self.mu_maps[_id] mu_map = self.mu_maps[_id]
# recon_file = os.path.join(self.dir_images, row[headers.file_recon_nac_nsc]) recon, mu_map = self.transform_augmentation(recon, mu_map)
# mu_map_file = os.path.join(self.dir_images, row[headers.file_mu_map])
# recon = pydicom.dcmread(recon_file).pixel_array
# mu_map = pydicom.dcmread(mu_map_file).pixel_array
# if self.discard_mu_map_slices:
# mu_map = discard_slices(row, mu_map)
# if self.bed_contours:
# bed_contour = self.bed_contours[row["id"]]
# for i in range(mu_map.shape[0]):
# mu_map[i] = cv.drawContours(mu_map[i], [bed_contour], -1, 0.0, -1)
# if self.align:
# recon = align_images(recon, mu_map)
return recon, mu_map return recon, mu_map
...@@ -127,46 +139,9 @@ class MuMapDataset(Dataset): ...@@ -127,46 +139,9 @@ class MuMapDataset(Dataset):
__all__ = [MuMapDataset.__name__] __all__ = [MuMapDataset.__name__]
if __name__ == "__main__": def main(dataset):
import argparse
from mu_map.util import to_grayscale, COLOR_WHITE from mu_map.util import to_grayscale, COLOR_WHITE
parser = argparse.ArgumentParser(
description="Visualize the images of a MuMapDataset",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"dataset_dir", type=str, help="the directory from which the dataset is loaded"
)
parser.add_argument(
"--unaligned",
action="store_true",
help="do not perform center alignment of reconstruction an mu-map slices",
)
parser.add_argument(
"--show_bed",
action="store_true",
help="do not remove the bed contour from the mu map",
)
parser.add_argument(
"--full_mu_map",
action="store_true",
help="do not remove broken slices of the mu map",
)
args = parser.parse_args()
align = not args.unaligned
discard_mu_map_slices = not args.full_mu_map
bed_contours_file = None if args.show_bed else DEFAULT_BED_CONTOURS_FILENAME
dataset = MuMapDataset(
args.dataset_dir,
align=align,
discard_mu_map_slices=discard_mu_map_slices,
bed_contours_file=bed_contours_file,
)
wname = "Dataset" wname = "Dataset"
cv.namedWindow(wname, cv.WINDOW_NORMAL) cv.namedWindow(wname, cv.WINDOW_NORMAL)
cv.resizeWindow(wname, 1600, 900) cv.resizeWindow(wname, 1600, 900)
...@@ -194,17 +169,20 @@ if __name__ == "__main__": ...@@ -194,17 +169,20 @@ if __name__ == "__main__":
im = 0 im = 0
recon, mu_map = dataset[i] recon, mu_map = dataset[i]
recon = recon.squeeze().numpy()
mu_map = mu_map.squeeze().numpy()
print(f"{(i+1):>{len(str(len(dataset)))}}/{len(dataset)}", end="\r") print(f"{(i+1):>{len(str(len(dataset)))}}/{len(dataset)}", end="\r")
cv.imshow(wname, combine_images((recon, mu_map), (ir, im))) cv.imshow(wname, combine_images((recon, mu_map), (ir, im)))
key = cv.waitKey(timeout) key = cv.waitKey(timeout)
running = 0
while True: while True:
ir = (ir + 1) % recon.shape[0] ir = (ir + 1) % recon.shape[0]
im = (im + 1) % mu_map.shape[0] im = (im + 1) % mu_map.shape[0]
cv.imshow(wname, combine_images((recon, mu_map), (ir, im))) to_show = combine_images((recon, mu_map), (ir, im))
cv.imshow(wname, to_show)
key = cv.waitKey(timeout) key = cv.waitKey(timeout)
if key == ord("n"): if key == ord("n"):
...@@ -218,3 +196,59 @@ if __name__ == "__main__": ...@@ -218,3 +196,59 @@ if __name__ == "__main__":
elif key == 81: # left arrow key elif key == 81: # left arrow key
ir = max(ir - 2, 0) ir = max(ir - 2, 0)
im = max(im - 2, 0) im = max(im - 2, 0)
elif key == ord("s"):
cv.imwrite(f"{running:03d}.png", to_show)
running += 1
if __name__ == "__main__":
import argparse
from mu_map.logging import add_logging_args, get_logger_by_args
parser = argparse.ArgumentParser(
description="Visualize the images of a MuMapDataset",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"dataset_dir", type=str, help="the directory from which the dataset is loaded"
)
parser.add_argument(
"--split",
type=str,
choices=["train", "validation", "test"],
help="choose the split of the data for the dataset",
)
parser.add_argument(
"--unaligned",
action="store_true",
help="do not perform center alignment of reconstruction an mu-map slices",
)
parser.add_argument(
"--show_bed",
action="store_true",
help="do not remove the bed contour from the mu map",
)
parser.add_argument(
"--full_mu_map",
action="store_true",
help="do not remove broken slices of the mu map",
)
add_logging_args(parser, defaults={"--loglevel": "DEBUG"})
args = parser.parse_args()
align = not args.unaligned
discard_mu_map_slices = not args.full_mu_map
bed_contours_file = None if args.show_bed else DEFAULT_BED_CONTOURS_FILENAME
logger = get_logger_by_args(args)
dataset = MuMapDataset(
args.dataset_dir,
align=align,
discard_mu_map_slices=discard_mu_map_slices,
bed_contours_file=bed_contours_file,
split_name=args.split,
logger=logger,
)
main(dataset)
from mu_map.dataset.default import MuMapDataset
from mu_map.dataset.normalization import MaxNormTransform
class MuMapMockDataset(MuMapDataset):
def __init__(self, dataset_dir: str = "data/initial/", num_images: int = 16, logger=None):
super().__init__(dataset_dir=dataset_dir, transform_normalization=MaxNormTransform(), logger=logger)
self.len = num_images
def __getitem__(self, index: int):
recon, mu_map = super().__getitem__(0)
recon = recon[:, :32, :, :]
mu_map = mu_map[:, :32, :, :]
return recon, mu_map / 40206.0
def __len__(self):
return self.len
if __name__ == "__main__":
import cv2 as cv
import numpy as np
from mu_map.dataset.default import main
dataset = MuMapMockDataset()
main(dataset)
from typing import Tuple
from torch import Tensor
from mu_map.dataset.transform import Transform
def norm_max(tensor: Tensor) -> Tensor:
return (tensor - tensor.min()) / (tensor.max() - tensor.min())
class MaxNormTransform(Transform):
def __init__(self, max_vals: Tuple[float, float] = None):
self.max_vals = max_vals
def __call__(
self, inputs: Tensor, outputs_expected: Tensor
) -> Tuple[Tensor, Tensor]:
if self.max_vals:
return inputs / self.max_vals[0], outputs_expected / self.max_vals[1]
return norm_max(inputs), outputs_expected
def norm_mean(tensor: Tensor):
return tensor / tensor.mean()
class MeanNormTransform(Transform):
def __call__(
self, inputs: Tensor, outputs_expected: Tensor
) -> Tuple[Tensor, Tensor]:
return norm_mean(inputs), outputs_expected
def norm_gaussian(tensor: Tensor):
return (tensor - tensor.mean()) / tensor.std()
class GaussianNormTransform(Transform):
def __call__(
self, inputs: Tensor, outputs_expected: Tensor
) -> Tuple[Tensor, Tensor]:
return norm_gaussian(inputs), outputs_expected
__all__ = [
norm_max.__name__,
norm_mean.__name__,
norm_gaussian.__name__,
MaxNormTransform.__name__,
MeanNormTransform.__name__,
GaussianNormTransform.__name__,
]
import math import math
import random import random
from mu_map.data.datasets import MuMapDataset import numpy as np
import torch
from mu_map.dataset.default import MuMapDataset
class MuMapPatchDataset(MuMapDataset): class MuMapPatchDataset(MuMapDataset):
...@@ -16,44 +19,44 @@ class MuMapPatchDataset(MuMapDataset): ...@@ -16,44 +19,44 @@ class MuMapPatchDataset(MuMapDataset):
self.generate_patches() self.generate_patches()
def generate_patches(self): def generate_patches(self):
for i, (recon, mu_map) in enumerate(zip(self.reconstructions, self.mu_maps)): for _id in self.reconstructions:
recon = self.reconstructions[_id].squeeze()
mu_map = self.mu_maps[_id].squeeze()
assert ( assert (
recon.shape[0] == mu_map.shape[0] recon.shape[0] == mu_map.shape[0]
), f"Reconstruction and MuMap were not aligned for patch dataset" ), f"Reconstruction and MuMap were not aligned for patch dataset"
_id = self.table.iloc[i]["id"]
z_range = (0, max(recon.shape[0] - self.patch_size, 0)) z_range = (0, max(recon.shape[0] - self.patch_size, 0))
# sometimes the mu_maps have fewer than 32 slices # sometimes the mu_maps have fewer than 32 slices
# in this case the z-axis will be padded to the patch size, but this means we only have a single option for z # in this case the z-axis will be padded to the patch size, but this means we only have a single option for z
y_range = (0, recon.shape[1] - self.patch_size) y_range = (20, recon.shape[1] - self.patch_size - 20)
x_range = (0, recon.shape[2] - self.patch_size) x_range = (20, recon.shape[2] - self.patch_size - 20)
padding = [(0, 0), (0, 0), (0, 0)] padding = [0, 0, 0, 0, 0, 0, 0, 0]
if recon.shape[0] < self.patch_size: if recon.shape[0] < self.patch_size:
diff = self.patch_size - recon.shape[0] diff = self.patch_size - recon.shape[0]
padding_bef = math.ceil(diff / 2) padding[4] = math.ceil(diff / 2)
padding_aft = math.floor(diff / 2) padding[5] = math.floor(diff / 2)
padding[0] = (padding_bef, padding_aft)
for j in range(self.patches_per_image): for j in range(self.patches_per_image):
z = random.randint(*z_range) z = random.randint(*z_range)
y = random.randint(*y_range) y = random.randint(*y_range)
x = random.randint(*x_range) x = random.randint(*x_range)
self.patches.append(_id, z, y, x) self.patches.append((_id, z, y, x, padding))
def __getitem___(self, index: int): def __getitem__(self, index: int):
_id, z, y, x, padding = self.patches[index] _id, z, y, x, padding = self.patches[index]
s = self.patches s = self.patch_size
recon = self.reconstructions[_id] recon = self.reconstructions[_id]
mu_map = self.mu_maps[_id] mu_map = self.mu_maps[_id]
recon = np.pad(recon, padding, mode="constant", constant_values=0) recon = torch.nn.functional.pad(recon, padding, mode="constant", value=0)
mu_map = np.pad(mu_map, padding, mode="constant", constant_values=0) mu_map = torch.nn.functional.pad(mu_map, padding, mode="constant", value=0)
recon = recon[z : z + s, y : y + s, x : x + s] recon = recon[:, z : z + s, y : y + s, x : x + s]
mu_map = mu_map[z : z + s, y : y + s, x : x + s] mu_map = mu_map[:, z : z + s, y : y + s, x : x + s]
return recon, mu_map return recon, mu_map
...@@ -70,19 +73,18 @@ if __name__ == "__main__": ...@@ -70,19 +73,18 @@ if __name__ == "__main__":
cv.namedWindow(wname, cv.WINDOW_NORMAL) cv.namedWindow(wname, cv.WINDOW_NORMAL)
cv.resizeWindow(wname, 1600, 900) cv.resizeWindow(wname, 1600, 900)
dataset = MuMapPatchDataset("data/initial/", patches_per_image=5) dataset = MuMapPatchDataset("data/initial/", patches_per_image=1)
print(f"Images (Patches) in the dataset {len(dataset)}") print(f"Images (Patches) in the dataset {len(dataset)}")
def create_image(recon, mu_map, recon_orig, patch, _slice): def create_image(recon, mu_map, recon_orig, patch, _slice):
s = recon.shape[0] s = dataset.patch_size
_id, _, y, x, padding = patch _id, _, y, x, padding = patch
_recon_orig = np.pad(recon_orig, patch, mode="constant", constant_values=0)
_recon_orig = recon_orig[_slice] _recon_orig = recon_orig[_slice]
_recon_orig = to_grayscale(_recon_orig) _recon_orig = to_grayscale(_recon_orig)
_recon_orig = grayscale_to_rgb(_recon_orig) _recon_orig = grayscale_to_rgb(_recon_orig)
_recon_orig = cv.rectangle(_recon_orig, (x, y), (x + s, y + s), thickness=1) _recon_orig = cv.rectangle(_recon_orig, (x, y), (x + s, y + s), color=(255, 0, 0), thickness=1)
_recon_orig = cv.resize(_recon_orig, (512, 512), cv.INTER_AREA) _recon_orig = cv.resize(_recon_orig, (512, 512), cv.INTER_AREA)
_recon = recon[_slice] _recon = recon[_slice]
...@@ -95,7 +97,7 @@ if __name__ == "__main__": ...@@ -95,7 +97,7 @@ if __name__ == "__main__":
_mu_map = cv.resize(_mu_map, (512, 512), cv.INTER_AREA) _mu_map = cv.resize(_mu_map, (512, 512), cv.INTER_AREA)
_mu_map = grayscale_to_rgb(_mu_map) _mu_map = grayscale_to_rgb(_mu_map)
space = np.full((3, 512, 10), 239, np.uint8) space = np.full((512, 10, 3), 239, np.uint8)
return np.hstack((_recon, space, _mu_map, space, _recon_orig)) return np.hstack((_recon, space, _mu_map, space, _recon_orig))
for i in range(len(dataset)): for i in range(len(dataset)):
...@@ -104,18 +106,24 @@ if __name__ == "__main__": ...@@ -104,18 +106,24 @@ if __name__ == "__main__":
patch = dataset.patches[i] patch = dataset.patches[i]
_id, z, y, x, padding = patch _id, z, y, x, padding = patch
print( print(
"Patch {str(i+1):>len(str(len(dataset)))}/{len(dataset)} - Location [{z:02d}, {y:02d}, {x:02d}] - Padding [{padding[0][0], padding[0][0]}]" f"Patch {str(i+1):>{len(str(len(dataset)))}}/{len(dataset)} - Location [{z:02d}, {y:02d}, {x:02d}] - Padding [{padding[5], padding[6]}]"
) )
recon, mu_map = dataset[i] recon, mu_map = dataset[i]
recon = recon.squeeze().numpy()
mu_map = mu_map.squeeze().numpy()
recon_orig = dataset.reconstructions[_id] recon_orig = dataset.reconstructions[_id]
recon_orig = torch.nn.functional.pad(recon_orig, padding, mode="constant", value=0)
recon_orig = recon_orig.squeeze().numpy()
cv.imshow(combine_images(recon, mu_map, recon_orig, patch, _i)) cv.imshow(wname, create_image(recon, mu_map, recon_orig, patch, _i))
key = cv.waitKey(100) key = cv.waitKey(100)
while True: while True:
_i = (_i + 1) % recon.shape[0] _i = (_i + 1) % recon.shape[0]
cv.imshow(combine_images(recon, mu_map, recon_orig, patch, _i)) cv.imshow(wname, create_image(recon, mu_map, recon_orig, patch, _i))
key = cv.waitKey(100)
if key == ord("n"): if key == ord("n"):
break break
......
from typing import List, Tuple
from torch import Tensor
class Transform:
"""
Interface of a transformer. A transformer can be initialized and then applied to
an input tensor and expected output tensor as returned by a dataset. It can be
used for normalization and data augmentation.
"""
def __call__(
self, inputs: Tensor, outputs_expected: Tensor
) -> Tuple[Tensor, Tensor]:
"""
Apply the transformer to a pair of inputs and expected outputs in a dataset.
"""
return inputs, outputs_expected
class SequenceTransform(Transform):
"""
A transformer that applies a sequence of transformers sequentially.
"""
def __init__(self, transforms: List[Transform]):
self.transforms = transforms
def __call__(
self, inputs: Tensor, outputs_expected: Tensor
) -> Tuple[Tensor, Tensor]:
for transforms in self.transforms:
inputs, outputs_expected = transforms(inputs, outputs_expected)
return inputs, outputs_expected
import argparse import argparse
import datetime from dataclasses import dataclass
from datetime import datetime
import logging import logging
from logging import Formatter, getLogger, StreamHandler from logging import Formatter, getLogger, StreamHandler
from logging.handlers import WatchedFileHandler from logging.handlers import WatchedFileHandler
import os import os
import shutil import shutil
from typing import Dict, Optional from typing import Dict, Optional, List
date_format="%m/%d/%Y %I:%M:%S"
FORMATTER = Formatter( FORMATTER = Formatter(
fmt="%(asctime)s - %(levelname)s: %(message)s", datefmt="%m/%d/%Y %I:%M:%S" fmt="%(asctime)s - %(levelname)7s - %(message)s", datefmt=date_format
) )
...@@ -35,7 +36,7 @@ def add_logging_args(parser: argparse.ArgumentParser, defaults: Dict[str, str]): ...@@ -35,7 +36,7 @@ def add_logging_args(parser: argparse.ArgumentParser, defaults: Dict[str, str]):
def timestamp_filename(filename: str): def timestamp_filename(filename: str):
timestamp = datetime.datetime.now().strftime("%Y-%m-%d-%H:%M:%S") timestamp = datetime.now().strftime("%Y-%m-%d-%H:%M:%S")
basename, ext = os.path.splitext(filename) basename, ext = os.path.splitext(filename)
return f"{basename}_{timestamp}{ext}" return f"{basename}_{timestamp}{ext}"
...@@ -71,6 +72,34 @@ def get_logger_by_args(args): ...@@ -71,6 +72,34 @@ def get_logger_by_args(args):
return get_logger(args.logfile, args.loglevel) return get_logger(args.logfile, args.loglevel)
@dataclass
class LogLine:
time: datetime
loglevel: str
message: str
def __repr__(self):
return f"{self.time.strftime(date_format)} - {self.loglevel:>7} - {self.message}"
def parse_line(logline):
_split = logline.strip().split("-")
assert len(_split) >= 3, f"A logged line should consists of a least three elements with the format [TIME - LOGLEVEL - MESSAGE] but got [{logline.strip()}]"
time_str = _split[0].strip()
time = datetime.strptime(time_str, date_format)
loglevel = _split[1].strip()
message = "-".join(_split[2:]).strip()
return LogLine(time=time, loglevel=loglevel, message=message)
def parse_file(logfile: str) -> List[LogLine]:
with open(logfile, mode="r") as f:
lines = f.readlines()
lines = map(parse_line, lines)
return list(lines)
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
add_logging_args(parser, defaults={"--loglevel": "DEBUG", "--logfile": "tmp.log"}) add_logging_args(parser, defaults={"--loglevel": "DEBUG", "--logfile": "tmp.log"})
......
from typing import Optional, List from typing import Optional, List
import torch
import torch.nn as nn import torch.nn as nn
......
import cv2 as cv
import matplotlib.pyplot as plt
import numpy as np
import torch import torch
from .data.preprocessing import * from mu_map.dataset.default import MuMapDataset
from mu_map.dataset.mock import MuMapMockDataset
from mu_map.dataset.normalization import norm_max
from mu_map.models.unet import UNet
from mu_map.util import to_grayscale, COLOR_WHITE
means = torch.full((10, 10, 10), 5.0) torch.set_grad_enabled(False)
stds = torch.full((10, 10, 10), 10.0)
x = torch.normal(means, stds)
print(f"Before: mean={x.mean():.3f} std={x.std():.3f}") dataset = MuMapMockDataset("data/initial/")
y = norm_gaussian(x) model = UNet(in_channels=1, features=[8, 16])
print(f" After: mean={y.mean():.3f} std={y.std():.3f}") device = torch.device("cpu")
y = GaussianNorm()(x) weights = torch.load("tmp/10.pth", map_location=device)
print(f" After: mean={y.mean():.3f} std={y.std():.3f}") model.load_state_dict(weights)
model = model.eval()
recon, mu_map = dataset[0]
recon = recon.unsqueeze(dim=0)
recon = norm_max(recon)
import cv2 as cv output = model(recon)
import numpy as np output = output * 40206.0
diff = ((mu_map - output) ** 2).mean()
print(f"Diff: {diff:.3f}")
output = output.squeeze().numpy()
mu_map = mu_map.squeeze().numpy()
assert output.shape[0] == mu_map.shape[0]
wname = "Dataset"
cv.namedWindow(wname, cv.WINDOW_NORMAL)
cv.resizeWindow(wname, 1600, 900)
space = np.full((1024, 10), 239, np.uint8)
def to_display_image(image, _slice):
_image = to_grayscale(image[_slice], min_val=image.min(), max_val=image.max())
_image = cv.resize(_image, (1024, 1024), cv.INTER_AREA)
_text = f"{str(_slice):>{len(str(image.shape[0]))}}/{str(image.shape[0])}"
_image = cv.putText(
_image, _text, (0, 30), cv.FONT_HERSHEY_SIMPLEX, 1, COLOR_WHITE, 3
)
return _image
def com(image1, image2, _slice):
image1 = to_display_image(image1, _slice)
image2 = to_display_image(image2, _slice)
space = np.full((image1.shape[0], 10), 239, np.uint8)
return np.hstack((image1, space, image2))
i = 0
while True:
x = com(output, mu_map, i)
cv.imshow(wname, x)
key = cv.waitKey(100)
if key == ord("q"):
break
i = (i + 1) % output.shape[0]
# dataset = MuMapDataset("data/initial")
# # print(" Recon || MuMap")
# # print(" Min | Max | Average || Min | Max | Average")
# r_max = []
# r_avg = []
# r_max_p = []
# r_avg_p = []
# r_avg_x = []
# m_max = []
# for recon, mu_map in dataset:
# r_max.append(recon.max())
# r_avg.append(recon.mean())
# recon_p = recon[:, :, 16:112, 16:112]
# r_max_p.append(recon_p.max())
# r_avg_p.append(recon_p.mean())
# r_avg_x.append(recon.sum() / (recon > 0.0).sum())
# # r_min = f"{recon.min():5.3f}"
# # r_max = f"{recon.max():5.3f}"
# # r_avg = f"{recon.mean():5.3f}"
# # m_min = f"{mu_map.min():5.3f}"
# # m_max = f"{mu_map.max():5.3f}"
# # m_avg = f"{mu_map.mean():5.3f}"
# # print(f"{r_min} | {r_max} | {r_avg} || {m_min} | {m_max} | {m_avg}")
# m_max.append(mu_map.max())
# # print(mu_map.max())
# r_max = np.array(r_max)
# r_avg = np.array(r_avg)
# r_max_p = np.array(r_max_p)
# r_avg_p = np.array(r_avg_p)
# r_avg_x = np.array(r_avg_x)
# m_max = np.array(m_max)
# fig, ax = plt.subplots()
# ax.scatter(r_max, m_max)
# # fig, axs = plt.subplots(4, 3, figsize=(16, 12))
# # axs[0, 0].hist(r_max)
# # axs[0, 0].set_title("Max")
# # axs[1, 0].hist(r_avg)
# # axs[1, 0].set_title("Mean")
# # axs[2, 0].hist(r_max / r_avg)
# # axs[2, 0].set_title("Max / Mean")
# # axs[3, 0].hist(recon.flatten())
# # axs[3, 0].set_title("Example Reconstruction")
# # axs[0, 1].hist(r_max_p)
# # axs[0, 1].set_title("Max")
# # axs[1, 1].hist(r_avg_p)
# # axs[1, 1].set_title("Mean")
# # axs[2, 1].hist(r_max_p / r_avg_p)
# # axs[2, 1].set_title("Max / Mean")
# # axs[3, 1].hist(recon_p.flatten())
# # axs[3, 1].set_title("Example Reconstruction")
x = np.zeros((512, 512), np.uint8) # # axs[0, 2].hist(r_max_p)
cv.imshow("X", x) # # axs[0, 2].set_title("Max")
key = cv.waitKey(0) # # axs[1, 2].hist(r_avg_x)
while key != ord("q"): # # axs[1, 2].set_title("Mean")
print(key) # # axs[2, 2].hist(r_max_p / r_avg_x)
key = cv.waitKey(0) # # axs[2, 2].set_title("Max / Mean")
# # axs[3, 2].hist(torch.masked_select(recon, (recon > 0.0)))
# # axs[3, 2].set_title("Example Reconstruction")
# plt.show()
import os
from typing import Dict
import torch
class Training(): from mu_map.logging import get_logger
def __init__(self, epochs):
class Training:
def __init__(
self,
model: torch.nn.Module,
data_loaders: Dict[str, torch.utils.data.DataLoader],
epochs: int,
device: torch.device,
lr: float,
lr_decay_factor: float,
lr_decay_epoch: int,
snapshot_dir: str,
snapshot_epoch: int,
logger=None,
):
self.model = model
self.data_loaders = data_loaders
self.epochs = epochs self.epochs = epochs
self.device = device
self.lr = lr
self.lr_decay_factor = lr_decay_factor
self.lr_decay_epoch = lr_decay_epoch
self.snapshot_dir = snapshot_dir
self.snapshot_epoch = snapshot_epoch
self.logger = logger if logger is not None else get_logger()
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr)
self.lr_scheduler = torch.optim.lr_scheduler.StepLR(
self.optimizer, step_size=self.lr_decay_epoch, gamma=self.lr_decay_factor
)
self.loss_func = torch.nn.MSELoss(reduction="mean")
def run(self): def run(self):
for epoch in range(1, self.epochs + 1): for epoch in range(1, self.epochs + 1):
self.run_epoch(self.data_loader["train"], phase="train") logger.debug(
loss_training = self.run_epoch(self.data_loader["train"], phase="eval") f"Run epoch {str(epoch):>{len(str(self.epochs))}}/{self.epochs} ..."
loss_validation = self.run_epoch(self.data_loader["validation"], phase="eval") )
self._run_epoch(self.data_loaders["train"], phase="train")
loss_training = self._run_epoch(self.data_loaders["train"], phase="val")
logger.info(
f"Epoch {str(epoch):>{len(str(self.epochs))}}/{self.epochs} - Loss TRAIN: {loss_training:.4f}"
)
loss_validation = self._run_epoch(
self.data_loaders["validation"], phase="val"
)
logger.info(
f"Epoch {str(epoch):>{len(str(self.epochs))}}/{self.epochs} - Loss VAL: {loss_validation:.4f}"
)
# ToDo: log outputs and time # ToDo: log outputs and time
_previous = self.lr_scheduler.get_last_lr()[0]
self.lr_scheduler.step() self.lr_scheduler.step()
logger.debug(
f"Update learning rate from {_previous:.4f} to {self.lr_scheduler.get_last_lr()[0]:.4f}"
)
if epoch % self.snapshot_epoch: if epoch % self.snapshot_epoch == 0:
self.store_snapshot(epoch) self.store_snapshot(epoch)
logger.debug(
f"Finished epoch {str(epoch):>{len(str(self.epochs))}}/{self.epochs}"
)
def run_epoch(self, data_loader, phase): def _run_epoch(self, data_loader, phase):
logger.debug(f"Run epoch in phase {phase}")
self.model.train() if phase == "train" else self.model.eval() self.model.train() if phase == "train" else self.model.eval()
epoch_loss = 0 epoch_loss = 0
for inputs, labels in self.data_loader: loss_updates = 0
for i, (inputs, labels) in enumerate(data_loader):
print(
f"Batch {str(i):>{len(str(len(data_loader)))}}/{len(data_loader)}",
end="\r",
)
inputs = inputs.to(self.device) inputs = inputs.to(self.device)
labels = labels.to(self.device) labels = labels.to(self.device)
self.optimizer.zero_grad() self.optimizer.zero_grad()
with torch.set_grad_enabled(phase == "train"): with torch.set_grad_enabled(phase == "train"):
outputs = self.model(inputs) outputs = self.model(inputs)
loss = self.loss(outputs, labels) loss = self.loss_func(outputs, labels)
if phase == "train": if phase == "train":
loss.backward() loss.backward()
optimizer.step() self.optimizer.step()
epoch_loss += loss.item() / inputs.size[0]
return epoch_loss
epoch_loss += loss.item()
loss_updates += 1
return epoch_loss / loss_updates
def store_snapshot(self, epoch): def store_snapshot(self, epoch):
pass snapshot_file = f"{epoch:0{len(str(self.epochs))}d}.pth"
snapshot_file = os.path.join(self.snapshot_dir, snapshot_file)
logger.debug(f"Store snapshot at {snapshot_file}")
torch.save(self.model.state_dict(), snapshot_file)
if __name__ == "__main__":
import argparse
from mu_map.dataset.mock import MuMapMockDataset
from mu_map.logging import add_logging_args, get_logger_by_args
from mu_map.models.unet import UNet
parser = argparse.ArgumentParser(
description="Train a UNet model to predict μ-maps from reconstructed scatter images",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
# Model Args
parser.add_argument(
"--features",
type=int,
nargs="+",
default=[8, 16],
help="number of features in the layers of the UNet structure",
)
# Dataset Args
# parser.add_argument("--features", type=int, nargs="+", default=[8, 16], help="number of features in the layers of the UNet structure")
# Training Args
parser.add_argument(
"--output_dir",
type=str,
default="train_data",
help="directory in which results (snapshots and logs) of this training are saved",
)
parser.add_argument(
"--epochs",
type=int,
default=10,
help="the number of epochs for which the model is trained",
)
parser.add_argument(
"--device",
type=str,
default="cuda:0" if torch.cuda.is_available() else "cpu",
help="the device (cpu or gpu) with which the training is performed",
)
parser.add_argument(
"--lr", type=float, default=0.1, help="the initial learning rate for training"
)
parser.add_argument(
"--lr_decay_factor",
type=float,
default=0.99,
help="decay factor for the learning rate",
)
parser.add_argument(
"--lr_decay_epoch",
type=int,
default=1,
help="frequency in epochs at which the learning rate is decayed",
)
parser.add_argument(
"--snapshot_dir",
type=str,
default="snapshots",
help="directory under --output_dir where snapshots are stored",
)
parser.add_argument(
"--snapshot_epoch",
type=int,
default=10,
help="frequency in epochs at which snapshots are stored",
)
# Logging Args
add_logging_args(parser, defaults={"--logfile": "train.log"})
args = parser.parse_args()
if not os.path.exists(args.output_dir):
os.mkdir(args.output_dir)
args.snapshot_dir = os.path.join(args.output_dir, args.snapshot_dir)
if not os.path.exists(args.snapshot_dir):
os.mkdir(args.snapshot_dir)
args.logfile = os.path.join(args.output_dir, args.logfile)
device = torch.device(args.device)
logger = get_logger_by_args(args)
model = UNet(in_channels=1, features=args.features)
dataset = MuMapMockDataset(logger=logger)
data_loader_train = torch.utils.data.DataLoader(
dataset=dataset, batch_size=2, shuffle=True, pin_memory=True, num_workers=1
)
data_loader_val = torch.utils.data.DataLoader(
dataset=dataset, batch_size=2, shuffle=True, pin_memory=True, num_workers=1
)
data_loaders = {"train": data_loader_train, "validation": data_loader_val}
training = Training(
model=model,
data_loaders=data_loaders,
epochs=args.epochs,
device=device,
lr=args.lr,
lr_decay_factor=args.lr_decay_factor,
lr_decay_epoch=args.lr_decay_epoch,
snapshot_dir=args.snapshot_dir,
snapshot_epoch=args.snapshot_epoch,
logger=logger,
)
training.run()
...@@ -23,6 +23,9 @@ def to_grayscale( ...@@ -23,6 +23,9 @@ def to_grayscale(
if max_val is None: if max_val is None:
max_val = img.max() max_val = img.max()
if (max_val - min_val) == 0:
return np.zeros(img.shape, np.uint8)
_img = (img - min_val) / (max_val - min_val) _img = (img - min_val) / (max_val - min_val)
_img = (_img * 255).astype(np.uint8) _img = (_img * 255).astype(np.uint8)
return _img return _img
...@@ -36,4 +39,4 @@ def grayscale_to_rgb(img: np.ndarray): ...@@ -36,4 +39,4 @@ def grayscale_to_rgb(img: np.ndarray):
:return: the image in rgb :return: the image in rgb
""" """
assert img.ndim == 2, f"grascale image has more than 2 dimensions {img.shape}" assert img.ndim == 2, f"grascale image has more than 2 dimensions {img.shape}"
return img.repeat(3).reshape((3, *img.shape)) return img.repeat(3).reshape((*img.shape, 3))
import argparse
import matplotlib.pyplot as plt
import numpy as np
from mu_map.logging import parse_file
SIZE_DEFAULT = 12
plt.rc("font", family="Roboto") # controls default font
plt.rc("font", weight="normal") # controls default font
plt.rc("font", size=SIZE_DEFAULT) # controls default text sizes
plt.rc("axes", titlesize=18) # fontsize of the axes title
# https://colorbrewer2.org/#type=diverging&scheme=RdBu&n=3lk
COLORS = ["#ef8a62", "#67a9cf"]
parser = argparse.ArgumentParser(description="TODO", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("logfile", type=str, help="TODO")
parser.add_argument("--normalize", action="store_true", help="TODO")
args = parser.parse_args()
logs = parse_file(args.logfile)
logs = list(filter(lambda logline: logline.loglevel == "INFO", logs))
def parse_loss(logs, phase):
_logs = map(lambda logline: logline.message, logs)
_logs = filter(lambda log: phase in log, _logs)
_logs = filter(lambda log: "Loss" in log, _logs)
_logs = list(_logs)
losses = map(lambda log: log.split("-")[-1].strip(), _logs)
losses = map(lambda log: log.split(":")[-1].strip(), losses)
losses = map(float, losses)
epochs = map(lambda log: log.split("-")[0].strip(), _logs)
epochs = list(epochs)
epochs = map(lambda log: log.split(" ")[-1], epochs)
epochs = map(lambda log: log.split("/")[0], epochs)
epochs = map(int, epochs)
return np.array(list(epochs)), np.array(list(losses))
phases = ["TRAIN", "VAL"]
labels = ["Training", "Validation"]
fig, ax = plt.subplots()
for phase, label, color in zip(phases, labels, COLORS):
epochs, loss = parse_loss(logs, phase)
if args.normalize:
loss = loss / loss.max()
ax.plot(epochs, loss, label=label, color=color)
ax.scatter(epochs, loss, s=15, color=color)
ax.spines["left"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.spines["top"].set_visible(False)
ax.grid(axis="y", alpha=0.7)
ax.legend()
ax.set_xlabel("Epoch")
ax.set_ylabel("Loss")
plt.tight_layout()
plt.show()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment