diff --git a/mu_map/dataset/default.py b/mu_map/dataset/default.py
index 5b98ff3c388a93b644f29002fdd89963e62a9814..8291615b2802dc2c23685b18a8bdefac2a170b27 100644
--- a/mu_map/dataset/default.py
+++ b/mu_map/dataset/default.py
@@ -11,6 +11,7 @@ from torch.utils.data import Dataset
 from mu_map.data.prepare import headers
 from mu_map.data.remove_bed import DEFAULT_BED_CONTOURS_FILENAME, load_contours
 from mu_map.data.review_mu_map import discard_slices
+from mu_map.data.split import split_csv
 from mu_map.dataset.transform import Transform
 from mu_map.logging import get_logger
 
@@ -46,6 +47,8 @@ class MuMapDataset(Dataset):
         self,
         dataset_dir: str,
         csv_file: str = "meta.csv",
+        split_file: str = "split.csv",
+        split_name: str = None,
         images_dir: str = "images",
         bed_contours_file: Optional[str] = DEFAULT_BED_CONTOURS_FILENAME,
         discard_mu_map_slices: bool = True,
@@ -59,6 +62,8 @@ class MuMapDataset(Dataset):
         self.dir = dataset_dir
         self.dir_images = os.path.join(dataset_dir, images_dir)
         self.csv_file = os.path.join(dataset_dir, csv_file)
+        self.split_file = os.path.join(dataset_dir, split_file)
+
         self.transform_normalization = transform_normalization
         self.transform_augmentation = transform_augmentation
         self.logger = logger if logger is not None else get_logger()
@@ -72,6 +77,8 @@ class MuMapDataset(Dataset):
 
         # read CSV file and from that access DICOM files
         self.table = pd.read_csv(self.csv_file)
+        if split_name:
+            self.table = split_csv(self.table, self.split_file)[split_name]
         self.table["id"] = self.table["id"].apply(int)
 
         self.discard_mu_map_slices = discard_mu_map_slices
@@ -206,6 +213,12 @@ if __name__ == "__main__":
     parser.add_argument(
         "dataset_dir", type=str, help="the directory from which the dataset is loaded"
     )
+    parser.add_argument(
+        "--split",
+        type=str,
+        choices=["train", "validation", "test"],
+        help="choose the split of the data for the dataset",
+    )
     parser.add_argument(
         "--unaligned",
         action="store_true",
@@ -234,6 +247,7 @@ if __name__ == "__main__":
         align=align,
         discard_mu_map_slices=discard_mu_map_slices,
         bed_contours_file=bed_contours_file,
+        split_name=args.split,
         logger=logger,
     )
     main(dataset)