import numpy as np
import torch

from mu_map.dataset.default import MuMapDataset
from mu_map.dataset.normalization import MeanNormTransform
from mu_map.models.unet import UNet

def mse(prediction: np.array, target: np.array):
    pass

def nmae(prediction: np.array, target: np.array):
    mae = np.absolute(prediction - target) / prediction.size
    nmae = mae / (target.max() - target.min())
    return nmae

device = torch.device("cpu")
model = UNet()
model.load_state_dict(torch.load("xx.pth", map_location=device))
model = model.eval()

dataset = MuMapDataset("data/initial/", transform_normalization=MeanNormTransform(), split_name="validation")

scores_mse = []
scores_nmae = []
for recon, mu_map in dataset:
    recon = recon.unsqueeze(dim=0).to(device)

    prediction = model(recon).squeeze().numpy()
    mu_map = mu_map.squeeze().numpy()

    scores_nmae.append(nmae(prediction, mu_map))
    scores_mse.append(mse(prediction, mu_map))
scores_mse = np.array(scores_mse)
scores_nmae = np.array(scores_nmae)

mse_avg = scores_mse.mean()
mse_std = np.std(scores_mse)

nmae_avg = scores_nmae.mean()
nmae_std = np.std(scores_nmae)

print("Scores:")
print(f" - NMAE: {nmae_avg:.6f}±{nmae_std:.6f}")
print(f" -  MSE: {mse_avg:.6f}±{mse_std:.6f}")