diff --git a/mu_map/dataset/default.py b/mu_map/dataset/default.py
index a63317cf270a16cf68c8300d23076b3cec508bd4..8d4a83c835e0bf9e32765c0e5ee09ec1b9cde05a 100644
--- a/mu_map/dataset/default.py
+++ b/mu_map/dataset/default.py
@@ -62,6 +62,7 @@ class MuMapDataset(Dataset):
         bed_contours_file: Optional[str] = DEFAULT_BED_CONTOURS_FILENAME,
         discard_mu_map_slices: bool = True,
         align: bool = True,
+        scatter_correction: bool = False,
         transform_normalization: Transform = Transform(),
         transform_augmentation: Transform = Transform(),
         logger=None,
@@ -92,6 +93,8 @@ class MuMapDataset(Dataset):
 
         self.discard_mu_map_slices = discard_mu_map_slices
         self.align = align
+        self.scatter_correction = scatter_correction
+        self.header_recon = headers.file_recon_nac_sc if self.scatter_correction else headers.file_recon_nac_nsc
 
         self.reconstructions = {}
         self.mu_maps = {}
@@ -113,7 +116,7 @@ class MuMapDataset(Dataset):
                 for i in range(mu_map.shape[0]):
                     mu_map[i] = cv.drawContours(mu_map[i], [bed_contour], -1, 0.0, -1)
 
-            recon_file = os.path.join(self.dir_images, row[headers.file_recon_nac_nsc])
+            recon_file = os.path.join(self.dir_images, row[self.header_recon])
             recon = pydicom.dcmread(recon_file)
             recon = recon.pixel_array / recon[DCM_TAG_PIXEL_SCALE_FACTOR].value
             if self.align: