diff --git a/mu_map/dataset/mock.py b/mu_map/dataset/mock.py deleted file mode 100644 index 9f33b3c900f0a85b053a2e7efabedb04de4e783b..0000000000000000000000000000000000000000 --- a/mu_map/dataset/mock.py +++ /dev/null @@ -1,27 +0,0 @@ -from mu_map.dataset.default import MuMapDataset -from mu_map.dataset.normalization import MaxNormTransform - - -class MuMapMockDataset(MuMapDataset): - def __init__(self, dataset_dir: str = "data/initial/", num_images: int = 16, logger=None): - super().__init__(dataset_dir=dataset_dir, transform_normalization=MaxNormTransform(), logger=logger) - self.len = num_images - - def __getitem__(self, index: int): - recon, mu_map = super().__getitem__(0) - recon = recon[:, :32, :, :] - mu_map = mu_map[:, :32, :, :] - return recon, mu_map - - def __len__(self): - return self.len - - -if __name__ == "__main__": - import cv2 as cv - import numpy as np - - from mu_map.dataset.default import main - - dataset = MuMapMockDataset() - main(dataset) diff --git a/mu_map/test.py b/mu_map/test.py index f761307eb19eb2f460c71e73705bd009a3b6a01d..84dcf684f7ea4c0a3e8403318a37ee2b6dd6c467 100644 --- a/mu_map/test.py +++ b/mu_map/test.py @@ -4,29 +4,31 @@ import numpy as np import torch from mu_map.dataset.default import MuMapDataset -from mu_map.dataset.mock import MuMapMockDataset -from mu_map.dataset.normalization import norm_max, norm_gaussian +from mu_map.dataset.normalization import MeanNormTransform +from mu_map.dataset.transform import PadCropTranform, SequenceTransform from mu_map.models.unet import UNet from mu_map.util import to_grayscale, COLOR_WHITE torch.set_grad_enabled(False) -dataset = MuMapMockDataset("data/second/") +dataset = MuMapDataset( + "data/second/", + transform_normalization=SequenceTransform([ + MeanNormTransform(), + PadCropTranform(dim=3, size=32) + ]), +) -# model = UNet(in_channels=1, features=[8, 16]) model = UNet(in_channels=1) device = torch.device("cpu") -weights = torch.load("train_data/snapshots/val_min_Model.pth", map_location=device) -model.load_state_dict(weights) +# weights = torch.load("train_data/snapshots/val_min_Model.pth", map_location=device) +# model.load_state_dict(weights) model = model.eval() recon, mu_map = dataset[0] recon = recon.unsqueeze(dim=0) -# recon = norm_max(recon) -recon = norm_gaussian(recon) output = model(recon) -# output = output * 40206.0 diff = ((mu_map - output) ** 2).mean() print(f"Diff: {diff:.5f}")