Skip to content
Snippets Groups Projects
Commit ac263a4e authored by Tamino Huxohl's avatar Tamino Huxohl
Browse files

mu map dataset can now load splits

parent 6442ed68
No related branches found
No related tags found
No related merge requests found
......@@ -11,6 +11,7 @@ from torch.utils.data import Dataset
from mu_map.data.prepare import headers
from mu_map.data.remove_bed import DEFAULT_BED_CONTOURS_FILENAME, load_contours
from mu_map.data.review_mu_map import discard_slices
from mu_map.data.split import split_csv
from mu_map.dataset.transform import Transform
from mu_map.logging import get_logger
......@@ -46,6 +47,8 @@ class MuMapDataset(Dataset):
self,
dataset_dir: str,
csv_file: str = "meta.csv",
split_file: str = "split.csv",
split_name: str = None,
images_dir: str = "images",
bed_contours_file: Optional[str] = DEFAULT_BED_CONTOURS_FILENAME,
discard_mu_map_slices: bool = True,
......@@ -59,6 +62,8 @@ class MuMapDataset(Dataset):
self.dir = dataset_dir
self.dir_images = os.path.join(dataset_dir, images_dir)
self.csv_file = os.path.join(dataset_dir, csv_file)
self.split_file = os.path.join(dataset_dir, split_file)
self.transform_normalization = transform_normalization
self.transform_augmentation = transform_augmentation
self.logger = logger if logger is not None else get_logger()
......@@ -72,6 +77,8 @@ class MuMapDataset(Dataset):
# read CSV file and from that access DICOM files
self.table = pd.read_csv(self.csv_file)
if split_name:
self.table = split_csv(self.table, self.split_file)[split_name]
self.table["id"] = self.table["id"].apply(int)
self.discard_mu_map_slices = discard_mu_map_slices
......@@ -206,6 +213,12 @@ if __name__ == "__main__":
parser.add_argument(
"dataset_dir", type=str, help="the directory from which the dataset is loaded"
)
parser.add_argument(
"--split",
type=str,
choices=["train", "validation", "test"],
help="choose the split of the data for the dataset",
)
parser.add_argument(
"--unaligned",
action="store_true",
......@@ -234,6 +247,7 @@ if __name__ == "__main__":
align=align,
discard_mu_map_slices=discard_mu_map_slices,
bed_contours_file=bed_contours_file,
split_name=args.split,
logger=logger,
)
main(dataset)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment