import numpy as np


def mse(prediction: np.array, target: np.array):
    se = (prediction - target) ** 2
    mse = se.sum() / se.size
    return mse


def nmae(prediction: np.array, target: np.array):
    mae = np.absolute(prediction - target) / prediction.size
    nmae = mae.sum() / (target.max() - target.min())
    return nmae


if __name__ == "__main__":
    import argparse

    import pandas as pd
    import torch

    from mu_map.dataset.default import MuMapDataset
    from mu_map.dataset.normalization import norm_by_str, norm_choices
    from mu_map.dataset.transform import SequenceTransform, PadCropTranform
    from mu_map.models.unet import UNet

    parser = argparse.ArgumentParser(
        description="Compute, print and store measures for a given model",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )
    parser.add_argument(
        "--device",
        type=str,
        default="cpu",
        choices=["cpu", "cuda"],
        help="the device on which the model is evaluated (cpu or cuda)",
    )
    parser.add_argument(
        "--weights",
        type=str,
        required=True,
        help="the model weights which should be scored",
    )
    parser.add_argument("--out", type=str, help="write results as a csv file")

    parser.add_argument(
        "--dataset_dir",
        type=str,
        default="data/initial/",
        help="directory where the dataset is found",
    )
    parser.add_argument(
        "--split",
        type=str,
        default="validation",
        choices=["train", "test", "validation", "all"],
        help="the split of the dataset to be processed",
    )
    parser.add_argument(
        "--norm",
        type=str,
        choices=["none", *norm_choices],
        default="mean",
        help="type of normalization applied to the reconstructions",
    )
    parser.add_argument(
        "--size",
        type=int,
        default=32,
        help="pad/crop the third tensor dimension to this value",
    )
    args = parser.parse_args()

    if args.split == "all":
        args.split = None

    torch.set_grad_enabled(False)

    device = torch.device(args.device)
    model = UNet()
    model.load_state_dict(torch.load(args.weights, map_location=device))
    model = model.to(device).eval()

    transform_normalization = SequenceTransform(
        transforms=[
            norm_by_str(args.norm),
            PadCropTranform(dim=3, size=args.size),
        ]
    )
    dataset = MuMapDataset(
        "data/initial/",
        transform_normalization=transform_normalization,
        split_name=args.split,
    )

    measures = {"NMAE": nmae, "MSE": mse}
    values = pd.Dataframe(map(lambda x: (x, []), measures.keys()))
    for i, (recon, mu_map) in enumerate(dataset):
        print(
            f"Process input {str(i):>{len(str(len(dataset)))}}/{len(dataset)}", end="\r"
        )
        prediction = model(recon.unsqueeze(dim=0).to(device))

        prediction = prediction.squeeze().cpu().numpy()
        mu_map = mu_map.squeeze().cpu().numpy()

        row = dict(
            map(lambda item: (item[0], item[1](prediction, mu_map)), measures.items)
        )
        values = values.append(row, ignore_index=True)
    print(f" " * 100, end="\r")

    print("Scores:")
    for measure_name, measure_values in values.items():
        mean = measure_values.mean()
        std = np.std(measure_values)
        print(f" - {measure_name:>6}: {mean:.6f}±{std:.6f}")