Skip to content
Snippets Groups Projects
Commit 5179229d authored by Tamino Huxohl's avatar Tamino Huxohl
Browse files

add script to measure performance after training

parent 4b3bf472
No related branches found
No related tags found
No related merge requests found
import numpy as np
import torch
from mu_map.dataset.default import MuMapDataset
from mu_map.dataset.normalization import MeanNormTransform
from mu_map.models.unet import UNet
def mse(prediction: np.array, target: np.array):
pass
def nmae(prediction: np.array, target: np.array):
mae = np.absolute(prediction - target) / prediction.size
nmae = mae / (target.max() - target.min())
return nmae
device = torch.device("cpu")
model = UNet()
model.load_state_dict(torch.load("xx.pth", map_location=device))
model = model.eval()
dataset = MuMapDataset("data/initial/", transform_normalization=MeanNormTransform(), split_name="validation")
scores_mse = []
scores_nmae = []
for recon, mu_map in dataset:
recon = recon.unsqueeze(dim=0).to(device)
prediction = model(recon).squeeze().numpy()
mu_map = mu_map.squeeze().numpy()
scores_nmae.append(nmae(prediction, mu_map))
scores_mse.append(mse(prediction, mu_map))
scores_mse = np.array(scores_mse)
scores_nmae = np.array(scores_nmae)
mse_avg = scores_mse.mean()
mse_std = np.std(scores_mse)
nmae_avg = scores_nmae.mean()
nmae_std = np.std(scores_nmae)
print("Scores:")
print(f" - NMAE: {nmae_avg:.6f}±{nmae_std:.6f}")
print(f" - MSE: {mse_avg:.6f}±{mse_std:.6f}")
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment