import argparse
import os

import cv2 as cv
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from mu_map.polar_map.prepare import headers


def get_circular_mark(shape: np.ndarray, channels:int=1):
    mask = np.full((*shape, channels), 0, np.uint8)
    cx, cy = np.array(mask.shape[:2]) // 2
    mask = cv.circle(
        mask,
        center=(cx-1, cy),
        radius=cx - 2,
        color=(255,) * channels,
        thickness=cv.FILLED,
    )
    mask = mask == 255
    return mask[:, :, 0] if channels == 1 else mask

parser = argparse.ArgumentParser(description="Visualize polar maps of different reconstructions", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("--polar_map_dir", type=str, default="data/polar_maps", help="directory containing the polar map images")
parser.add_argument("--images_dir", type=str, default="images", help="sub-directory under <polar_map_dir> containing the actual image files")
parser.add_argument("--csv", type=str, default="polar_maps.csv", help="file under <polar_map_dir> containing meta information about the polar maps")
parser.add_argument("--baseline", choices=["symbia", "stir"], default="symbia", help="select the polar map treated as the baseline")
parser.add_argument("--relative_difference", action="store_true", help="show the difference as a percentage difference to the baseline")
args = parser.parse_args()

args.images_dir = os.path.join(args.polar_map_dir, args.images_dir)
args.csv = os.path.join(args.polar_map_dir, args.csv)

meta = pd.read_csv(args.csv)
ids = meta[headers.id].unique()

dists_abs = {"No AC": [], "AC SYN": [], "AC CT": []}
dists_per = {"No AC": [], "AC SYN": [], "AC CT": []}
recons_tot = {"No AC": [], "AC SYN": [], "AC CT": []}
baseline_tot = []
 
for _id in ids:
    print(_id)
    _meta = meta[(meta[headers.id] == _id) & ~(meta[headers.segments])]

    file_recon_ac = _meta[(_meta[headers.type] == "symbia") & _meta[headers.ac]][headers.file].values[0]
    file_recon_nac = _meta[~_meta[headers.ac]][headers.file].values[0]
    file_recon_syn = _meta[_meta[headers.type] == "synthetic"][headers.file].values[0]
    file_recon_ct = _meta[_meta[headers.type] == "ct"][headers.file].values[0]

    recon_ac = cv.imread(os.path.join(args.images_dir, file_recon_ac), cv.IMREAD_GRAYSCALE)
    recon_nac = cv.imread(os.path.join(args.images_dir, file_recon_nac), cv.IMREAD_GRAYSCALE)
    recon_syn = cv.imread(os.path.join(args.images_dir, file_recon_syn), cv.IMREAD_GRAYSCALE)
    recon_ct = cv.imread(os.path.join(args.images_dir, file_recon_ct), cv.IMREAD_GRAYSCALE)

    baseline = recon_ac.copy() if args.baseline == "symbia" else recon_ct.copy()
    recons = [recon_nac, recon_syn, recon_ct]
    labels = ["No AC", "AC SYN", "AC CT"]

    baseline = baseline * (100.0 / 255.0)
    recons = list(map(lambda recon: recon * (100.0 / 255.0), recons))

    mask = get_circular_mark(baseline.shape)
    _baseline = baseline[mask]
    mask2 = _baseline != 0
    _baseline = _baseline[mask2]
    baseline_tot.append(_baseline)
    for recon, label in zip(recons, labels):
        _recon = recon[mask]
        assert _recon[np.logical_not(mask2)].sum() == 0
        _recon = _recon[mask2]
        recons_tot[label].append(_recon)

        dist_abs = np.absolute(_recon - _baseline)
        dist_percentage = 100 * dist_abs / _baseline
        # print(f" - {label:>6} -   Absolute: {dist_abs.mean():.3f}+-{dist_abs.std():.3f}")
        # print(f" - {label:>6} - Percentage: {dist_percentage.mean():.3f}+-{dist_percentage.std():.3f}")

        dists_abs[label].append(dist_abs.mean())
        dists_per[label].append(dist_percentage.mean())

print("Total:")
for label in dists_abs.keys():
    dist_abs = np.array(dists_abs[label])
    dist_percentage = np.array(dists_per[label])
    print(f" - {label:>6} -   Absolute: {dist_abs.mean():.3f}+-{dist_abs.std():.3f}")
    print(f" - {label:>6} - Percentage: {dist_percentage.mean():.3f}+-{dist_percentage.std():.3f}")


baseline_tot = np.concatenate(baseline_tot)

bins_x = np.arange(0, 100.1, 2)
bins_y = np.arange(0, 100.1, 2)
fig, axs = plt.subplots(2, 3, figsize=(15, 10))
for i, (label, recon) in enumerate(recons_tot.items()):
    recon = np.concatenate(recon)
    h, xedges, yedges, image = axs[0, i].hist2d(baseline_tot, recon, bins=(bins_x, bins_y))

    h = (h - h.min()) / (h.max() - h.min())
    img = mpl.colormaps["viridis"](h)
    img = cv.line(img, (0, 0), (50, 50), (1.0, 0.0, 0.0, 0.3), 1, lineType=cv.LINE_AA)
    # img = img[::-1, ::-1]
    # axs[1, i].imshow(img, origin="lower")
    img = np.rot90(img)
    axs[1, i].imshow(img)
    axs[1, i].set_axis_off()

    axs[0, i].set_title(label)
    axs[0, i].set_xlabel("Baseline")
    axs[0, i].set_ylabel("Prediction")

plt.show()