import argparse import os import cv2 as cv import matplotlib as mlp import matplotlib.pyplot as plt import numpy as np import pandas as pd from mu_map.polar_map.prepare import headers def get_circular_mark(shape: np.ndarray, channels:int=1): mask = np.full((*shape, channels), 0, np.uint8) cx, cy = np.array(mask.shape[:2]) // 2 mask = cv.circle( mask, center=(cx-1, cy), radius=cx - 2, color=(255,) * channels, thickness=cv.FILLED, ) mask = mask == 255 return mask[:, :, 0] if channels == 1 else mask parser = argparse.ArgumentParser(description="Visualize polar maps of different reconstructions", formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument("--polar_map_dir", type=str, default="data/polar_maps", help="directory containing the polar map images") parser.add_argument("--images_dir", type=str, default="images", help="sub-directory under <polar_map_dir> containing the actual image files") parser.add_argument("--csv", type=str, default="polar_maps.csv", help="file under <polar_map_dir> containing meta information about the polar maps") parser.add_argument("--baseline", choices=["symbia", "stir"], default="symbia", help="select the polar map treated as the baseline") parser.add_argument("--id", type=int, help="select a specific study to show by its id") parser.add_argument("--relative_difference", action="store_true", help="show the difference as a percentage difference to the baseline") args = parser.parse_args() args.images_dir = os.path.join(args.polar_map_dir, args.images_dir) args.csv = os.path.join(args.polar_map_dir, args.csv) meta = pd.read_csv(args.csv) ids = meta[headers.id].unique() if args.id: assert args.id in ids, f"Id {args.id} is not available. Chose one of {ids}." ids = [args.id] for _id in ids: print(f"Show id {_id:03d}") _meta = meta[(meta[headers.id] == _id) & ~(meta[headers.segments])] file_recon_ac = _meta[(_meta[headers.type] == "symbia") & _meta[headers.ac]][headers.file].values[0] file_recon_nac = _meta[~_meta[headers.ac]][headers.file].values[0] file_recon_syn = _meta[_meta[headers.type] == "synthetic"][headers.file].values[0] file_recon_ct = _meta[_meta[headers.type] == "ct"][headers.file].values[0] recon_ac = cv.imread(os.path.join(args.images_dir, file_recon_ac), cv.IMREAD_GRAYSCALE) recon_nac = cv.imread(os.path.join(args.images_dir, file_recon_nac), cv.IMREAD_GRAYSCALE) recon_syn = cv.imread(os.path.join(args.images_dir, file_recon_syn), cv.IMREAD_GRAYSCALE) recon_ct = cv.imread(os.path.join(args.images_dir, file_recon_ct), cv.IMREAD_GRAYSCALE) baseline = recon_ac.copy() if args.baseline == "symbia" else recon_ct.copy() recons = [recon_nac, recon_syn, recon_ct] labels = ["No AC", "AC SYN", "AC CT"] diffs = map(lambda recon: (recon.astype(float) - baseline) * 100.0 / 255.0, recons) if args.relative_difference: divider = np.where(baseline > 0, baseline, 1) diffs = map(lambda diff: 100 * diff / divider, diffs) diffs = list(diffs) diff_min = min(map(lambda diff: diff.min(), diffs)) diff_max = max(map(lambda diff: diff.max(), diffs)) diff_min = -max(abs(diff_min), diff_max) diff_max = max(abs(diff_min), diff_max) diffs = map(lambda diff: (diff - diff_min) / (diff_max - diff_min), diffs) diffs = list(diffs) fig, axs = plt.subplots(2, 4, figsize=(16, 8)) for ax in axs.flatten(): ax.set_axis_off() mask = np.full((*baseline.shape, 4), 0, np.uint8) cx, cy = np.array(mask.shape[:2]) // 2 mask = cv.circle( mask, center=(cx - 1, cy + 1), radius=cx - 2, color=(255, 255, 255, 255), thickness=cv.FILLED, ) mask = mask == 255 black = np.zeros(mask.shape, np.uint8) def colormap_and_mask(img, cm): _img = cm(img) return np.where(mask, _img, black) cm_plasma = mlp.colormaps["plasma"] cm_divergent = mlp.colormaps["RdYlBu"].reversed() axs[0, 0].imshow(colormap_and_mask(baseline, cm_plasma)) axs[0, 0].set_title("AC" if args.baseline == "symbia" else "AC - CT") for i, (recon, diff, label) in enumerate(zip(recons, diffs, labels), start=1): axs[0, i].imshow(colormap_and_mask(recon, cm_plasma)) axs[0, i].set_title(label) axs[1, i].imshow(colormap_and_mask(diff, cm_divergent)) fig.colorbar( mlp.cm.ScalarMappable( norm=mlp.colors.Normalize(vmin=diff_min, vmax=diff_max), cmap=cm_divergent ), ax=axs[1, 1:4], ) fig.colorbar( mlp.cm.ScalarMappable( norm=mlp.colors.Normalize(vmin=0, vmax=100), cmap=cm_plasma ), ax=axs[0, 1:4] ) plt.show()