Skip to content
Snippets Groups Projects
Commit d91fd6a7 authored by Tamino Huxohl's avatar Tamino Huxohl
Browse files

document and small refactoring of MuMapDataset

parent b7fad94b
No related branches found
No related tags found
No related merge requests found
......@@ -161,7 +161,7 @@ if __name__ == "__main__":
ids = args.revise_ids
for _i, _id in enumerate(ids):
_, mu_map = dataset.getitem_by_id(_id)
_, mu_map = dataset.get_item_by_id(_id)
if str(_id) in bed_contours and not args.revise_ids:
print(f"Skip {_id} because file already contains these contours")
......
import os
from typing import Optional, Tuple
from typing import List, Optional, Tuple
import cv2 as cv
import pandas as pd
import pydicom
import numpy as np
import torch
from torch.utils.data import Dataset
......@@ -23,12 +22,20 @@ from mu_map.logging import get_logger
class MuMapDataset(Dataset):
"""
A dataset to map reconstructions to attenuation maps (mu maps).
The dataset is lazy. This means that that dataset creation is
fast and images are only read into memory when their first accessed.
Thus, after the first iteration, accessing images becomes a lot
faster.
"""
def __init__(
self,
dataset_dir: str,
csv_file: str = "meta.csv",
split_file: str = "split.csv",
split_name: str = None,
split_name: Optional[str] = None,
images_dir: str = "images",
bed_contours_file: Optional[str] = DEFAULT_BED_CONTOURS_FILENAME,
discard_mu_map_slices: bool = True,
......@@ -38,6 +45,35 @@ class MuMapDataset(Dataset):
transform_augmentation: Transform = Transform(),
logger=None,
):
"""
Create a new mu map dataset.
Parameters
----------
dataset_dir: str
directory of the dataset to be loaded
csv_file: str
name of the csv file in the dataset directory containing meta information (created by mu_map.data.prepare)
split_file: str
csv file defining a split of the dataset in train/validation/test (created by mu_map.data.split)
split_name: str, optional
the name of the split which is loaded
images_dir: str
directory under `dataset_dir` containing the actual images in DICOM format
bed_contours_file: str, optional
json file containing contours around the bed for each mu map (see mu_map.data.remove_bed)
discard_mu_map_slices: bool
remove defective slices from mu maps (have to be labeled by mu_map.data.review_mu_map)
align: bool
center align reconstructions and mu maps
scatter_correction: bool
use scatter corrected reconstructions
transform_normalization: Transform
transform used for image normalization which is applied once when the image is loaded
transform_augmentation: Transform
transform used for augmentation which is applied every time `__getitem__` is called
logger: Logger, optional
"""
super().__init__()
self.dir = dataset_dir
......@@ -77,7 +113,15 @@ class MuMapDataset(Dataset):
self.reconstructions = {}
self.mu_maps = {}
def split_copy(self, split_name: str):
def copy(self, split_name: str):
"""
Create a copy of the dataset and modify parameters.
Parameters
----------
split_name: str
the split which with which the copy is created
"""
return MuMapDataset(
dataset_dir=self.dir,
csv_file=os.path.basename(self.csv_file),
......@@ -94,6 +138,18 @@ class MuMapDataset(Dataset):
)
def load_image(self, _id: int):
"""
Load an image into memory.
This function also performs all of the pre-processing (discard slices, remove bed, alignment ...).
Afterwards, the reconstruction and mu map are available from the local
dicts: `self.reconstructions` and `self.mu_maps.`
Parameters
----------
_id: int
the id of the image to be loaded
"""
row = self.table[self.table[headers.id] == _id].iloc[0]
_id = row[headers.id]
......@@ -129,15 +185,42 @@ class MuMapDataset(Dataset):
self.reconstructions[_id] = recon
def pre_load_images(self):
"""
Load all images into memory.
"""
for _id in self.table[headers.id]:
self.load_image(_id)
def __getitem__(self, index: int):
def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Get a reconstruction mu map pair by index.
This method retrieves the image id of this index and call `get_item_by_id`.
Parameters
----------
index: int
"""
row = self.table.iloc[index]
_id = row[headers.id]
return self.getitem_by_id(_id)
return self.get_item_by_id(_id)
def get_item_by_id(self, _id: int) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Get a reconstruction and mu map pair by their id.
def getitem_by_id(self, _id: int):
This methods loads the images of not yet in memory and applies the
augmentation transform before returning them.
Parameters
----------
_id: int
Returns
-------
Tuple[torch.Tensor, torch.Tensor]
a pair of a reconstruction and the according mu map
"""
if _id not in self.reconstructions:
self.load_image(_id)
......@@ -148,14 +231,29 @@ class MuMapDataset(Dataset):
return recon, mu_map
def __len__(self):
def __len__(self) -> int:
"""
Get the number of elements in this dataset.
"""
return len(self.table)
__all__ = [MuMapDataset.__name__]
def main(dataset, ids, paused=False):
def main(dataset: MuMapDataset, ids: Optional[List[int]] = None, paused: bool=False):
"""
Display reconstructions and mu maps in a dataset.
Parameters
----------
dataset: MuMapDataset
the dataset of which elements are displayed
ids: list of int, optional
only display these ids
paused: bool
start display in paused mode
"""
from mu_map.util import to_grayscale, COLOR_WHITE
wname = "Dataset"
......@@ -163,10 +261,10 @@ def main(dataset, ids, paused=False):
cv.resizeWindow(wname, 1600, 900)
space = np.full((1024, 10), 239, np.uint8)
TIMEOUT_PAUSED = 0
TIMEOUT_RUNNING = 1000 // 15
timeout_paused = 0
timeout_running = 1000 // 15
timeout = TIMEOUT_PAUSED if paused else TIMEOUT_RUNNING
timeout = timeout_paused if paused else timeout_running
def to_display_image(image, _slice):
_image = to_grayscale(image[_slice], min_val=image.min(), max_val=image.max())
......@@ -223,7 +321,7 @@ def main(dataset, ids, paused=False):
elif key == ord("q"):
exit(0)
elif key == ord("p"):
timeout = TIMEOUT_PAUSED if timeout > 0 else TIMEOUT_RUNNING
timeout = timeout_paused if timeout > 0 else timeout_running
elif key == 82: # up arrow key
ir = ir - 1
continue
......
......@@ -98,7 +98,7 @@ class MuMapPatchDataset(MuMapDataset):
ps = self.patch_size
ps_z = self.patch_size_z
recon, mu_map = super().getitem_by_id(_id)
recon, mu_map = super().get_item_by_id(_id)
recon = torch.nn.functional.pad(recon, padding, mode="constant", value=0)
mu_map = torch.nn.functional.pad(mu_map, padding, mode="constant", value=0)
......
......@@ -77,7 +77,7 @@ class AbstractTraining:
(
split_name,
torch.utils.data.DataLoader(
dataset.split_copy(split_name),
dataset.copy(split_name),
batch_size=self.batch_size,
shuffle=True,
pin_memory=True,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment