From 5179229d7aece80f8224ec487e004fd9e5781df4 Mon Sep 17 00:00:00 2001 From: Tamino Huxohl <thuxohl@techfak.uni-bielefeld.de> Date: Mon, 10 Oct 2022 15:13:51 +0200 Subject: [PATCH] add script to measure performance after training --- mu_map/eval/__init__.py | 0 mu_map/eval/measures.py | 45 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 45 insertions(+) create mode 100644 mu_map/eval/__init__.py create mode 100644 mu_map/eval/measures.py diff --git a/mu_map/eval/__init__.py b/mu_map/eval/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/mu_map/eval/measures.py b/mu_map/eval/measures.py new file mode 100644 index 0000000..9881b30 --- /dev/null +++ b/mu_map/eval/measures.py @@ -0,0 +1,45 @@ +import numpy as np +import torch + +from mu_map.dataset.default import MuMapDataset +from mu_map.dataset.normalization import MeanNormTransform +from mu_map.models.unet import UNet + +def mse(prediction: np.array, target: np.array): + pass + +def nmae(prediction: np.array, target: np.array): + mae = np.absolute(prediction - target) / prediction.size + nmae = mae / (target.max() - target.min()) + return nmae + +device = torch.device("cpu") +model = UNet() +model.load_state_dict(torch.load("xx.pth", map_location=device)) +model = model.eval() + +dataset = MuMapDataset("data/initial/", transform_normalization=MeanNormTransform(), split_name="validation") + +scores_mse = [] +scores_nmae = [] +for recon, mu_map in dataset: + recon = recon.unsqueeze(dim=0).to(device) + + prediction = model(recon).squeeze().numpy() + mu_map = mu_map.squeeze().numpy() + + scores_nmae.append(nmae(prediction, mu_map)) + scores_mse.append(mse(prediction, mu_map)) +scores_mse = np.array(scores_mse) +scores_nmae = np.array(scores_nmae) + +mse_avg = scores_mse.mean() +mse_std = np.std(scores_mse) + +nmae_avg = scores_nmae.mean() +nmae_std = np.std(scores_nmae) + +print("Scores:") +print(f" - NMAE: {nmae_avg:.6f}±{nmae_std:.6f}") +print(f" - MSE: {mse_avg:.6f}±{mse_std:.6f}") + -- GitLab