diff --git a/mu_map/training/default.py b/mu_map/training/default.py index 8a8e4e42683a484cd471f7b279d178b4ea999883..754a878cc679af2f2a27d8a041f245a80decc136 100644 --- a/mu_map/training/default.py +++ b/mu_map/training/default.py @@ -265,7 +265,8 @@ if __name__ == "__main__": dataset = MuMapPatchDataset( args.dataset_dir, split_name=split, - transform_normalization=MeanNormTransform(), + transform_normalization=transform_normalization, + transform_augmentation=transform_augmentation, logger=logger, ) data_loader = torch.utils.data.DataLoader(