Skip to content
Snippets Groups Projects
measures.py 5.19 KiB
Newer Older
  • Learn to ignore specific revisions
  • import numpy as np
    
    import pandas as pd
    
    from mu_map.dataset.default import MuMapDataset
    from mu_map.models.unet import UNet
    
    
    def mse(prediction: np.array, target: np.array) -> float:
        """
        Compute the mean squared error (MSE) between a prediction and
        a target array.
    
        Parameters
        ----------
        prediction: np.ndarray
        target: np.ndarray
        """
    
    Tamino Huxohl's avatar
    Tamino Huxohl committed
        se = (prediction - target) ** 2
        mse = se.sum() / se.size
        return mse
    
    def nmae(
        prediction: np.array, target: np.array, vmax: float = None, vmin: float = None
    ):
        """
        Compute the normalized mean absolute error (NMAE) between a prediction
        and a target array.
    
        Parameters
        ----------
        prediction: np.ndarray
        target: np.ndarray
        vmax: float, optional
            maximum value for normalization, defaults to the maximal value in the target
        vmin: float, optional
            minimum value for normalization, defaults to the minimal value in the target
        """
        vmax = target.max() if vmax is None else vmax
        vmin = target.min() if vmin is None else vmin
    
    
        ae = np.absolute(prediction - target)
        mae = ae.sum() / ae.size
        nmae = mae / (vmax - vmin)
    
    def compute_measures(dataset: MuMapDataset, model: UNet) -> pd.DataFrame:
        """
        Compute measures (MSE, NMAE) for all images in a dataset.
    
        Parameters
        ----------
        dataset: MuMapDataset
            the dataset containing the reconstructions and mu maps for which the scores are computed
        model: UNet
            the UNet model which is used to predict mu maps from reconstructions
    
        Returns
        -------
        pd.DataFrame
            a dataframe containing containing the measures for each image in the dataset
        """
    
        device = next(model.parameters()).device
    
    
        measures = {"NMAE": nmae, "MSE": mse}
        values = pd.DataFrame(dict(map(lambda x: (x, []), measures.keys())))
        for i, (recon, mu_map) in enumerate(dataset):
            _id = dataset.table.iloc[i]["id"]
            print(
                f"Process input {str(i):>{len(str(len(dataset)))}}/{len(dataset)}", end="\r"
            )
            prediction = model(recon.unsqueeze(dim=0).to(device))
    
            prediction = prediction.squeeze().cpu().numpy()
            mu_map = mu_map.squeeze().cpu().numpy()
    
            row = dict(
                map(lambda item: (item[0], [item[1](prediction, mu_map)]), measures.items())
            )
            row["id"] = _id
            row = pd.DataFrame(row)
            values = pd.concat((values, row), ignore_index=True)
        print(f" " * 100, end="\r")
        return values
    
    
    
    if __name__ == "__main__":
        import argparse
    
        import torch
    
        from mu_map.dataset.normalization import norm_by_str, norm_choices
        from mu_map.dataset.transform import SequenceTransform, PadCropTranform
    
        parser = argparse.ArgumentParser(
            description="Compute, print and store measures for a given model",
            formatter_class=argparse.ArgumentDefaultsHelpFormatter,
        )
        parser.add_argument(
            "--device",
            type=str,
    
            default="cuda" if torch.cuda.is_available() else "cpu",
    
            choices=["cpu", "cuda"],
            help="the device on which the model is evaluated (cpu or cuda)",
        )
        parser.add_argument(
            "--weights",
            type=str,
            required=True,
            help="the model weights which should be scored",
        )
        parser.add_argument("--out", type=str, help="write results as a csv file")
    
        parser.add_argument("--scatter_corrected", action="store_true")
    
    
        parser.add_argument(
            "--dataset_dir",
            type=str,
    
            default="data/second/",
    
            help="directory where the dataset is found",
        )
        parser.add_argument(
            "--split",
            type=str,
            default="validation",
            choices=["train", "test", "validation", "all"],
            help="the split of the dataset to be processed",
        )
        parser.add_argument(
            "--norm",
            type=str,
            choices=["none", *norm_choices],
            default="mean",
            help="type of normalization applied to the reconstructions",
        )
        parser.add_argument(
            "--size",
            type=int,
            default=32,
            help="pad/crop the third tensor dimension to this value",
        )
        args = parser.parse_args()
    
        if args.split == "all":
    
    Tamino Huxohl's avatar
    Tamino Huxohl committed
            args.split = None
    
    
        torch.set_grad_enabled(False)
    
        device = torch.device(args.device)
    
        # model = UNet(features=[32, 64, 128, 256, 512])
        model = UNet(features=[64, 128, 256, 512])
    
        model.load_state_dict(torch.load(args.weights, map_location=device))
        model = model.to(device).eval()
    
        transform_normalization = SequenceTransform(
            transforms=[
    
    Tamino Huxohl's avatar
    Tamino Huxohl committed
                norm_by_str(args.norm),
    
                PadCropTranform(dim=3, size=args.size),
            ]
        )
        dataset = MuMapDataset(
    
            transform_normalization=transform_normalization,
            split_name=args.split,
    
            scatter_correction=args.scatter_corrected,
    
        values = compute_measures(dataset, model)
    
        if args.out:
            values.to_csv(args.out, index=False)
    
    
        print("Scores:")
        for measure_name, measure_values in values.items():
    
            if measure_name == "id":
                continue
    
    
            mean = measure_values.mean()
            std = np.std(measure_values)
    
    Tamino Huxohl's avatar
    Tamino Huxohl committed
            print(f" - {measure_name:>6}: {mean:.6f}±{std:.6f}")