diff --git a/mu_map/dataset/patches.py b/mu_map/dataset/patches.py index 248a0a51832880c4cbd0e7a14a067023e7635f5c..b4d0b1fb2e99504698ff800128436601eadc33d3 100644 --- a/mu_map/dataset/patches.py +++ b/mu_map/dataset/patches.py @@ -46,6 +46,8 @@ class MuMapPatchDataset(MuMapDataset): self.generate_patches() def copy(self, split_name: str): + kwargs = self.kwargs.copy() + kwargs["split_name"] = split_name return MuMapPatchDataset( dataset_dir=self.dir, patches_per_image=self.patches_per_image, @@ -53,7 +55,7 @@ class MuMapPatchDataset(MuMapDataset): patch_size_z=self.patch_size_z, patch_offset=self.patch_offset, shuffle=self.shuffle, - **self.kwargs, + **kwargs, ) def generate_patches(self):