diff --git a/mu_map/data/datasets.py b/mu_map/data/datasets.py
index 5fd89402cdad67a02d81dc1ddaae58639b17f5c8..0f0a321a8de0e30da93b830563ba4ffb789084ce 100644
--- a/mu_map/data/datasets.py
+++ b/mu_map/data/datasets.py
@@ -1,10 +1,13 @@
 import os
+from typing import Optional
 
 import pandas as pd
 import pydicom
 import numpy as np
 from torch.utils.data import Dataset
 
+from mu_map.data.remove_bed import DEFAULT_BED_CONTOURS_FILENAME, load_contours
+
 
 HEADER_DISC_FIRST = "discard_first"
 HEADER_DISC_LAST = "discard_last"
@@ -62,6 +65,7 @@ class MuMapDataset(Dataset):
         dataset_dir: str,
         csv_file: str = "meta.csv",
         images_dir: str = "images",
+        bed_contours_file: Optional[str] = DEFAULT_BED_CONTOURS_FILENAME,
         discard_μ_map_slices: bool = True,
     ):
         super().__init__()
@@ -70,8 +74,12 @@ class MuMapDataset(Dataset):
         self.dir_images = os.path.join(dataset_dir, images_dir)
         self.csv_file = os.path.join(dataset_dir, csv_file)
 
+        self.bed_contours_file = os.path.join(dataset_dir, bed_contours_file) if bed_contours_file else None 
+        self.bed_contours = load_contours(self.bed_contours_file) if bed_contours_file else None
+
         # read CSV file and from that access DICOM files
         self.table = pd.read_csv(self.csv_file)
+        self.table["id"] = self.table["id"].apply(int)
 
         self.discard_μ_map_slices = discard_μ_map_slices
 
@@ -87,6 +95,11 @@ class MuMapDataset(Dataset):
         if self.discard_μ_map_slices:
             mu_map = discard_slices(row, mu_map)
 
+        if self.bed_contours:
+            bed_contour = self.bed_contours[row["id"]]
+            for i in range(mu_map.shape[0]):
+                mu_map[i] = cv.drawContours(mu_map[i], [bed_contour], -1, 0.0, -1)
+
         recon = align_images(recon, mu_map)
 
         return recon, mu_map
diff --git a/mu_map/data/remove_bed.py b/mu_map/data/remove_bed.py
index b0dda80820a36c72bfb768576f3a4cd890d094cd..d002f3f49ce48559e59f89af904b8ba51c4428bf 100644
--- a/mu_map/data/remove_bed.py
+++ b/mu_map/data/remove_bed.py
@@ -19,10 +19,8 @@ def load_contours(filename: str) -> Dict[int, np.ndarray]:
     with open(filename, mode="r") as f:
         contours = json.load(f)
 
-        for key, contour in contours.items():
-            del contours[key]
-            contours[int(key)] = np.array(contour).astype(int)
-    return contours
+    _map = map(lambda item: (int(item[0]), np.array(item[1]).astype(int)), contours.items())
+    return dict(_map)
 
 
 if __name__ == "__main__":