From 5179229d7aece80f8224ec487e004fd9e5781df4 Mon Sep 17 00:00:00 2001
From: Tamino Huxohl <thuxohl@techfak.uni-bielefeld.de>
Date: Mon, 10 Oct 2022 15:13:51 +0200
Subject: [PATCH] add script to measure performance after training

---
 mu_map/eval/__init__.py |  0
 mu_map/eval/measures.py | 45 +++++++++++++++++++++++++++++++++++++++++
 2 files changed, 45 insertions(+)
 create mode 100644 mu_map/eval/__init__.py
 create mode 100644 mu_map/eval/measures.py

diff --git a/mu_map/eval/__init__.py b/mu_map/eval/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/mu_map/eval/measures.py b/mu_map/eval/measures.py
new file mode 100644
index 0000000..9881b30
--- /dev/null
+++ b/mu_map/eval/measures.py
@@ -0,0 +1,45 @@
+import numpy as np
+import torch
+
+from mu_map.dataset.default import MuMapDataset
+from mu_map.dataset.normalization import MeanNormTransform
+from mu_map.models.unet import UNet
+
+def mse(prediction: np.array, target: np.array):
+    pass
+
+def nmae(prediction: np.array, target: np.array):
+    mae = np.absolute(prediction - target) / prediction.size
+    nmae = mae / (target.max() - target.min())
+    return nmae
+
+device = torch.device("cpu")
+model = UNet()
+model.load_state_dict(torch.load("xx.pth", map_location=device))
+model = model.eval()
+
+dataset = MuMapDataset("data/initial/", transform_normalization=MeanNormTransform(), split_name="validation")
+
+scores_mse = []
+scores_nmae = []
+for recon, mu_map in dataset:
+    recon = recon.unsqueeze(dim=0).to(device)
+
+    prediction = model(recon).squeeze().numpy()
+    mu_map = mu_map.squeeze().numpy()
+
+    scores_nmae.append(nmae(prediction, mu_map))
+    scores_mse.append(mse(prediction, mu_map))
+scores_mse = np.array(scores_mse)
+scores_nmae = np.array(scores_nmae)
+
+mse_avg = scores_mse.mean()
+mse_std = np.std(scores_mse)
+
+nmae_avg = scores_nmae.mean()
+nmae_std = np.std(scores_nmae)
+
+print("Scores:")
+print(f" - NMAE: {nmae_avg:.6f}±{nmae_std:.6f}")
+print(f" -  MSE: {mse_avg:.6f}±{mse_std:.6f}")
+
-- 
GitLab