Skip to content
Snippets Groups Projects
recon_ac.py 4.6 KiB
Newer Older
  • Learn to ignore specific revisions
  • from mu_map.eval.measures import nmae, mse
    
    if __name__ == "__main__":
        import argparse
        import json
        import os
    
        import numpy as np
        import pandas as pd
        import torch
    
        from mu_map.data.prepare import headers
        from mu_map.data.remove_bed import add_bed
        from mu_map.dataset.default import MuMapDataset
        from mu_map.dataset.util import load_dcm_img, align_images
        from mu_map.dataset.transform import SequenceTransform, PadCropTranform
        from mu_map.models.unet import UNet
        from mu_map.training.random_search import normalization_by_params, scatter_correction_by_params
        from mu_map.util import reconstruct
    
        parser = argparse.ArgumentParser(
            description="Compute, print and store measures for a given model based on the resulting reconstructions",
            formatter_class=argparse.ArgumentDefaultsHelpFormatter,
        )
        parser.add_argument(
            "--device",
            type=str,
            default="cuda",
            choices=["cpu", "cuda"],
            help="the device on which the model is evaluated (cpu or cuda)",
        )
        parser.add_argument(
            "--dir_train",
            type=str,
            required=True,
            help="directory where training results (snapshots, params) are stored",
        )
        parser.add_argument("--out", type=str, help="write results as a csv file")
    
        parser.add_argument(
            "--dataset_dir",
            type=str,
            default="data/second/",
            help="directory where the dataset is found",
        )
        parser.add_argument(
            "--split",
            type=str,
            default="validation",
            choices=["train", "test", "validation", "all"],
            help="the split of the dataset to be processed",
        )
        args = parser.parse_args()
    
        if args.split == "all":
            args.split = None
    
        torch.set_grad_enabled(False)
        device = torch.device(args.device)
    
        with open(os.path.join(args.dir_train, "params.json"), mode="r") as f:
            params = json.load(f)
        weights = os.path.join(args.dir_train, "snapshots", "val_min_generator.pth")
    
        model = UNet()
        model.load_state_dict(torch.load(weights, map_location=device))
        model = model.to(device).eval()
    
        transform_pad_crop = PadCropTranform(dim=3, size=32)
        transform_normalization = SequenceTransform(
            transforms=[
                normalization_by_params(params),
                transform_pad_crop,
            ]
        )
    
        dataset = MuMapDataset(
            args.dataset_dir,
            transform_normalization=transform_normalization,
            split_name=args.split,
            scatter_correction=scatter_correction_by_params(params),
        )
        dataset_with_bed = MuMapDataset(args.dataset_dir, transform_normalization=transform_pad_crop, split_name=args.split, bed_contours_file=None)
    
        values = pd.DataFrame({
                "NMAE_NAC_TO_AC": [],
                "NMAE_SYN_TO_AC": [],
                "NMAE_CT_TO_AC": [],
                "NMAE_NAC_TO_CT": [],
                "NMAE_SYN_TO_CT": [],
        })
        for i, ((recon, _), (recon_nac, mu_map_ct)) in enumerate(zip(dataset, dataset_with_bed)):
            print(
                f"Process input {str(i):>{len(str(len(dataset)))}}/{len(dataset)}", end="\r"
            )
            _row = dataset.table.iloc[i]
    
            mu_map_syn = model(recon.unsqueeze(dim=0).to(device))
            mu_map_syn = mu_map_syn.squeeze().cpu().numpy()
    
            mu_map_ct = mu_map_ct.squeeze().cpu().numpy()
            mu_map_syn = add_bed(mu_map_syn, mu_map_ct, bed_contour=dataset.bed_contours[_row["id"]])
    
            recon_nac = recon_nac.squeeze().cpu().numpy()
    
            recon_ac = load_dcm_img(os.path.join(dataset.dir_images, _row[headers.file_recon_ac_nsc]))
            recon_ac = torch.from_numpy(recon_ac)
            recon_ac, _ = transform_pad_crop(recon_ac, recon_ac)
            recon_ac = recon_ac.cpu().numpy()
    
            recon_ac_syn = reconstruct(recon_nac.copy(), mu_map=mu_map_syn.copy(), use_gpu=args.device=="cuda")
            recon_ac_ct = reconstruct(recon_nac.copy(), mu_map=mu_map_ct.copy(), use_gpu=args.device=="cuda")
    
            row = pd.DataFrame({
                "NMAE_NAC_TO_AC": [nmae(recon_nac, recon_ac)],
                "NMAE_SYN_TO_AC": [nmae(recon_ac_syn, recon_ac)],
                "NMAE_CT_TO_AC": [nmae(recon_ac_ct, recon_ac)],
                "NMAE_NAC_TO_CT": [nmae(recon_nac, recon_ac_ct)],
                "NMAE_SYN_TO_CT": [nmae(recon_ac_syn, recon_ac_ct)],
            })
            values = pd.concat((values, row), ignore_index=True)
        print(f" " * 100, end="\r")
    
        if args.out:
            values.to_csv(args.out, index=False)
    
        print("Scores:")
        for measure_name, measure_values in values.items():
            mean = measure_values.mean()
            std = np.std(measure_values)
            print(f" - {measure_name:>20}: {mean:.6f}±{std:.6f}")