diff --git a/mu_map/training/default.py b/mu_map/training/default.py
index 8a8e4e42683a484cd471f7b279d178b4ea999883..754a878cc679af2f2a27d8a041f245a80decc136 100644
--- a/mu_map/training/default.py
+++ b/mu_map/training/default.py
@@ -265,7 +265,8 @@ if __name__ == "__main__":
         dataset = MuMapPatchDataset(
             args.dataset_dir,
             split_name=split,
-            transform_normalization=MeanNormTransform(),
+            transform_normalization=transform_normalization,
+            transform_augmentation=transform_augmentation,
             logger=logger,
         )
         data_loader = torch.utils.data.DataLoader(