diff --git a/mu_map/dataset/patches.py b/mu_map/dataset/patches.py
index 248a0a51832880c4cbd0e7a14a067023e7635f5c..b4d0b1fb2e99504698ff800128436601eadc33d3 100644
--- a/mu_map/dataset/patches.py
+++ b/mu_map/dataset/patches.py
@@ -46,6 +46,8 @@ class MuMapPatchDataset(MuMapDataset):
         self.generate_patches()
 
     def copy(self, split_name: str):
+        kwargs = self.kwargs.copy()
+        kwargs["split_name"] = split_name
         return MuMapPatchDataset(
             dataset_dir=self.dir,
             patches_per_image=self.patches_per_image,
@@ -53,7 +55,7 @@ class MuMapPatchDataset(MuMapDataset):
             patch_size_z=self.patch_size_z,
             patch_offset=self.patch_offset,
             shuffle=self.shuffle,
-            **self.kwargs,
+            **kwargs,
         )
 
     def generate_patches(self):