import numpy as np
import pandas as pd

from mu_map.dataset.default import MuMapDataset
from mu_map.models.unet import UNet


def mse(prediction: np.array, target: np.array) -> float:
    """
    Compute the mean squared error (MSE) between a prediction and
    a target array.

    Parameters
    ----------
    prediction: np.ndarray
    target: np.ndarray
    """
    se = (prediction - target) ** 2
    mse = se.sum() / se.size
    return mse


def nmae(
    prediction: np.array, target: np.array, vmax: float = None, vmin: float = None
):
    """
    Compute the normalized mean absolute error (NMAE) between a prediction
    and a target array.

    Parameters
    ----------
    prediction: np.ndarray
    target: np.ndarray
    vmax: float, optional
        maximum value for normalization, defaults to the maximal value in the target
    vmin: float, optional
        minimum value for normalization, defaults to the minimal value in the target
    """
    vmax = target.max() if vmax is None else vmax
    vmin = target.min() if vmin is None else vmin

    ae = np.absolute(prediction - target)
    mae = ae.sum() / ae.size
    nmae = mae / (vmax - vmin)
    return nmae


def compute_measures(dataset: MuMapDataset, model: UNet) -> pd.DataFrame:
    """
    Compute measures (MSE, NMAE) for all images in a dataset.

    Parameters
    ----------
    dataset: MuMapDataset
        the dataset containing the reconstructions and mu maps for which the scores are computed
    model: UNet
        the UNet model which is used to predict mu maps from reconstructions

    Returns
    -------
    pd.DataFrame
        a dataframe containing containing the measures for each image in the dataset
    """
    device = next(model.parameters()).device

    measures = {"NMAE": nmae, "MSE": mse}
    values = pd.DataFrame(dict(map(lambda x: (x, []), measures.keys())))
    for i, (recon, mu_map) in enumerate(dataset):
        _id = dataset.table.iloc[i]["id"]
        print(
            f"Process input {str(i):>{len(str(len(dataset)))}}/{len(dataset)}", end="\r"
        )
        prediction = model(recon.unsqueeze(dim=0).to(device))

        prediction = prediction.squeeze().cpu().numpy()
        mu_map = mu_map.squeeze().cpu().numpy()

        row = dict(
            map(lambda item: (item[0], [item[1](prediction, mu_map)]), measures.items())
        )
        row["id"] = _id
        row = pd.DataFrame(row)
        values = pd.concat((values, row), ignore_index=True)
    print(f" " * 100, end="\r")
    return values


if __name__ == "__main__":
    import argparse

    import torch

    from mu_map.dataset.normalization import norm_by_str, norm_choices
    from mu_map.dataset.transform import SequenceTransform, PadCropTranform

    parser = argparse.ArgumentParser(
        description="Compute, print and store measures for a given model",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )
    parser.add_argument(
        "--device",
        type=str,
        default="cuda" if torch.cuda.is_available() else "cpu",
        choices=["cpu", "cuda"],
        help="the device on which the model is evaluated (cpu or cuda)",
    )
    parser.add_argument(
        "--weights",
        type=str,
        required=True,
        help="the model weights which should be scored",
    )
    parser.add_argument("--out", type=str, help="write results as a csv file")
    parser.add_argument("--scatter_corrected", action="store_true")

    parser.add_argument(
        "--dataset_dir",
        type=str,
        default="data/second/",
        help="directory where the dataset is found",
    )
    parser.add_argument(
        "--split",
        type=str,
        default="validation",
        choices=["train", "test", "validation", "all"],
        help="the split of the dataset to be processed",
    )
    parser.add_argument(
        "--norm",
        type=str,
        choices=["none", *norm_choices],
        default="mean",
        help="type of normalization applied to the reconstructions",
    )
    parser.add_argument(
        "--size",
        type=int,
        default=32,
        help="pad/crop the third tensor dimension to this value",
    )
    args = parser.parse_args()

    if args.split == "all":
        args.split = None

    torch.set_grad_enabled(False)

    device = torch.device(args.device)
    # model = UNet(features=[32, 64, 128, 256, 512])
    model = UNet(features=[64, 128, 256, 512])
    model.load_state_dict(torch.load(args.weights, map_location=device))
    model = model.to(device).eval()

    transform_normalization = SequenceTransform(
        transforms=[
            norm_by_str(args.norm),
            PadCropTranform(dim=3, size=args.size),
        ]
    )
    dataset = MuMapDataset(
        args.dataset_dir,
        transform_normalization=transform_normalization,
        split_name=args.split,
        scatter_correction=args.scatter_corrected,
    )

    values = compute_measures(dataset, model)

    if args.out:
        values.to_csv(args.out, index=False)

    print("Scores:")
    for measure_name, measure_values in values.items():
        if measure_name == "id":
            continue

        mean = measure_values.mean()
        std = np.std(measure_values)
        print(f" - {measure_name:>6}: {mean:.6f}±{std:.6f}")