import numpy as np
import pandas as pd
import torch

from mu_map.dataset.default import MuMapDataset
from mu_map.dataset.normalization import MeanNormTransform
from mu_map.dataset.transform import SequenceTransform, PadCropTranform
from mu_map.models.unet import UNet

torch.set_grad_enabled(False)

def mse(prediction: np.array, target: np.array):
    se = (prediction - target) ** 2
    mse = se.sum() / se.size
    return mse

def nmae(prediction: np.array, target: np.array):
    mae = np.absolute(prediction - target) / prediction.size
    nmae = mae.sum() / (target.max() - target.min())
    return nmae

device = torch.device("cuda:0")
model = UNet()
model = model.to(device)
model.load_state_dict(torch.load("trainings/03_cgan/snapshots/50_generator.pth", map_location=device))
model = model.eval()

transform_normalization = SequenceTransform(transforms=[MeanNormTransform(), PadCropTranform(dim=3, size=32)])
dataset = MuMapDataset("data/initial/", transform_normalization=transform_normalization, split_name="validation")

scores_mse = []
scores_nmae = []
for i, (recon, mu_map) in enumerate(dataset):
    print(f"{i:02d}/{len(dataset)}", end="\r")
    recon = recon.unsqueeze(dim=0).to(device)
    prediction = model(recon).squeeze().cpu().numpy()
    mu_map = mu_map.squeeze().cpu().numpy()

    scores_nmae.append(nmae(prediction, mu_map))
    scores_mse.append(mse(prediction, mu_map))
scores_mse = np.array(scores_mse)
scores_nmae = np.array(scores_nmae)

mse_avg = scores_mse.mean()
mse_std = np.std(scores_mse)

nmae_avg = scores_nmae.mean()
nmae_std = np.std(scores_nmae)

print("Scores:")
print(f" - NMAE: {nmae_avg:.6f}±{nmae_std:.6f}")
print(f" -  MSE: {mse_avg:.6f}±{mse_std:.6f}")