diff --git a/mu_map/eval/measures.py b/mu_map/eval/measures.py index bf1d401f9c2546e3bf812deb2508659d5524928f..c20629560ea5d30dc971765f55454e76287c3072 100644 --- a/mu_map/eval/measures.py +++ b/mu_map/eval/measures.py @@ -1,8 +1,10 @@ import numpy as np +import pandas as pd import torch from mu_map.dataset.default import MuMapDataset from mu_map.dataset.normalization import MeanNormTransform +from mu_map.dataset.transform import SequenceTransform, PadCropTranform from mu_map.models.unet import UNet torch.set_grad_enabled(False) @@ -22,22 +24,8 @@ model = UNet() model.load_state_dict(torch.load("trainings/01_cgan/snapshots/050_generator.pth", map_location=device)) model = model.eval() -dataset = MuMapDataset("data/initial/", transform_normalization=MeanNormTransform(), split_name="validation") - -def crop_or_pad(inputs: torch.Tensor, target_shape=32): - # crop - if inputs.shape[1] > target_shape: - c = inputs.shape[1] // 2 - diff = target_shape // 2 - inputs = inputs[:, c-diff:c+diff] - elif inputs.shape[1] < target_shape: - diff = target_shape - inputs.shape[1] - padding_front = math.ceil(diff / 2) - padding_back = math.floor(diff / 2) - padding = [0, 0, 0, 0, padding_front, padding_back] - inputs = torch.nn.functional.pad(inputs, padding, mode="constant", value=0) - return inputs - +transform_normalization = SequenceTransform(transforms=[MeanNormTransform(), PadCropTranform(dim=3, size=32)]) +dataset = MuMapDataset("data/initial/", transform_normalization=transform_normalization, split_name="validation") scores_mse = [] scores_nmae = []