From c93101a086fde8e57168453e13e2c8d4b8223b1b Mon Sep 17 00:00:00 2001
From: Tamino Huxohl <thuxohl@techfak.uni-bielefeld.de>
Date: Fri, 9 Dec 2022 13:36:47 +0100
Subject: [PATCH] add script for pixel-wise polarmap evaluation

---
 mu_map/polar_map/eval_pixelwise.py | 115 +++++++++++++++++++++++++++++
 1 file changed, 115 insertions(+)
 create mode 100644 mu_map/polar_map/eval_pixelwise.py

diff --git a/mu_map/polar_map/eval_pixelwise.py b/mu_map/polar_map/eval_pixelwise.py
new file mode 100644
index 0000000..1d9d9d0
--- /dev/null
+++ b/mu_map/polar_map/eval_pixelwise.py
@@ -0,0 +1,115 @@
+import argparse
+import os
+
+import cv2 as cv
+import matplotlib as mpl
+import matplotlib.pyplot as plt
+import numpy as np
+import pandas as pd
+
+from mu_map.polar_map.prepare import headers
+
+
+def get_circular_mark(shape: np.ndarray, channels:int=1):
+    mask = np.full((*shape, channels), 0, np.uint8)
+    cx, cy = np.array(mask.shape[:2]) // 2
+    mask = cv.circle(
+        mask,
+        center=(cx-1, cy),
+        radius=cx - 2,
+        color=(255,) * channels,
+        thickness=cv.FILLED,
+    )
+    mask = mask == 255
+    return mask[:, :, 0] if channels == 1 else mask
+
+parser = argparse.ArgumentParser(description="Visualize polar maps of different reconstructions", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+parser.add_argument("--polar_map_dir", type=str, default="data/polar_maps", help="directory containing the polar map images")
+parser.add_argument("--images_dir", type=str, default="images", help="sub-directory under <polar_map_dir> containing the actual image files")
+parser.add_argument("--csv", type=str, default="polar_maps.csv", help="file under <polar_map_dir> containing meta information about the polar maps")
+parser.add_argument("--baseline", choices=["symbia", "stir"], default="symbia", help="select the polar map treated as the baseline")
+parser.add_argument("--relative_difference", action="store_true", help="show the difference as a percentage difference to the baseline")
+args = parser.parse_args()
+
+args.images_dir = os.path.join(args.polar_map_dir, args.images_dir)
+args.csv = os.path.join(args.polar_map_dir, args.csv)
+
+meta = pd.read_csv(args.csv)
+ids = meta[headers.id].unique()
+
+dists_abs = {"No AC": [], "AC SYN": [], "AC CT": []}
+dists_per = {"No AC": [], "AC SYN": [], "AC CT": []}
+recons_tot = {"No AC": [], "AC SYN": [], "AC CT": []}
+baseline_tot = []
+ 
+for _id in ids:
+    print(_id)
+    _meta = meta[(meta[headers.id] == _id) & ~(meta[headers.segments])]
+
+    file_recon_ac = _meta[(_meta[headers.type] == "symbia") & _meta[headers.ac]][headers.file].values[0]
+    file_recon_nac = _meta[~_meta[headers.ac]][headers.file].values[0]
+    file_recon_syn = _meta[_meta[headers.type] == "synthetic"][headers.file].values[0]
+    file_recon_ct = _meta[_meta[headers.type] == "ct"][headers.file].values[0]
+
+    recon_ac = cv.imread(os.path.join(args.images_dir, file_recon_ac), cv.IMREAD_GRAYSCALE)
+    recon_nac = cv.imread(os.path.join(args.images_dir, file_recon_nac), cv.IMREAD_GRAYSCALE)
+    recon_syn = cv.imread(os.path.join(args.images_dir, file_recon_syn), cv.IMREAD_GRAYSCALE)
+    recon_ct = cv.imread(os.path.join(args.images_dir, file_recon_ct), cv.IMREAD_GRAYSCALE)
+
+    baseline = recon_ac.copy() if args.baseline == "symbia" else recon_ct.copy()
+    recons = [recon_nac, recon_syn, recon_ct]
+    labels = ["No AC", "AC SYN", "AC CT"]
+
+    baseline = baseline * (100.0 / 255.0)
+    recons = list(map(lambda recon: recon * (100.0 / 255.0), recons))
+
+    mask = get_circular_mark(baseline.shape)
+    _baseline = baseline[mask]
+    mask2 = _baseline != 0
+    _baseline = _baseline[mask2]
+    baseline_tot.append(_baseline)
+    for recon, label in zip(recons, labels):
+        _recon = recon[mask]
+        assert _recon[np.logical_not(mask2)].sum() == 0
+        _recon = _recon[mask2]
+        recons_tot[label].append(_recon)
+
+        dist_abs = np.absolute(_recon - _baseline)
+        dist_percentage = 100 * dist_abs / _baseline
+        # print(f" - {label:>6} -   Absolute: {dist_abs.mean():.3f}+-{dist_abs.std():.3f}")
+        # print(f" - {label:>6} - Percentage: {dist_percentage.mean():.3f}+-{dist_percentage.std():.3f}")
+
+        dists_abs[label].append(dist_abs.mean())
+        dists_per[label].append(dist_percentage.mean())
+
+print("Total:")
+for label in dists_abs.keys():
+    dist_abs = np.array(dists_abs[label])
+    dist_percentage = np.array(dists_per[label])
+    print(f" - {label:>6} -   Absolute: {dist_abs.mean():.3f}+-{dist_abs.std():.3f}")
+    print(f" - {label:>6} - Percentage: {dist_percentage.mean():.3f}+-{dist_percentage.std():.3f}")
+
+
+baseline_tot = np.concatenate(baseline_tot)
+
+bins_x = np.arange(0, 100.1, 2)
+bins_y = np.arange(0, 100.1, 2)
+fig, axs = plt.subplots(2, 3, figsize=(15, 10))
+for i, (label, recon) in enumerate(recons_tot.items()):
+    recon = np.concatenate(recon)
+    h, xedges, yedges, image = axs[0, i].hist2d(baseline_tot, recon, bins=(bins_x, bins_y))
+
+    h = (h - h.min()) / (h.max() - h.min())
+    img = mpl.colormaps["viridis"](h)
+    img = cv.line(img, (0, 0), (50, 50), (1.0, 0.0, 0.0, 0.3), 1, lineType=cv.LINE_AA)
+    # img = img[::-1, ::-1]
+    # axs[1, i].imshow(img, origin="lower")
+    img = np.rot90(img)
+    axs[1, i].imshow(img)
+    axs[1, i].set_axis_off()
+
+    axs[0, i].set_title(label)
+    axs[0, i].set_xlabel("Baseline")
+    axs[0, i].set_ylabel("Prediction")
+
+plt.show()
-- 
GitLab