Skip to content
Snippets Groups Projects
Commit 88657e0a authored by Tamino Huxohl's avatar Tamino Huxohl
Browse files

add usage of new transforms to eval script

parent 9daaf276
No related branches found
No related tags found
No related merge requests found
import numpy as np
import pandas as pd
import torch
from mu_map.dataset.default import MuMapDataset
from mu_map.dataset.normalization import MeanNormTransform
from mu_map.dataset.transform import SequenceTransform, PadCropTranform
from mu_map.models.unet import UNet
torch.set_grad_enabled(False)
......@@ -22,22 +24,8 @@ model = UNet()
model.load_state_dict(torch.load("trainings/01_cgan/snapshots/050_generator.pth", map_location=device))
model = model.eval()
dataset = MuMapDataset("data/initial/", transform_normalization=MeanNormTransform(), split_name="validation")
def crop_or_pad(inputs: torch.Tensor, target_shape=32):
# crop
if inputs.shape[1] > target_shape:
c = inputs.shape[1] // 2
diff = target_shape // 2
inputs = inputs[:, c-diff:c+diff]
elif inputs.shape[1] < target_shape:
diff = target_shape - inputs.shape[1]
padding_front = math.ceil(diff / 2)
padding_back = math.floor(diff / 2)
padding = [0, 0, 0, 0, padding_front, padding_back]
inputs = torch.nn.functional.pad(inputs, padding, mode="constant", value=0)
return inputs
transform_normalization = SequenceTransform(transforms=[MeanNormTransform(), PadCropTranform(dim=3, size=32)])
dataset = MuMapDataset("data/initial/", transform_normalization=transform_normalization, split_name="validation")
scores_mse = []
scores_nmae = []
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment