diff --git a/mu_map/data/patch_dataset.py b/mu_map/data/patch_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..46caba3af8e969e7f5eaa3af7b484064eee6bf11 --- /dev/null +++ b/mu_map/data/patch_dataset.py @@ -0,0 +1,21 @@ +from mu_map.data.datasets import MuMapDataset + +class MuMapPatchDataset(MuMapDataset): + + def __init__(self, dataset_dir, patches_per_image=100, patch_size=32): + super().__init__(dataset_dir) + + self.patches_per_image=patches_per_image + self.patch_size=patch_size + + def __getitem___(self, index:int): + return super()[index] + + def __len__(self): + return super().__len__() * self.patches_per_image + + +if __name__ == "__main__": + dataset = MuMapPatchDataset("data/initial/") + + print(f"Images {len(dataset)}")