diff --git a/mu_map/dataset/default.py b/mu_map/dataset/default.py index 5b98ff3c388a93b644f29002fdd89963e62a9814..8291615b2802dc2c23685b18a8bdefac2a170b27 100644 --- a/mu_map/dataset/default.py +++ b/mu_map/dataset/default.py @@ -11,6 +11,7 @@ from torch.utils.data import Dataset from mu_map.data.prepare import headers from mu_map.data.remove_bed import DEFAULT_BED_CONTOURS_FILENAME, load_contours from mu_map.data.review_mu_map import discard_slices +from mu_map.data.split import split_csv from mu_map.dataset.transform import Transform from mu_map.logging import get_logger @@ -46,6 +47,8 @@ class MuMapDataset(Dataset): self, dataset_dir: str, csv_file: str = "meta.csv", + split_file: str = "split.csv", + split_name: str = None, images_dir: str = "images", bed_contours_file: Optional[str] = DEFAULT_BED_CONTOURS_FILENAME, discard_mu_map_slices: bool = True, @@ -59,6 +62,8 @@ class MuMapDataset(Dataset): self.dir = dataset_dir self.dir_images = os.path.join(dataset_dir, images_dir) self.csv_file = os.path.join(dataset_dir, csv_file) + self.split_file = os.path.join(dataset_dir, split_file) + self.transform_normalization = transform_normalization self.transform_augmentation = transform_augmentation self.logger = logger if logger is not None else get_logger() @@ -72,6 +77,8 @@ class MuMapDataset(Dataset): # read CSV file and from that access DICOM files self.table = pd.read_csv(self.csv_file) + if split_name: + self.table = split_csv(self.table, self.split_file)[split_name] self.table["id"] = self.table["id"].apply(int) self.discard_mu_map_slices = discard_mu_map_slices @@ -206,6 +213,12 @@ if __name__ == "__main__": parser.add_argument( "dataset_dir", type=str, help="the directory from which the dataset is loaded" ) + parser.add_argument( + "--split", + type=str, + choices=["train", "validation", "test"], + help="choose the split of the data for the dataset", + ) parser.add_argument( "--unaligned", action="store_true", @@ -234,6 +247,7 @@ if __name__ == "__main__": align=align, discard_mu_map_slices=discard_mu_map_slices, bed_contours_file=bed_contours_file, + split_name=args.split, logger=logger, ) main(dataset)