Skip to content
Snippets Groups Projects
show_predictions.py 5.18 KiB
Newer Older
  • Learn to ignore specific revisions
  • import cv2 as cv
    import numpy as np
    import torch
    
    torch.set_grad_enabled(False)
    
    
    from mu_map.data.prepare import headers
    
    from mu_map.dataset.default import MuMapDataset
    from mu_map.dataset.normalization import (
        GaussianNormTransform,
        MeanNormTransform,
        MaxNormTransform,
    )
    from mu_map.dataset.transform import PadCropTranform, SequenceTransform
    from mu_map.eval.measures import nmae, mse
    from mu_map.models.unet import UNet
    from mu_map.util import to_grayscale, COLOR_WHITE
    from mu_map.vis.slices import join_images
    from mu_map.random_search.cgan import load_params
    
    
    
    def main(
        model: torch.nn.Module,
        dataset: MuMapDataset,
        wname: str = "Prediction",
        action: Callable[int, bool] = None,
        _print: bool = True,
    ):
    
        """
        Visualize the predictions of a model for all reconstructions
        in a dataset.
    
    
        Parameters
        ----------
        model: torch.nn.Module
            the mode with which predictions are computed
        dataset: MuMapDataset
            the dataset containing reconstructions for which images are computed and target attenuation maps for comparison
        wname: str
            the name of the display window
        action: Callable[int, bool]
            Add control behaviour by providing a callable reacting to key presses.
            If it returns true, the display is stopped.
        _print: bool
            if measures for predictions should be printed
    
        # disable print if not wanted
        print_func = print if _print else lambda x: x
    
        # print header
        print_func(" Id |     NMAE |      MSE")
        print_func("----|----------|---------")
    
    
        for i, (recon, mu_map) in enumerate(dataset):
    
            _id = dataset.table.iloc[i][headers.id]
    
    
            prediction = model(recon.unsqueeze(dim=0)).squeeze().numpy()
            mu_map = mu_map.squeeze().numpy()
    
            _nmae = nmae(prediction, mu_map)
            _mse = mse(prediction, mu_map)
    
            print_func(f"{_id:03d} | {_nmae:.6f} | {_mse:.6f}")
    
            prediction = np.clip(prediction, 0, mu_map.max())
    
            diff = np.abs(prediction - mu_map)
    
            volumes = [prediction, mu_map, diff]
            min_val = 0
            max_val = mu_map.max()
            n_slices = mu_map.shape[0]
    
            _slice = 0
    
            while True:
                images = map(lambda v: v[_slice], volumes)
                images = map(
                    lambda img: to_grayscale(img, min_val=min_val, max_val=max_val), images
                )
                images = map(lambda img: cv.resize(img, (512, 512)), images)
                images = list(images)
    
    
                if show_text:
                    txt = f"{str(_slice):{len(str(n_slices))}}/{n_slices}"
                    cv.putText(images[0], txt, (0, 30), cv.FONT_HERSHEY_SIMPLEX, 1, 255, 3)
    
    
                _slice = (_slice + 1) % n_slices
    
                cv.imshow(wname, join_images(images))
    
                if action is not None and action(key):
                    _break_outer = True
                    break
    
    
                if key == ord("q"):
                    exit(0)
                elif key == ord("n"):
                    break
                elif key == 81:
                    _slice = (_slice - 2) % n_slices
                elif key == ord("p"):
                    timeout = 0 if timeout > 0 else 100
    
                elif key == ord("t"):
                    show_text = not show_text
                    _slice = (_slice - 1) % n_slices
                elif key == ord("s"):
                    _slice = (_slice - 1) % n_slices
                    cv.imwrite("prediction.png", images[0])
                    cv.imwrite("mu_map.png", images[1])
                    cv.imwrite("difference.png", images[2])
                    with open("info.txt", mode="w") as f:
                        f.write(f"Arguments: {args}\n")
                        f.write("\n")
                        f.write(f"Id: {_id}\n")
                        f.write(f"Slice: {_slice}\n")
    
    
    if __name__ == "__main__":
        import argparse
        import os
    
        parser = argparse.ArgumentParser(
            description="visualize the results of a random search run",
            formatter_class=argparse.ArgumentDefaultsHelpFormatter,
        )
        parser.add_argument(
            "dir", type=str, help="directory containing the data of the random search run"
        )
        parser.add_argument(
            "--split",
            choices=["train", "validation", "test"],
            default="validation",
            help="the split of the dataset used",
        )
        args = parser.parse_args()
    
        params = load_params(os.path.join(args.dir, "params.json"))
        dataset = MuMapDataset(
            "data/second/",
            transform_normalization=SequenceTransform(
                [params["normalization"], PadCropTranform(dim=3, size=32)]
            ),
            split_name=args.split,
            scatter_correction=False,
        )
        device = torch.device("cpu")
        model = UNet(features=params["generator_features"])
        weights = torch.load(
            os.path.join(args.dir, "snapshots", "val_min_generator.pth"),
            map_location=device,
        )
        model.load_state_dict(weights)
        model = model.to(device).eval()
    
    
        cv.namedWindow(wname, cv.WINDOW_NORMAL)
        cv.resizeWindow(wname, 1600, 900)