Skip to content
Snippets Groups Projects
Commit 0daff4e2 authored by Tamino Huxohl's avatar Tamino Huxohl
Browse files

fix eval measure script

parent 5179229d
No related merge requests found
...@@ -5,24 +5,46 @@ from mu_map.dataset.default import MuMapDataset ...@@ -5,24 +5,46 @@ from mu_map.dataset.default import MuMapDataset
from mu_map.dataset.normalization import MeanNormTransform from mu_map.dataset.normalization import MeanNormTransform
from mu_map.models.unet import UNet from mu_map.models.unet import UNet
torch.set_grad_enabled(False)
def mse(prediction: np.array, target: np.array): def mse(prediction: np.array, target: np.array):
pass se = (prediction - target) ** 2
mse = se.sum() / se.size
return mse
def nmae(prediction: np.array, target: np.array): def nmae(prediction: np.array, target: np.array):
mae = np.absolute(prediction - target) / prediction.size mae = np.absolute(prediction - target) / prediction.size
nmae = mae / (target.max() - target.min()) nmae = mae.sum() / (target.max() - target.min())
return nmae return nmae
device = torch.device("cpu") device = torch.device("cpu")
model = UNet() model = UNet()
model.load_state_dict(torch.load("xx.pth", map_location=device)) model.load_state_dict(torch.load("trainings/01_cgan/snapshots/050_generator.pth", map_location=device))
model = model.eval() model = model.eval()
dataset = MuMapDataset("data/initial/", transform_normalization=MeanNormTransform(), split_name="validation") dataset = MuMapDataset("data/initial/", transform_normalization=MeanNormTransform(), split_name="validation")
def crop_or_pad(inputs: torch.Tensor, target_shape=32):
# crop
if inputs.shape[1] > target_shape:
c = inputs.shape[1] // 2
diff = target_shape // 2
inputs = inputs[:, c-diff:c+diff]
elif inputs.shape[1] < target_shape:
diff = target_shape - inputs.shape[1]
padding_front = math.ceil(diff / 2)
padding_back = math.floor(diff / 2)
padding = [0, 0, 0, 0, padding_front, padding_back]
inputs = torch.nn.functional.pad(inputs, padding, mode="constant", value=0)
return inputs
scores_mse = [] scores_mse = []
scores_nmae = [] scores_nmae = []
for recon, mu_map in dataset: for i, (recon, mu_map) in enumerate(dataset):
print(f"{i:02d}/{len(dataset)}", end="\r")
mu_map = crop_or_pad(mu_map)
recon = crop_or_pad(recon)
recon = recon.unsqueeze(dim=0).to(device) recon = recon.unsqueeze(dim=0).to(device)
prediction = model(recon).squeeze().numpy() prediction = model(recon).squeeze().numpy()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment