Skip to content
Snippets Groups Projects
Commit bed5ba58 authored by Tamino Huxohl's avatar Tamino Huxohl
Browse files

add transforms to default dataset

parent b43e8bce
No related branches found
No related tags found
No related merge requests found
......@@ -11,6 +11,8 @@ from torch.utils.data import Dataset
from mu_map.data.prepare import headers
from mu_map.data.remove_bed import DEFAULT_BED_CONTOURS_FILENAME, load_contours
from mu_map.data.review_mu_map import discard_slices
from mu_map.dataset.transform import Transform
from mu_map.logging import get_logger
def align_images(image_1: np.ndarray, image_2: np.ndarray) -> np.ndarray:
......@@ -48,12 +50,18 @@ class MuMapDataset(Dataset):
bed_contours_file: Optional[str] = DEFAULT_BED_CONTOURS_FILENAME,
discard_mu_map_slices: bool = True,
align: bool = True,
transform_normalization: Transform = Transform(),
transform_augmentation: Transform = Transform(),
logger=None,
):
super().__init__()
self.dir = dataset_dir
self.dir_images = os.path.join(dataset_dir, images_dir)
self.csv_file = os.path.join(dataset_dir, csv_file)
self.transform_normalization = transform_normalization
self.transform_augmentation = transform_augmentation
self.logger = logger if logger is not None else get_logger()
self.bed_contours_file = (
os.path.join(dataset_dir, bed_contours_file) if bed_contours_file else None
......@@ -74,7 +82,7 @@ class MuMapDataset(Dataset):
self.pre_load_images()
def pre_load_images(self):
print("Pre-loading images ...", end="\r")
self.logger.debug("Pre-loading images ...")
for i in range(len(self.table)):
row = self.table.iloc[i]
_id = row["id"]
......@@ -96,13 +104,16 @@ class MuMapDataset(Dataset):
mu_map = mu_map.astype(np.float32)
mu_map = torch.from_numpy(mu_map)
mu_map = mu_map.unsqueeze(dim=0)
self.mu_maps[_id] = mu_map
recon = recon.astype(np.float32)
recon = torch.from_numpy(recon)
recon = recon.unsqueeze(dim=0)
recon, mu_map = self.transform_normalization(recon, mu_map)
self.mu_maps[_id] = mu_map
self.reconstructions[_id] = recon
print("Pre-loading images done!")
self.logger.debug("Pre-loading images done!")
def __getitem__(self, index: int):
row = self.table.iloc[index]
......@@ -111,22 +122,7 @@ class MuMapDataset(Dataset):
recon = self.reconstructions[_id]
mu_map = self.mu_maps[_id]
# recon_file = os.path.join(self.dir_images, row[headers.file_recon_nac_nsc])
# mu_map_file = os.path.join(self.dir_images, row[headers.file_mu_map])
# recon = pydicom.dcmread(recon_file).pixel_array
# mu_map = pydicom.dcmread(mu_map_file).pixel_array
# if self.discard_mu_map_slices:
# mu_map = discard_slices(row, mu_map)
# if self.bed_contours:
# bed_contour = self.bed_contours[row["id"]]
# for i in range(mu_map.shape[0]):
# mu_map[i] = cv.drawContours(mu_map[i], [bed_contour], -1, 0.0, -1)
# if self.align:
# recon = align_images(recon, mu_map)
recon, mu_map = self.transform_augmentation(recon, mu_map)
return recon, mu_map
......@@ -139,6 +135,12 @@ __all__ = [MuMapDataset.__name__]
if __name__ == "__main__":
import argparse
from mu_map.dataset.normalization import (
GaussianNormTransform,
MaxNormTransform,
MeanNormTransform,
)
from mu_map.logging import add_logging_args, get_logger_by_args
from mu_map.util import to_grayscale, COLOR_WHITE
parser = argparse.ArgumentParser(
......@@ -163,17 +165,20 @@ if __name__ == "__main__":
action="store_true",
help="do not remove broken slices of the mu map",
)
add_logging_args(parser, defaults={"--loglevel": "DEBUG"})
args = parser.parse_args()
align = not args.unaligned
discard_mu_map_slices = not args.full_mu_map
bed_contours_file = None if args.show_bed else DEFAULT_BED_CONTOURS_FILENAME
logger = get_logger_by_args(args)
dataset = MuMapDataset(
args.dataset_dir,
align=align,
discard_mu_map_slices=discard_mu_map_slices,
bed_contours_file=bed_contours_file,
logger=logger,
)
wname = "Dataset"
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment