Skip to content
Snippets Groups Projects
visualize.py 4.66 KiB
Newer Older
  • Learn to ignore specific revisions
  • import argparse
    import os
    
    import cv2 as cv
    import matplotlib as mlp
    import matplotlib.pyplot as plt
    import numpy as np
    import pandas as pd
    
    from mu_map.polar_map.prepare import headers
    
    
    def get_circular_mark(shape: np.ndarray, channels:int=1):
        mask = np.full((*shape, channels), 0, np.uint8)
        cx, cy = np.array(mask.shape[:2]) // 2
        mask = cv.circle(
            mask,
            center=(cx-1, cy),
            radius=cx - 2,
            color=(255,) * channels,
            thickness=cv.FILLED,
        )
        mask = mask == 255
        return mask[:, :, 0] if channels == 1 else mask
    
    
    parser = argparse.ArgumentParser(description="Visualize polar maps of different reconstructions", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument("--polar_map_dir", type=str, default="data/polar_maps", help="directory containing the polar map images")
    parser.add_argument("--images_dir", type=str, default="images", help="sub-directory under <polar_map_dir> containing the actual image files")
    parser.add_argument("--csv", type=str, default="polar_maps.csv", help="file under <polar_map_dir> containing meta information about the polar maps")
    parser.add_argument("--baseline", choices=["symbia", "stir"], default="symbia", help="select the polar map treated as the baseline")
    parser.add_argument("--id", type=int, help="select a specific study to show by its id")
    parser.add_argument("--relative_difference", action="store_true", help="show the difference as a percentage difference to the baseline")
    args = parser.parse_args()
    
    args.images_dir = os.path.join(args.polar_map_dir, args.images_dir)
    args.csv = os.path.join(args.polar_map_dir, args.csv)
    
    meta = pd.read_csv(args.csv)
    ids = meta[headers.id].unique()
    
    if args.id:
        assert args.id in ids, f"Id {args.id} is not available. Chose one of {ids}."
        ids = [args.id]
    
    
    for _id in ids:
        print(f"Show id {_id:03d}")
        _meta = meta[(meta[headers.id] == _id) & ~(meta[headers.segments])]
    
        file_recon_ac = _meta[(_meta[headers.type] == "symbia") & _meta[headers.ac]][headers.file].values[0]
        file_recon_nac = _meta[~_meta[headers.ac]][headers.file].values[0]
        file_recon_syn = _meta[_meta[headers.type] == "synthetic"][headers.file].values[0]
        file_recon_ct = _meta[_meta[headers.type] == "ct"][headers.file].values[0]
    
        recon_ac = cv.imread(os.path.join(args.images_dir, file_recon_ac), cv.IMREAD_GRAYSCALE)
        recon_nac = cv.imread(os.path.join(args.images_dir, file_recon_nac), cv.IMREAD_GRAYSCALE)
        recon_syn = cv.imread(os.path.join(args.images_dir, file_recon_syn), cv.IMREAD_GRAYSCALE)
        recon_ct = cv.imread(os.path.join(args.images_dir, file_recon_ct), cv.IMREAD_GRAYSCALE)
    
        baseline = recon_ac.copy() if args.baseline == "symbia" else recon_ct.copy()
        recons = [recon_nac, recon_syn, recon_ct]
        labels = ["No AC", "AC SYN", "AC CT"]
    
    
        diffs = map(lambda recon: (recon.astype(float) - baseline) * 100.0 / 255.0, recons)
    
        if args.relative_difference:
            divider = np.where(baseline > 0, baseline, 1)
            diffs = map(lambda diff: 100 * diff / divider, diffs)
    
        diffs = list(diffs)
    
        diff_min = min(map(lambda diff: diff.min(), diffs))
        diff_max = max(map(lambda diff: diff.max(), diffs))
        diff_min = -max(abs(diff_min), diff_max)
        diff_max = max(abs(diff_min), diff_max)
    
        diffs = map(lambda diff: (diff - diff_min) / (diff_max - diff_min), diffs)
        diffs = list(diffs)
    
        fig, axs = plt.subplots(2, 4, figsize=(16, 8))
        for ax in axs.flatten():
            ax.set_axis_off()
    
        mask = np.full((*baseline.shape, 4), 0, np.uint8)
        cx, cy = np.array(mask.shape[:2]) // 2
        mask = cv.circle(
            mask,
            center=(cx - 1, cy + 1),
            radius=cx - 2,
            color=(255, 255, 255, 255),
            thickness=cv.FILLED,
        )
        mask = mask == 255
        black = np.zeros(mask.shape, np.uint8)
    
    
        def colormap_and_mask(img, cm):
            _img = cm(img)
            return np.where(mask, _img, black)
    
        cm_plasma = mlp.colormaps["plasma"]
        cm_divergent = mlp.colormaps["RdYlBu"].reversed()
        axs[0, 0].imshow(colormap_and_mask(baseline, cm_plasma))
        axs[0, 0].set_title("AC" if args.baseline == "symbia" else "AC - CT")
        for i, (recon, diff, label) in enumerate(zip(recons, diffs, labels), start=1):
            axs[0, i].imshow(colormap_and_mask(recon, cm_plasma))
            axs[0, i].set_title(label)
            axs[1, i].imshow(colormap_and_mask(diff, cm_divergent))
    
        fig.colorbar(
            mlp.cm.ScalarMappable(
                norm=mlp.colors.Normalize(vmin=diff_min, vmax=diff_max), cmap=cm_divergent
            ),
            ax=axs[1, 1:4],
        )
        fig.colorbar(
            mlp.cm.ScalarMappable(
                norm=mlp.colors.Normalize(vmin=0, vmax=100), cmap=cm_plasma
            ),
            ax=axs[0, 1:4]
        )
        plt.show()