Skip to content
Snippets Groups Projects
Commit ac263a4e authored by Tamino Huxohl's avatar Tamino Huxohl
Browse files

mu map dataset can now load splits

parent 6442ed68
No related branches found
No related tags found
No related merge requests found
...@@ -11,6 +11,7 @@ from torch.utils.data import Dataset ...@@ -11,6 +11,7 @@ from torch.utils.data import Dataset
from mu_map.data.prepare import headers from mu_map.data.prepare import headers
from mu_map.data.remove_bed import DEFAULT_BED_CONTOURS_FILENAME, load_contours from mu_map.data.remove_bed import DEFAULT_BED_CONTOURS_FILENAME, load_contours
from mu_map.data.review_mu_map import discard_slices from mu_map.data.review_mu_map import discard_slices
from mu_map.data.split import split_csv
from mu_map.dataset.transform import Transform from mu_map.dataset.transform import Transform
from mu_map.logging import get_logger from mu_map.logging import get_logger
...@@ -46,6 +47,8 @@ class MuMapDataset(Dataset): ...@@ -46,6 +47,8 @@ class MuMapDataset(Dataset):
self, self,
dataset_dir: str, dataset_dir: str,
csv_file: str = "meta.csv", csv_file: str = "meta.csv",
split_file: str = "split.csv",
split_name: str = None,
images_dir: str = "images", images_dir: str = "images",
bed_contours_file: Optional[str] = DEFAULT_BED_CONTOURS_FILENAME, bed_contours_file: Optional[str] = DEFAULT_BED_CONTOURS_FILENAME,
discard_mu_map_slices: bool = True, discard_mu_map_slices: bool = True,
...@@ -59,6 +62,8 @@ class MuMapDataset(Dataset): ...@@ -59,6 +62,8 @@ class MuMapDataset(Dataset):
self.dir = dataset_dir self.dir = dataset_dir
self.dir_images = os.path.join(dataset_dir, images_dir) self.dir_images = os.path.join(dataset_dir, images_dir)
self.csv_file = os.path.join(dataset_dir, csv_file) self.csv_file = os.path.join(dataset_dir, csv_file)
self.split_file = os.path.join(dataset_dir, split_file)
self.transform_normalization = transform_normalization self.transform_normalization = transform_normalization
self.transform_augmentation = transform_augmentation self.transform_augmentation = transform_augmentation
self.logger = logger if logger is not None else get_logger() self.logger = logger if logger is not None else get_logger()
...@@ -72,6 +77,8 @@ class MuMapDataset(Dataset): ...@@ -72,6 +77,8 @@ class MuMapDataset(Dataset):
# read CSV file and from that access DICOM files # read CSV file and from that access DICOM files
self.table = pd.read_csv(self.csv_file) self.table = pd.read_csv(self.csv_file)
if split_name:
self.table = split_csv(self.table, self.split_file)[split_name]
self.table["id"] = self.table["id"].apply(int) self.table["id"] = self.table["id"].apply(int)
self.discard_mu_map_slices = discard_mu_map_slices self.discard_mu_map_slices = discard_mu_map_slices
...@@ -206,6 +213,12 @@ if __name__ == "__main__": ...@@ -206,6 +213,12 @@ if __name__ == "__main__":
parser.add_argument( parser.add_argument(
"dataset_dir", type=str, help="the directory from which the dataset is loaded" "dataset_dir", type=str, help="the directory from which the dataset is loaded"
) )
parser.add_argument(
"--split",
type=str,
choices=["train", "validation", "test"],
help="choose the split of the data for the dataset",
)
parser.add_argument( parser.add_argument(
"--unaligned", "--unaligned",
action="store_true", action="store_true",
...@@ -234,6 +247,7 @@ if __name__ == "__main__": ...@@ -234,6 +247,7 @@ if __name__ == "__main__":
align=align, align=align,
discard_mu_map_slices=discard_mu_map_slices, discard_mu_map_slices=discard_mu_map_slices,
bed_contours_file=bed_contours_file, bed_contours_file=bed_contours_file,
split_name=args.split,
logger=logger, logger=logger,
) )
main(dataset) main(dataset)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment