diff --git a/mu_map/dataset/patches.py b/mu_map/dataset/patches.py index df9f2c971b991c2278a7e4af9a6376c601af6d71..530e61abff3d658608e2a0557b2d34abf80e0036 100644 --- a/mu_map/dataset/patches.py +++ b/mu_map/dataset/patches.py @@ -8,8 +8,15 @@ from mu_map.dataset.default import MuMapDataset class MuMapPatchDataset(MuMapDataset): - def __init__(self, dataset_dir, patches_per_image=100, patch_size=32, shuffle=True): - super().__init__(dataset_dir) + def __init__( + self, + dataset_dir: str, + patches_per_image: int = 100, + patch_size: int = 32, + shuffle: bool = True, + **kwargs, + ): + super().__init__(dataset_dir, **kwargs) self.patches_per_image = patches_per_image self.patch_size = patch_size @@ -84,7 +91,9 @@ if __name__ == "__main__": _recon_orig = recon_orig[_slice] _recon_orig = to_grayscale(_recon_orig) _recon_orig = grayscale_to_rgb(_recon_orig) - _recon_orig = cv.rectangle(_recon_orig, (x, y), (x + s, y + s), color=(255, 0, 0), thickness=1) + _recon_orig = cv.rectangle( + _recon_orig, (x, y), (x + s, y + s), color=(255, 0, 0), thickness=1 + ) _recon_orig = cv.resize(_recon_orig, (512, 512), cv.INTER_AREA) _recon = recon[_slice] @@ -113,7 +122,9 @@ if __name__ == "__main__": mu_map = mu_map.squeeze().numpy() recon_orig = dataset.reconstructions[_id] - recon_orig = torch.nn.functional.pad(recon_orig, padding, mode="constant", value=0) + recon_orig = torch.nn.functional.pad( + recon_orig, padding, mode="constant", value=0 + ) recon_orig = recon_orig.squeeze().numpy() cv.imshow(wname, create_image(recon, mu_map, recon_orig, patch, _i))