import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from mu_map.polar_map.prepare import headers

data = pd.read_csv("data/polar_maps/perfusion.csv")

baseline = data[data[headers.ac] & (data[headers.type] == "symbia")]

correction_none = data[~data[headers.ac]]
correction_syn = data[data[headers.ac] & (data[headers.type] == "synthetic")]
correction_ct = data[data[headers.ac] & (data[headers.type] == "ct")]

keys_segments = [f"segment_{i}" for i in range(1, 18)]

_baseline = baseline[keys_segments].values
_correction_none = correction_none[keys_segments].values
_correction_syn = correction_syn[keys_segments].values
_correction_ct = correction_ct[keys_segments].values


def absolute_percent_error(prediction: np.ndarray, target: np.ndarray) -> float:
    mean_p = prediction.mean(axis=0)
    mean_t = target.mean(axis=0)
    diff = np.absolute(mean_p - mean_t)
    return 100.0 * diff / mean_t


def apc(prediction: np.ndarray, target: np.ndarray) -> float:
    return absolute_percent_error(prediction, target)


combinations = [
    ("NoAC to AC", _correction_none, _baseline),
    ("SYN to AC", _correction_syn, _baseline),
    ("CT to AC", _correction_ct, _baseline),
    ("NoAC to CT", _correction_none, _correction_ct),
    ("SYN to CT", _correction_syn, _correction_ct),
]

for label, prediction, target in combinations:
    apcs = absolute_percent_error(prediction, target)
    _mean = apcs.mean()
    _std = apcs.std()
    print(f"APE - {label:>12}: {_mean:.3f}±{_std:.3f}")


n_plots = 5
plot_width = 4
fig, axs = plt.subplots(1, n_plots, figsize=(n_plots * plot_width + 1, plot_width + 1))
fig.suptitle("Correlation of Segment Perfusion")


def jitter(array: np.ndarray, scale=1):
    return array + (np.random.random(array.shape) * scale - 0.5 * scale)


axs[0].scatter(
    jitter(_baseline.flatten()),
    jitter(_correction_none.flatten()),
    c="#67a9cf",
    s=20,
    alpha=0.6,
    edgecolor="black",
)
axs[1].scatter(
    jitter(_baseline.flatten()),
    jitter(_correction_syn.flatten()),
    c="#67a9cf",
    s=20,
    alpha=0.6,
    edgecolor="black",
)
axs[2].scatter(
    jitter(_baseline.flatten()),
    jitter(_correction_ct.flatten()),
    c="#67a9cf",
    s=20,
    alpha=0.6,
    edgecolor="black",
)

axs[3].scatter(
    jitter(_correction_ct.flatten()),
    jitter(_correction_none.flatten()),
    c="#67a9cf",
    s=20,
    alpha=0.6,
    edgecolor="black",
)
axs[4].scatter(
    jitter(_correction_ct.flatten()),
    jitter(_correction_syn.flatten()),
    c="#67a9cf",
    s=20,
    alpha=0.6,
    edgecolor="black",
)

axs[0].set_xlabel("Perfusion AC")
axs[1].set_xlabel("Perfusion AC")
axs[2].set_xlabel("Perfusion AC")
axs[3].set_xlabel("Perfusion AC CT")
axs[4].set_xlabel("Perfusion AC CT")

axs[0].set_ylabel("Perfusion NoAC")
axs[1].set_ylabel("Perfusion AC SYN")
axs[2].set_ylabel("Perfusion AC CT")
axs[3].set_ylabel("Perfusion NoAC")
axs[4].set_ylabel("Perfusion AC SYN")

axs[0].set_title("  NoAC vs. AC")
axs[1].set_title("AC SYN vs. AC")
axs[2].set_title(" AC CT vs. AC")
axs[3].set_title("  NoAC vs. AC CT")
axs[4].set_title("AC SYN vs. AC CT")

for ax in axs:
    ax.set_xlim((30, 100))
    ax.set_ylim((30, 100))
    ax.plot((30, 100), (30, 100), color="#ef8a62")

plt.tight_layout()
plt.show()