Skip to content
Snippets Groups Projects
show_predictions.py 3.48 KiB
Newer Older
  • Learn to ignore specific revisions
  • import cv2 as cv
    import numpy as np
    import torch
    
    torch.set_grad_enabled(False)
    
    from mu_map.dataset.default import MuMapDataset
    from mu_map.dataset.normalization import (
        GaussianNormTransform,
        MeanNormTransform,
        MaxNormTransform,
    )
    from mu_map.dataset.transform import PadCropTranform, SequenceTransform
    from mu_map.eval.measures import nmae, mse
    from mu_map.models.unet import UNet
    from mu_map.util import to_grayscale, COLOR_WHITE
    from mu_map.vis.slices import join_images
    from mu_map.random_search.cgan import load_params
    
    
    def main(model: torch.nn.Module, dataset: MuMapDataset):
        """
        Visualize the predictions of a model for all reconstructions
        in a dataset.
        """
        timeout = 100
    
        print(" Id |     NMAE |      MSE")
        print("----|----------|---------")
    
        for i, (recon, mu_map) in enumerate(dataset):
            _id = dataset.table.iloc[i]["id"]
    
            prediction = model(recon.unsqueeze(dim=0)).squeeze().numpy()
            mu_map = mu_map.squeeze().numpy()
    
            _nmae = nmae(prediction, mu_map)
            _mse = mse(prediction, mu_map)
            print(f"{_id:03d} | {_nmae:.6f} | {_mse:.6f}")
    
            prediction = np.clip(prediction, 0, prediction.max())
            diff = np.abs(prediction - mu_map)
    
            volumes = [prediction, mu_map, diff]
            min_val = 0
            max_val = mu_map.max()
            n_slices = mu_map.shape[0]
    
            _slice = 0
            while True:
                images = map(lambda v: v[_slice], volumes)
                images = map(
                    lambda img: to_grayscale(img, min_val=min_val, max_val=max_val), images
                )
                images = map(lambda img: cv.resize(img, (512, 512)), images)
                images = list(images)
    
                txt = f"{str(_slice):{len(str(n_slices))}}/{n_slices}"
                cv.putText(images[0], txt, (0, 30), cv.FONT_HERSHEY_SIMPLEX, 1, 255, 3)
    
                _slice = (_slice + 1) % n_slices
    
                cv.imshow(wname, join_images(images))
                key = cv.waitKey(100)
                if key == ord("q"):
                    exit(0)
                elif key == ord("n"):
                    break
                elif key == 81:
                    _slice = (_slice - 2) % n_slices
                elif key == ord("p"):
                    timeout = 0 if timeout > 0 else 100
    
    
    if __name__ == "__main__":
        import argparse
        import os
    
        parser = argparse.ArgumentParser(
            description="visualize the results of a random search run",
            formatter_class=argparse.ArgumentDefaultsHelpFormatter,
        )
        parser.add_argument(
            "dir", type=str, help="directory containing the data of the random search run"
        )
        parser.add_argument(
            "--split",
            choices=["train", "validation", "test"],
            default="validation",
            help="the split of the dataset used",
        )
        args = parser.parse_args()
    
        params = load_params(os.path.join(args.dir, "params.json"))
        dataset = MuMapDataset(
            "data/second/",
            transform_normalization=SequenceTransform(
                [params["normalization"], PadCropTranform(dim=3, size=32)]
            ),
            split_name=args.split,
            scatter_correction=False,
        )
        device = torch.device("cpu")
        model = UNet(features=params["generator_features"])
        weights = torch.load(
            os.path.join(args.dir, "snapshots", "val_min_generator.pth"),
            map_location=device,
        )
        model.load_state_dict(weights)
        model = model.to(device).eval()
    
        wname = "Dataset"
        cv.namedWindow(wname, cv.WINDOW_NORMAL)
        cv.resizeWindow(wname, 1600, 900)
    
        main(model, dataset)