From c93101a086fde8e57168453e13e2c8d4b8223b1b Mon Sep 17 00:00:00 2001 From: Tamino Huxohl <thuxohl@techfak.uni-bielefeld.de> Date: Fri, 9 Dec 2022 13:36:47 +0100 Subject: [PATCH] add script for pixel-wise polarmap evaluation --- mu_map/polar_map/eval_pixelwise.py | 115 +++++++++++++++++++++++++++++ 1 file changed, 115 insertions(+) create mode 100644 mu_map/polar_map/eval_pixelwise.py diff --git a/mu_map/polar_map/eval_pixelwise.py b/mu_map/polar_map/eval_pixelwise.py new file mode 100644 index 0000000..1d9d9d0 --- /dev/null +++ b/mu_map/polar_map/eval_pixelwise.py @@ -0,0 +1,115 @@ +import argparse +import os + +import cv2 as cv +import matplotlib as mpl +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd + +from mu_map.polar_map.prepare import headers + + +def get_circular_mark(shape: np.ndarray, channels:int=1): + mask = np.full((*shape, channels), 0, np.uint8) + cx, cy = np.array(mask.shape[:2]) // 2 + mask = cv.circle( + mask, + center=(cx-1, cy), + radius=cx - 2, + color=(255,) * channels, + thickness=cv.FILLED, + ) + mask = mask == 255 + return mask[:, :, 0] if channels == 1 else mask + +parser = argparse.ArgumentParser(description="Visualize polar maps of different reconstructions", formatter_class=argparse.ArgumentDefaultsHelpFormatter) +parser.add_argument("--polar_map_dir", type=str, default="data/polar_maps", help="directory containing the polar map images") +parser.add_argument("--images_dir", type=str, default="images", help="sub-directory under <polar_map_dir> containing the actual image files") +parser.add_argument("--csv", type=str, default="polar_maps.csv", help="file under <polar_map_dir> containing meta information about the polar maps") +parser.add_argument("--baseline", choices=["symbia", "stir"], default="symbia", help="select the polar map treated as the baseline") +parser.add_argument("--relative_difference", action="store_true", help="show the difference as a percentage difference to the baseline") +args = parser.parse_args() + +args.images_dir = os.path.join(args.polar_map_dir, args.images_dir) +args.csv = os.path.join(args.polar_map_dir, args.csv) + +meta = pd.read_csv(args.csv) +ids = meta[headers.id].unique() + +dists_abs = {"No AC": [], "AC SYN": [], "AC CT": []} +dists_per = {"No AC": [], "AC SYN": [], "AC CT": []} +recons_tot = {"No AC": [], "AC SYN": [], "AC CT": []} +baseline_tot = [] + +for _id in ids: + print(_id) + _meta = meta[(meta[headers.id] == _id) & ~(meta[headers.segments])] + + file_recon_ac = _meta[(_meta[headers.type] == "symbia") & _meta[headers.ac]][headers.file].values[0] + file_recon_nac = _meta[~_meta[headers.ac]][headers.file].values[0] + file_recon_syn = _meta[_meta[headers.type] == "synthetic"][headers.file].values[0] + file_recon_ct = _meta[_meta[headers.type] == "ct"][headers.file].values[0] + + recon_ac = cv.imread(os.path.join(args.images_dir, file_recon_ac), cv.IMREAD_GRAYSCALE) + recon_nac = cv.imread(os.path.join(args.images_dir, file_recon_nac), cv.IMREAD_GRAYSCALE) + recon_syn = cv.imread(os.path.join(args.images_dir, file_recon_syn), cv.IMREAD_GRAYSCALE) + recon_ct = cv.imread(os.path.join(args.images_dir, file_recon_ct), cv.IMREAD_GRAYSCALE) + + baseline = recon_ac.copy() if args.baseline == "symbia" else recon_ct.copy() + recons = [recon_nac, recon_syn, recon_ct] + labels = ["No AC", "AC SYN", "AC CT"] + + baseline = baseline * (100.0 / 255.0) + recons = list(map(lambda recon: recon * (100.0 / 255.0), recons)) + + mask = get_circular_mark(baseline.shape) + _baseline = baseline[mask] + mask2 = _baseline != 0 + _baseline = _baseline[mask2] + baseline_tot.append(_baseline) + for recon, label in zip(recons, labels): + _recon = recon[mask] + assert _recon[np.logical_not(mask2)].sum() == 0 + _recon = _recon[mask2] + recons_tot[label].append(_recon) + + dist_abs = np.absolute(_recon - _baseline) + dist_percentage = 100 * dist_abs / _baseline + # print(f" - {label:>6} - Absolute: {dist_abs.mean():.3f}+-{dist_abs.std():.3f}") + # print(f" - {label:>6} - Percentage: {dist_percentage.mean():.3f}+-{dist_percentage.std():.3f}") + + dists_abs[label].append(dist_abs.mean()) + dists_per[label].append(dist_percentage.mean()) + +print("Total:") +for label in dists_abs.keys(): + dist_abs = np.array(dists_abs[label]) + dist_percentage = np.array(dists_per[label]) + print(f" - {label:>6} - Absolute: {dist_abs.mean():.3f}+-{dist_abs.std():.3f}") + print(f" - {label:>6} - Percentage: {dist_percentage.mean():.3f}+-{dist_percentage.std():.3f}") + + +baseline_tot = np.concatenate(baseline_tot) + +bins_x = np.arange(0, 100.1, 2) +bins_y = np.arange(0, 100.1, 2) +fig, axs = plt.subplots(2, 3, figsize=(15, 10)) +for i, (label, recon) in enumerate(recons_tot.items()): + recon = np.concatenate(recon) + h, xedges, yedges, image = axs[0, i].hist2d(baseline_tot, recon, bins=(bins_x, bins_y)) + + h = (h - h.min()) / (h.max() - h.min()) + img = mpl.colormaps["viridis"](h) + img = cv.line(img, (0, 0), (50, 50), (1.0, 0.0, 0.0, 0.3), 1, lineType=cv.LINE_AA) + # img = img[::-1, ::-1] + # axs[1, i].imshow(img, origin="lower") + img = np.rot90(img) + axs[1, i].imshow(img) + axs[1, i].set_axis_off() + + axs[0, i].set_title(label) + axs[0, i].set_xlabel("Baseline") + axs[0, i].set_ylabel("Prediction") + +plt.show() -- GitLab