Skip to content
Snippets Groups Projects
Commit 27c9e209 authored by Tamino Huxohl's avatar Tamino Huxohl
Browse files

include slice removal from mu_map in the dataset

parent 3178da67
No related branches found
No related tags found
No related merge requests found
......@@ -6,9 +6,63 @@ import numpy as np
from torch.utils.data import Dataset
HEADER_DISC_FIRST = "discard_first"
HEADER_DISC_LAST = "discard_last"
def discard_slices(row, μ_map):
"""
Discard slices based on the flags in the row of th according table.
The row is expected to contain the flags 'discard_first' and 'discard_last'.
:param row: the row of meta configuration file of a dataset
:param μ_map: the μ_map
:return: the μ_map with according slices removed
"""
_res = μ_map
if row[HEADER_DISC_FIRST]:
_res = _res[1:]
if row[HEADER_DISC_LAST]:
_res = _res[:-1]
return _res
def align_images(image_1: np.ndarray, image_2: np.ndarray):
"""
Align one image to another on the first axis (z-axis).
It is assumed that the second image has less slices than the first.
Then, the first image is shortened in a way that the centers of both images lie on top of each other.
:param image_1: the image to be aligned
:param image_2: the image to which image_1 is aligned
:return: the aligned image_1
"""
assert (
image_1.shape[0] > image_2.shape[0]
), f"Alignment is based on the fact that image 1 has more slices {image_1.shape[0]} than image_2 {image_.shape[0]}"
# central slice of image 2
c_2 = image_2.shape[0] // 2
# image to the left and right of the center
left = c_2
right = image_2.shape[0] - c_2
# central slice of image 1
c_1 = image_1.shape[0] // 2
# select center and same amount to the left/right as image_2
return image_1[(c_1 - left) : (c_1 + right)]
class MuMapDataset(Dataset):
def __init__(
self, dataset_dir: str, csv_file: str = "meta.csv", images_dir: str = "images"
self,
dataset_dir: str,
csv_file: str = "meta.csv",
images_dir: str = "images",
discard_μ_map_slices: bool = True,
):
super().__init__()
......@@ -16,8 +70,10 @@ class MuMapDataset(Dataset):
self.dir_images = os.path.join(dataset_dir, images_dir)
self.csv_file = os.path.join(dataset_dir, csv_file)
# read CSV file and from that access DICOM files
self.table = pd.read_csv(self.csv_file)
# read csv file and from that access dicom files
self.discard_μ_map_slices = discard_μ_map_slices
def __getitem__(self, index: int):
row = self.table.iloc[index]
......@@ -27,25 +83,17 @@ class MuMapDataset(Dataset):
recon = pydicom.dcmread(recon_file).pixel_array
mu_map = pydicom.dcmread(mu_map_file).pixel_array
recon, mu_map = self.align(recon, mu_map)
if self.discard_μ_map_slices:
mu_map = discard_slices(row, mu_map)
recon = align_images(recon, mu_map)
return recon, mu_map
def __len__(self):
return len(self.table)
def align(self, recon, mu_map):
assert recon.shape[0] > mu_map.shape[0], f"Alignment is based on the fact that the NoAC Recon has more slices {recon.shape[0]} than the attenuation map {mu_map.shape[0]}"
cm = mu_map.shape[0] // 2
left = cm
right = mu_map.shape[0] - cm
cr = recon.shape[0] // 2
recon = recon[(cr - left):(cr + right)]
return recon, mu_map
__all__ = [MuMapDataset.__name__]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment