import argparse import os import cv2 as cv import matplotlib as mpl import matplotlib.pyplot as plt import numpy as np import pandas as pd from mu_map.polar_map.prepare import headers def get_circular_mark(shape: np.ndarray, channels:int=1): mask = np.full((*shape, channels), 0, np.uint8) cx, cy = np.array(mask.shape[:2]) // 2 mask = cv.circle( mask, center=(cx-1, cy), radius=cx - 2, color=(255,) * channels, thickness=cv.FILLED, ) mask = mask == 255 return mask[:, :, 0] if channels == 1 else mask parser = argparse.ArgumentParser(description="Visualize polar maps of different reconstructions", formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument("--polar_map_dir", type=str, default="data/polar_maps", help="directory containing the polar map images") parser.add_argument("--images_dir", type=str, default="images", help="sub-directory under <polar_map_dir> containing the actual image files") parser.add_argument("--csv", type=str, default="polar_maps.csv", help="file under <polar_map_dir> containing meta information about the polar maps") parser.add_argument("--baseline", choices=["symbia", "stir"], default="symbia", help="select the polar map treated as the baseline") parser.add_argument("--relative_difference", action="store_true", help="show the difference as a percentage difference to the baseline") args = parser.parse_args() args.images_dir = os.path.join(args.polar_map_dir, args.images_dir) args.csv = os.path.join(args.polar_map_dir, args.csv) meta = pd.read_csv(args.csv) ids = meta[headers.id].unique() dists_abs = {"No AC": [], "AC SYN": [], "AC CT": []} dists_per = {"No AC": [], "AC SYN": [], "AC CT": []} recons_tot = {"No AC": [], "AC SYN": [], "AC CT": []} baseline_tot = [] for _id in ids: print(_id) _meta = meta[(meta[headers.id] == _id) & ~(meta[headers.segments])] file_recon_ac = _meta[(_meta[headers.type] == "symbia") & _meta[headers.ac]][headers.file].values[0] file_recon_nac = _meta[~_meta[headers.ac]][headers.file].values[0] file_recon_syn = _meta[_meta[headers.type] == "synthetic"][headers.file].values[0] file_recon_ct = _meta[_meta[headers.type] == "ct"][headers.file].values[0] recon_ac = cv.imread(os.path.join(args.images_dir, file_recon_ac), cv.IMREAD_GRAYSCALE) recon_nac = cv.imread(os.path.join(args.images_dir, file_recon_nac), cv.IMREAD_GRAYSCALE) recon_syn = cv.imread(os.path.join(args.images_dir, file_recon_syn), cv.IMREAD_GRAYSCALE) recon_ct = cv.imread(os.path.join(args.images_dir, file_recon_ct), cv.IMREAD_GRAYSCALE) baseline = recon_ac.copy() if args.baseline == "symbia" else recon_ct.copy() recons = [recon_nac, recon_syn, recon_ct] labels = ["No AC", "AC SYN", "AC CT"] baseline = baseline * (100.0 / 255.0) recons = list(map(lambda recon: recon * (100.0 / 255.0), recons)) mask = get_circular_mark(baseline.shape) _baseline = baseline[mask] mask2 = _baseline != 0 _baseline = _baseline[mask2] baseline_tot.append(_baseline) for recon, label in zip(recons, labels): _recon = recon[mask] assert _recon[np.logical_not(mask2)].sum() == 0 _recon = _recon[mask2] recons_tot[label].append(_recon) dist_abs = np.absolute(_recon - _baseline) dist_percentage = 100 * dist_abs / _baseline # print(f" - {label:>6} - Absolute: {dist_abs.mean():.3f}+-{dist_abs.std():.3f}") # print(f" - {label:>6} - Percentage: {dist_percentage.mean():.3f}+-{dist_percentage.std():.3f}") dists_abs[label].append(dist_abs.mean()) dists_per[label].append(dist_percentage.mean()) print("Total:") for label in dists_abs.keys(): dist_abs = np.array(dists_abs[label]) dist_percentage = np.array(dists_per[label]) print(f" - {label:>6} - Absolute: {dist_abs.mean():.3f}+-{dist_abs.std():.3f}") print(f" - {label:>6} - Percentage: {dist_percentage.mean():.3f}+-{dist_percentage.std():.3f}") baseline_tot = np.concatenate(baseline_tot) bins_x = np.arange(0, 100.1, 2) bins_y = np.arange(0, 100.1, 2) fig, axs = plt.subplots(2, 3, figsize=(15, 10)) for i, (label, recon) in enumerate(recons_tot.items()): recon = np.concatenate(recon) h, xedges, yedges, image = axs[0, i].hist2d(baseline_tot, recon, bins=(bins_x, bins_y)) h = (h - h.min()) / (h.max() - h.min()) img = mpl.colormaps["viridis"](h) img = cv.line(img, (0, 0), (50, 50), (1.0, 0.0, 0.0, 0.3), 1, lineType=cv.LINE_AA) # img = img[::-1, ::-1] # axs[1, i].imshow(img, origin="lower") img = np.rot90(img) axs[1, i].imshow(img) axs[1, i].set_axis_off() axs[0, i].set_title(label) axs[0, i].set_xlabel("Baseline") axs[0, i].set_ylabel("Prediction") plt.show()