from typing import Tuple

import cv2 as cv
import numpy as np


def get_circular_mask(shape: Tuple[int, int], channels: int = 1) -> np.ndarray:
    """
    Create a mask for the largest possible circle in an image.

    This is used to extract polar maps from rectangular images.

    Parameters
    ----------
    shape: tuple of int
        the shape of the rectangle
    channels: int
        number of channels the mask should have

    Returns
    -------
    np.ndarray
        a mask as an array of booleans
    """
    mask = np.full((*shape, channels), 0, np.uint8)
    cx, cy = np.array(mask.shape[:2]) // 2
    mask = cv.circle(
        mask,
        center=(cx - 1, cy),
        radius=cx - 2,
        color=(255,) * channels,
        thickness=cv.FILLED,
    )
    mask = mask == 255
    return mask[:, :, 0] if channels == 1 else mask


if __name__ == "__main__":
    import argparse
    import os

    import matplotlib as mlp
    import matplotlib.pyplot as plt
    import pandas as pd

    plt.rcParams.update(
        {
            "text.usetex": True,
        }
    )

    from mu_map.polar_map.prepare import headers

    parser = argparse.ArgumentParser(
        description="Visualize polar maps of different reconstructions",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )
    parser.add_argument(
        "--polar_map_dir",
        type=str,
        default="data/polar_maps",
        help="directory containing the polar map images",
    )
    parser.add_argument(
        "--images_dir",
        type=str,
        default="images",
        help="sub-directory under <polar_map_dir> containing the actual image files",
    )
    parser.add_argument(
        "--csv",
        type=str,
        default="polar_maps.csv",
        help="file under <polar_map_dir> containing meta information about the polar maps",
    )
    parser.add_argument(
        "--baseline",
        choices=["symbia", "stir"],
        default="stir",
        help="select the polar map treated as the baseline",
    )
    parser.add_argument(
        "--id", type=int, help="select a specific study to show by its id"
    )
    parser.add_argument(
        "--color_map",
        type=str,
        default="data/color_maps/PrismOeyn.cm",
        help="select the color map to visualize the polar maps",
    )
    parser.add_argument(
        "--save",
        type=str,
        help="save the visualization as an image",
    )
    args = parser.parse_args()

    args.images_dir = os.path.join(args.polar_map_dir, args.images_dir)
    args.csv = os.path.join(args.polar_map_dir, args.csv)

    meta = pd.read_csv(args.csv)
    ids = meta[headers.id].unique()

    if args.id:
        assert args.id in ids, f"Id {args.id} is not available. Chose one of {ids}."
        ids = [args.id]

    if os.path.isfile(args.color_map):
        color_map = pd.read_csv(args.color_map)
        color_map = mlp.colors.ListedColormap(color_map.values / 255.0)
    else:
        color_map = mlp.colormaps["plasma"]

    for _id in ids:
        print(f"Show id {_id:03d}")
        _meta = meta[(meta[headers.id] == _id) & ~(meta[headers.segments])]

        file_recon_ctac = _meta[(_meta[headers.type] == "symbia") & _meta[headers.ac]][
            headers.file
        ].values[0]
        file_recon_noac = _meta[~_meta[headers.ac]][headers.file].values[0]
        file_recon_dlac = _meta[_meta[headers.type] == "dl"][headers.file].values[0]

        recon_ctac = cv.imread(
            os.path.join(args.images_dir, file_recon_ctac), cv.IMREAD_GRAYSCALE
        )
        recon_noac = cv.imread(
            os.path.join(args.images_dir, file_recon_noac), cv.IMREAD_GRAYSCALE
        )
        recon_dlac = cv.imread(
            os.path.join(args.images_dir, file_recon_dlac), cv.IMREAD_GRAYSCALE
        )

        recons = [recon_ctac, recon_dlac, recon_noac]
        labels = ["CTAC", "DLAC", "No AC"]

        fig, axs = plt.subplots(1, 3, figsize=(15, 5))
        for ax in axs.flatten():
            ax.set_axis_off()

        mask = get_circular_mask(recon_ctac.shape, channels=4)
        black = np.zeros(mask.shape, np.uint8)

        for ax, recon, label in zip(axs, recons, labels):
            polar_map = color_map(recon)
            polar_map = np.where(mask, polar_map, black)

            ax.imshow(polar_map)
            ax.set_title(label)

        plt.tight_layout()

        fig.colorbar(
            mlp.cm.ScalarMappable(
                norm=mlp.colors.Normalize(vmin=0, vmax=100), cmap=color_map
            ),
            fraction=0.05,
            ax=axs,
        )

        if args.save:
            plt.savefig(args.save, dpi=300)

        plt.show()