Skip to content
Snippets Groups Projects
measures.py 1.22 KiB
Newer Older
  • Learn to ignore specific revisions
  • import numpy as np
    import torch
    
    from mu_map.dataset.default import MuMapDataset
    from mu_map.dataset.normalization import MeanNormTransform
    from mu_map.models.unet import UNet
    
    def mse(prediction: np.array, target: np.array):
        pass
    
    def nmae(prediction: np.array, target: np.array):
        mae = np.absolute(prediction - target) / prediction.size
        nmae = mae / (target.max() - target.min())
        return nmae
    
    device = torch.device("cpu")
    model = UNet()
    model.load_state_dict(torch.load("xx.pth", map_location=device))
    model = model.eval()
    
    dataset = MuMapDataset("data/initial/", transform_normalization=MeanNormTransform(), split_name="validation")
    
    scores_mse = []
    scores_nmae = []
    for recon, mu_map in dataset:
        recon = recon.unsqueeze(dim=0).to(device)
    
        prediction = model(recon).squeeze().numpy()
        mu_map = mu_map.squeeze().numpy()
    
        scores_nmae.append(nmae(prediction, mu_map))
        scores_mse.append(mse(prediction, mu_map))
    scores_mse = np.array(scores_mse)
    scores_nmae = np.array(scores_nmae)
    
    mse_avg = scores_mse.mean()
    mse_std = np.std(scores_mse)
    
    nmae_avg = scores_nmae.mean()
    nmae_std = np.std(scores_nmae)
    
    print("Scores:")
    print(f" - NMAE: {nmae_avg:.6f}±{nmae_std:.6f}")
    print(f" -  MSE: {mse_avg:.6f}±{mse_std:.6f}")