Skip to content
Snippets Groups Projects
visualize.py 4.58 KiB
Newer Older
  • Learn to ignore specific revisions
  • Tamino Huxohl's avatar
    Tamino Huxohl committed
    from typing import Tuple
    
    
    import cv2 as cv
    import numpy as np
    
    
    def get_circular_mask(shape: Tuple[int, int], channels: int = 1) -> np.ndarray:
        """
        Create a mask for the largest possible circle in an image.
    
        This is used to extract polar maps from rectangular images.
    
        Parameters
        ----------
        shape: tuple of int
            the shape of the rectangle
        channels: int
            number of channels the mask should have
    
        Returns
        -------
        np.ndarray
            a mask as an array of booleans
        """
    
        mask = np.full((*shape, channels), 0, np.uint8)
        cx, cy = np.array(mask.shape[:2]) // 2
        mask = cv.circle(
            mask,
    
            center=(cx - 1, cy),
    
            radius=cx - 2,
            color=(255,) * channels,
            thickness=cv.FILLED,
        )
        mask = mask == 255
        return mask[:, :, 0] if channels == 1 else mask
    
    
    
    if __name__ == "__main__":
        import argparse
        import os
    
        import matplotlib as mlp
        import matplotlib.pyplot as plt
        import pandas as pd
    
        plt.rcParams.update(
            {
                "text.usetex": True,
            }
    
        from mu_map.polar_map.prepare import headers
    
        parser = argparse.ArgumentParser(
            description="Visualize polar maps of different reconstructions",
            formatter_class=argparse.ArgumentDefaultsHelpFormatter,
        )
        parser.add_argument(
            "--polar_map_dir",
            type=str,
            default="data/polar_maps",
            help="directory containing the polar map images",
        )
        parser.add_argument(
            "--images_dir",
            type=str,
            default="images",
            help="sub-directory under <polar_map_dir> containing the actual image files",
        )
        parser.add_argument(
            "--csv",
            type=str,
            default="polar_maps.csv",
            help="file under <polar_map_dir> containing meta information about the polar maps",
        )
        parser.add_argument(
            "--baseline",
            choices=["symbia", "stir"],
            default="stir",
            help="select the polar map treated as the baseline",
        )
        parser.add_argument(
            "--id", type=int, help="select a specific study to show by its id"
        )
        parser.add_argument(
            "--color_map",
            type=str,
            default="data/color_maps/PrismOeyn.cm",
            help="select the color map to visualize the polar maps",
        )
        parser.add_argument(
            "--save",
            type=str,
            help="save the visualization as an image",
        )
        args = parser.parse_args()
    
        args.images_dir = os.path.join(args.polar_map_dir, args.images_dir)
        args.csv = os.path.join(args.polar_map_dir, args.csv)
    
        meta = pd.read_csv(args.csv)
        ids = meta[headers.id].unique()
    
        if args.id:
            assert args.id in ids, f"Id {args.id} is not available. Chose one of {ids}."
            ids = [args.id]
    
        if os.path.isfile(args.color_map):
            color_map = pd.read_csv(args.color_map)
            color_map = mlp.colors.ListedColormap(color_map.values / 255.0)
        else:
            color_map = mlp.colormaps["plasma"]
    
        for _id in ids:
            print(f"Show id {_id:03d}")
            _meta = meta[(meta[headers.id] == _id) & ~(meta[headers.segments])]
    
            file_recon_ctac = _meta[(_meta[headers.type] == "symbia") & _meta[headers.ac]][
                headers.file
            ].values[0]
            file_recon_noac = _meta[~_meta[headers.ac]][headers.file].values[0]
            file_recon_dlac = _meta[_meta[headers.type] == "dl"][headers.file].values[0]
    
            recon_ctac = cv.imread(
                os.path.join(args.images_dir, file_recon_ctac), cv.IMREAD_GRAYSCALE
            )
            recon_noac = cv.imread(
                os.path.join(args.images_dir, file_recon_noac), cv.IMREAD_GRAYSCALE
            )
            recon_dlac = cv.imread(
                os.path.join(args.images_dir, file_recon_dlac), cv.IMREAD_GRAYSCALE
            )
    
            recons = [recon_ctac, recon_dlac, recon_noac]
            labels = ["CTAC", "DLAC", "No AC"]
    
            fig, axs = plt.subplots(1, 3, figsize=(15, 5))
            for ax in axs.flatten():
                ax.set_axis_off()
    
            mask = get_circular_mask(recon_ctac.shape, channels=4)
            black = np.zeros(mask.shape, np.uint8)
    
            for ax, recon, label in zip(axs, recons, labels):
                polar_map = color_map(recon)
                polar_map = np.where(mask, polar_map, black)
    
                ax.imshow(polar_map)
                ax.set_title(label)
    
            plt.tight_layout()
    
            fig.colorbar(
                mlp.cm.ScalarMappable(
                    norm=mlp.colors.Normalize(vmin=0, vmax=100), cmap=color_map
                ),
                fraction=0.05,
                ax=axs,
            )
    
            if args.save:
                plt.savefig(args.save, dpi=300)
    
            plt.show()