import argparse
import os

import cv2 as cv
import matplotlib as mlp
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from mu_map.polar_map.prepare import headers


def get_circular_mark(shape: np.ndarray, channels:int=1):
    mask = np.full((*shape, channels), 0, np.uint8)
    cx, cy = np.array(mask.shape[:2]) // 2
    mask = cv.circle(
        mask,
        center=(cx-1, cy),
        radius=cx - 2,
        color=(255,) * channels,
        thickness=cv.FILLED,
    )
    mask = mask == 255
    return mask[:, :, 0] if channels == 1 else mask


parser = argparse.ArgumentParser(description="Visualize polar maps of different reconstructions", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("--polar_map_dir", type=str, default="data/polar_maps", help="directory containing the polar map images")
parser.add_argument("--images_dir", type=str, default="images", help="sub-directory under <polar_map_dir> containing the actual image files")
parser.add_argument("--csv", type=str, default="polar_maps.csv", help="file under <polar_map_dir> containing meta information about the polar maps")
parser.add_argument("--baseline", choices=["symbia", "stir"], default="symbia", help="select the polar map treated as the baseline")
parser.add_argument("--id", type=int, help="select a specific study to show by its id")
parser.add_argument("--relative_difference", action="store_true", help="show the difference as a percentage difference to the baseline")
args = parser.parse_args()

args.images_dir = os.path.join(args.polar_map_dir, args.images_dir)
args.csv = os.path.join(args.polar_map_dir, args.csv)

meta = pd.read_csv(args.csv)
ids = meta[headers.id].unique()

if args.id:
    assert args.id in ids, f"Id {args.id} is not available. Chose one of {ids}."
    ids = [args.id]


for _id in ids:
    print(f"Show id {_id:03d}")
    _meta = meta[(meta[headers.id] == _id) & ~(meta[headers.segments])]

    file_recon_ac = _meta[(_meta[headers.type] == "symbia") & _meta[headers.ac]][headers.file].values[0]
    file_recon_nac = _meta[~_meta[headers.ac]][headers.file].values[0]
    file_recon_syn = _meta[_meta[headers.type] == "synthetic"][headers.file].values[0]
    file_recon_ct = _meta[_meta[headers.type] == "ct"][headers.file].values[0]

    recon_ac = cv.imread(os.path.join(args.images_dir, file_recon_ac), cv.IMREAD_GRAYSCALE)
    recon_nac = cv.imread(os.path.join(args.images_dir, file_recon_nac), cv.IMREAD_GRAYSCALE)
    recon_syn = cv.imread(os.path.join(args.images_dir, file_recon_syn), cv.IMREAD_GRAYSCALE)
    recon_ct = cv.imread(os.path.join(args.images_dir, file_recon_ct), cv.IMREAD_GRAYSCALE)

    baseline = recon_ac.copy() if args.baseline == "symbia" else recon_ct.copy()
    recons = [recon_nac, recon_syn, recon_ct]
    labels = ["No AC", "AC SYN", "AC CT"]


    diffs = map(lambda recon: (recon.astype(float) - baseline) * 100.0 / 255.0, recons)

    if args.relative_difference:
        divider = np.where(baseline > 0, baseline, 1)
        diffs = map(lambda diff: 100 * diff / divider, diffs)

    diffs = list(diffs)

    diff_min = min(map(lambda diff: diff.min(), diffs))
    diff_max = max(map(lambda diff: diff.max(), diffs))
    diff_min = -max(abs(diff_min), diff_max)
    diff_max = max(abs(diff_min), diff_max)

    diffs = map(lambda diff: (diff - diff_min) / (diff_max - diff_min), diffs)
    diffs = list(diffs)

    fig, axs = plt.subplots(2, 4, figsize=(16, 8))
    for ax in axs.flatten():
        ax.set_axis_off()

    mask = np.full((*baseline.shape, 4), 0, np.uint8)
    cx, cy = np.array(mask.shape[:2]) // 2
    mask = cv.circle(
        mask,
        center=(cx - 1, cy + 1),
        radius=cx - 2,
        color=(255, 255, 255, 255),
        thickness=cv.FILLED,
    )
    mask = mask == 255
    black = np.zeros(mask.shape, np.uint8)


    def colormap_and_mask(img, cm):
        _img = cm(img)
        return np.where(mask, _img, black)

    cm_plasma = mlp.colormaps["plasma"]
    cm_divergent = mlp.colormaps["RdYlBu"].reversed()
    axs[0, 0].imshow(colormap_and_mask(baseline, cm_plasma))
    axs[0, 0].set_title("AC" if args.baseline == "symbia" else "AC - CT")
    for i, (recon, diff, label) in enumerate(zip(recons, diffs, labels), start=1):
        axs[0, i].imshow(colormap_and_mask(recon, cm_plasma))
        axs[0, i].set_title(label)
        axs[1, i].imshow(colormap_and_mask(diff, cm_divergent))

    fig.colorbar(
        mlp.cm.ScalarMappable(
            norm=mlp.colors.Normalize(vmin=diff_min, vmax=diff_max), cmap=cm_divergent
        ),
        ax=axs[1, 1:4],
    )
    fig.colorbar(
        mlp.cm.ScalarMappable(
            norm=mlp.colors.Normalize(vmin=0, vmax=100), cmap=cm_plasma
        ),
        ax=axs[0, 1:4]
    )
    plt.show()