import numpy as np import pandas as pd from mu_map.dataset.default import MuMapDataset from mu_map.models.unet import UNet def mse(prediction: np.array, target: np.array) -> float: """ Compute the mean squared error (MSE) between a prediction and a target array. Parameters ---------- prediction: np.ndarray target: np.ndarray """ se = (prediction - target) ** 2 mse = se.sum() / se.size return mse def nmae( prediction: np.array, target: np.array, vmax: float = None, vmin: float = None ): """ Compute the normalized mean absolute error (NMAE) between a prediction and a target array. Parameters ---------- prediction: np.ndarray target: np.ndarray vmax: float, optional maximum value for normalization, defaults to the maximal value in the target vmin: float, optional minimum value for normalization, defaults to the minimal value in the target """ vmax = target.max() if vmax is None else vmax vmin = target.min() if vmin is None else vmin ae = np.absolute(prediction - target) mae = ae.sum() / ae.size nmae = mae / (vmax - vmin) return nmae def compute_measures(dataset: MuMapDataset, model: UNet) -> pd.DataFrame: """ Compute measures (MSE, NMAE) for all images in a dataset. Parameters ---------- dataset: MuMapDataset the dataset containing the reconstructions and mu maps for which the scores are computed model: UNet the UNet model which is used to predict mu maps from reconstructions Returns ------- pd.DataFrame a dataframe containing containing the measures for each image in the dataset """ device = next(model.parameters()).device measures = {"NMAE": nmae, "MSE": mse} values = pd.DataFrame(dict(map(lambda x: (x, []), measures.keys()))) for i, (recon, mu_map) in enumerate(dataset): _id = dataset.table.iloc[i]["id"] print( f"Process input {str(i):>{len(str(len(dataset)))}}/{len(dataset)}", end="\r" ) prediction = model(recon.unsqueeze(dim=0).to(device)) prediction = prediction.squeeze().cpu().numpy() mu_map = mu_map.squeeze().cpu().numpy() row = dict( map(lambda item: (item[0], [item[1](prediction, mu_map)]), measures.items()) ) row["id"] = _id row = pd.DataFrame(row) values = pd.concat((values, row), ignore_index=True) print(f" " * 100, end="\r") return values if __name__ == "__main__": import argparse import torch from mu_map.dataset.normalization import norm_by_str, norm_choices from mu_map.dataset.transform import SequenceTransform, PadCropTranform parser = argparse.ArgumentParser( description="Compute, print and store measures for a given model", formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) parser.add_argument( "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", choices=["cpu", "cuda"], help="the device on which the model is evaluated (cpu or cuda)", ) parser.add_argument( "--weights", type=str, required=True, help="the model weights which should be scored", ) parser.add_argument("--out", type=str, help="write results as a csv file") parser.add_argument("--scatter_corrected", action="store_true") parser.add_argument( "--dataset_dir", type=str, default="data/second/", help="directory where the dataset is found", ) parser.add_argument( "--split", type=str, default="validation", choices=["train", "test", "validation", "all"], help="the split of the dataset to be processed", ) parser.add_argument( "--norm", type=str, choices=["none", *norm_choices], default="mean", help="type of normalization applied to the reconstructions", ) parser.add_argument( "--size", type=int, default=32, help="pad/crop the third tensor dimension to this value", ) args = parser.parse_args() if args.split == "all": args.split = None torch.set_grad_enabled(False) device = torch.device(args.device) # model = UNet(features=[32, 64, 128, 256, 512]) model = UNet(features=[64, 128, 256, 512]) model.load_state_dict(torch.load(args.weights, map_location=device)) model = model.to(device).eval() transform_normalization = SequenceTransform( transforms=[ norm_by_str(args.norm), PadCropTranform(dim=3, size=args.size), ] ) dataset = MuMapDataset( args.dataset_dir, transform_normalization=transform_normalization, split_name=args.split, scatter_correction=args.scatter_corrected, ) values = compute_measures(dataset, model) if args.out: values.to_csv(args.out, index=False) print("Scores:") for measure_name, measure_values in values.items(): if measure_name == "id": continue mean = measure_values.mean() std = np.std(measure_values) print(f" - {measure_name:>6}: {mean:.6f}±{std:.6f}")