Skip to content
Snippets Groups Projects
Commit 4527c73c authored by Tamino Huxohl's avatar Tamino Huxohl
Browse files

update test script and remove usage of mock dataset

parent bf09641f
No related branches found
No related tags found
No related merge requests found
from mu_map.dataset.default import MuMapDataset
from mu_map.dataset.normalization import MaxNormTransform
class MuMapMockDataset(MuMapDataset):
def __init__(self, dataset_dir: str = "data/initial/", num_images: int = 16, logger=None):
super().__init__(dataset_dir=dataset_dir, transform_normalization=MaxNormTransform(), logger=logger)
self.len = num_images
def __getitem__(self, index: int):
recon, mu_map = super().__getitem__(0)
recon = recon[:, :32, :, :]
mu_map = mu_map[:, :32, :, :]
return recon, mu_map
def __len__(self):
return self.len
if __name__ == "__main__":
import cv2 as cv
import numpy as np
from mu_map.dataset.default import main
dataset = MuMapMockDataset()
main(dataset)
...@@ -4,29 +4,31 @@ import numpy as np ...@@ -4,29 +4,31 @@ import numpy as np
import torch import torch
from mu_map.dataset.default import MuMapDataset from mu_map.dataset.default import MuMapDataset
from mu_map.dataset.mock import MuMapMockDataset from mu_map.dataset.normalization import MeanNormTransform
from mu_map.dataset.normalization import norm_max, norm_gaussian from mu_map.dataset.transform import PadCropTranform, SequenceTransform
from mu_map.models.unet import UNet from mu_map.models.unet import UNet
from mu_map.util import to_grayscale, COLOR_WHITE from mu_map.util import to_grayscale, COLOR_WHITE
torch.set_grad_enabled(False) torch.set_grad_enabled(False)
dataset = MuMapMockDataset("data/second/") dataset = MuMapDataset(
"data/second/",
transform_normalization=SequenceTransform([
MeanNormTransform(),
PadCropTranform(dim=3, size=32)
]),
)
# model = UNet(in_channels=1, features=[8, 16])
model = UNet(in_channels=1) model = UNet(in_channels=1)
device = torch.device("cpu") device = torch.device("cpu")
weights = torch.load("train_data/snapshots/val_min_Model.pth", map_location=device) # weights = torch.load("train_data/snapshots/val_min_Model.pth", map_location=device)
model.load_state_dict(weights) # model.load_state_dict(weights)
model = model.eval() model = model.eval()
recon, mu_map = dataset[0] recon, mu_map = dataset[0]
recon = recon.unsqueeze(dim=0) recon = recon.unsqueeze(dim=0)
# recon = norm_max(recon)
recon = norm_gaussian(recon)
output = model(recon) output = model(recon)
# output = output * 40206.0
diff = ((mu_map - output) ** 2).mean() diff = ((mu_map - output) ** 2).mean()
print(f"Diff: {diff:.5f}") print(f"Diff: {diff:.5f}")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment