Skip to content
Snippets Groups Projects
recon.py 8.16 KiB
Newer Older
  • Learn to ignore specific revisions
  • import os
    from typing import Dict, Optional, Tuple
    import tempfile
    
    import numpy as np
    import stir
    
    from mu_map.file import load_as_interfile
    from mu_map.file.interfile import (
        parse_interfile_header_str,
        load_interfile,
        write_interfile,
        KEY_DIM_1,
        KEY_DIM_2,
        KEY_SPACING_1,
        KEY_SPACING_2,
        TEMPLATE_HEADER_IMAGE,
    )
    
    
    TEMPLATE_RECON_PARAMS = """
    OSMAPOSLParameters :=
    
      objective function type:= PoissonLogLikelihoodWithLinearModelForMeanAndProjData
      PoissonLogLikelihoodWithLinearModelForMeanAndProjData Parameters:=
    
        input file := {PROJECTION}
    
        projector pair type := Matrix
          Projector Pair Using Matrix Parameters :=
          Matrix type := SPECT UB
          Projection Matrix By Bin SPECT UB Parameters:=
                ; width of PSF
                maximum number of sigmas:= 2.0
    
                ;PSF type of correction { 2D // 3D // Geometrical }
                psf type:= Geometrical
                ; next 2 parameters define the PSF. They are ignored if psf_type is "Geometrical"
                ; These values are mostly dependent on your collimator.
                ; the PSF is modelled as a Gaussian with sigma dependent on the distance from the collimator
                ; sigma_at_depth = collimator_slope * depth_in_cm + collimator sigma 0(cm)
                collimator slope := 0.0163
                collimator sigma 0(cm) := 0.1466
    
                ;Attenuation correction { Simple // Full // No }
                attenuation type := {ATTENUATION_TYPE}
                ;Values in attenuation map in cm-1
                attenuation map := {ATTENUATION_MAP}
    
                ;Mask properties { Cylinder // Attenuation Map // Explicit Mask // No}
                mask type := {MASK_TYPE}
                mask file := {MASK_FILE}
    
               ; if next variable is set to 0, only a single view is kept in memory
               keep all views in cache:=1
    
            End Projection Matrix By Bin SPECT UB Parameters:=
         End Projector Pair Using Matrix Parameters :=
      end PoissonLogLikelihoodWithLinearModelForMeanAndProjData Parameters:=
    
      initial estimate := {INIT_FILE}
      output filename prefix := {OUTPUT_PREFIX}
    
      number of subsets:= {N_SUBSETS}
      number of subiterations:= {N_SUBITERATIONS}
      Save estimates at subiteration intervals:= {SAVE_INTERVALS}
    
      ; keywords that specify the filtering that occurs after every subiteration
      ; warning: do not normally use together with a prior
      ;inter-iteration filter subiteration interval := 4
      ;inter-iteration filter type := Separable Gaussian
    
      ;post-filter type := Separable Gaussian
      ;separable gaussian filter parameters :=
        ;x-dir filter fwhm (in mm) := 6
        ;y-dir filter fwhm (in mm) := 6
        ;z-dir filter fwhm (in mm) := 6
        ;x-dir maximum kernel size := 129
        ;y-dir maximum kernel size := 129
        ;z-dir maximum kernel size := 31
        ;Normalise filter to 1 := 1
      ;end separable gaussian filter parameters :=
    
    
    END :=
    """
    
    
    def uniform_estimate(projection: Tuple[Dict[str, str], np.ndarray]):
        header_proj, image_proj = projection
    
        image = np.ones(
            (image_proj.shape[1], image_proj.shape[2], image_proj.shape[2]), np.float32
        )
    
        offset = -0.5 * image_proj.shape[2] * float(header_proj[KEY_SPACING_1])
        header = TEMPLATE_HEADER_IMAGE.replace("{ROWS}", str(image.shape[2]))
        header = header.replace("{COLUMNS}", str(image.shape[1]))
        header = header.replace("{SLICES}", str(image.shape[0]))
        header = header.replace("{SPACING_X}", header_proj[KEY_SPACING_1])
        header = header.replace("{SPACING_Y}", header_proj[KEY_SPACING_1])
        header = header.replace("{SPACING_Z}", header_proj[KEY_SPACING_2])
        header = header.replace("{OFFSET_X}", f"{offset:.4f}")
        header = header.replace("{OFFSET_Y}", f"{offset:.4f}")
        header = parse_interfile_header_str(header)
    
        return header, image
    
    
    def reconstruct(
        projection: Tuple[Dict[str, str], np.ndarray],
        mu_map: Optional[Tuple[Dict[str, str], np.ndarray]] = None,
        mask: Optional[Tuple[Dict[str, str], np.ndarray]] = None,
        init: Optional[Tuple[Dict[str, str], np.ndarray]] = None,
        n_subsets: Optional[int] = 4,
        n_iterations: Optional[int] = 10,
    
    ):
        # sanitize parameters
        n_subiterations = n_subsets * n_iterations
        save_intervals = n_subiterations
    
        dir_tmp = tempfile.TemporaryDirectory()
        filename_projection = os.path.join(dir_tmp.name, "projection.hv")
        write_interfile(filename_projection, *projection)
        params = TEMPLATE_RECON_PARAMS.replace("{PROJECTION}", filename_projection)
    
        output_prefix = os.path.join(dir_tmp.name, "out")
        filename_out = f"{output_prefix}_{save_intervals}.hv"
        params = params.replace("{OUTPUT_PREFIX}", output_prefix)
        params = params.replace("{N_SUBSETS}", str(n_subsets))
        params = params.replace("{N_SUBITERATIONS}", str(n_subiterations))
        params = params.replace("{SAVE_INTERVALS}", str(save_intervals))
    
        if mu_map is not None:
            filename_mu_map = os.path.join(dir_tmp.name, "mu_map.hv")
            write_interfile(filename_mu_map, *mu_map)
            params = params.replace("{ATTENUATION_TYPE}", "Full")
            params = params.replace("{ATTENUATION_MAP}", filename_mu_map)
        else:
            params = params.replace("{ATTENUATION_TYPE}", "No")
            params = params.replace("attenuation map", ";attenuation map")
    
        if mask is not None:
            filename_mask = os.path.join(dir_tmp.name, "mask.hv")
            write_interfile(filename_mask, *mask)
            params = params.replace("{MASK_TYPE}", "Explicit Mask")
            params = params.replace("{MASK_FILE}", filename_mask)
        else:
            params = params.replace("mask file", ";mask file")
            params = params.replace(
                "{MASK_TYPE}", "Attenuation Map" if mu_map is not None else "No"
            )
    
        init = uniform_estimate(projection) if init is None else init
        filename_init = os.path.join(dir_tmp.name, "init.hv")
        write_interfile(filename_init, *init)
        params = params.replace("{INIT_FILE}", filename_init)
    
        filename_params = os.path.join(dir_tmp.name, "OSEM_SPECT.par")
        with open(filename_params, mode="w") as f:
            f.write(params.strip())
    
        recon = stir.OSMAPOSLReconstruction3DFloat(filename_params)
        recon.reconstruct()
    
        print(params)
        return load_interfile(filename_out)
    
    
    if __name__ == "__main__":
        import argparse
    
        from mu_map.file import load_as_interfile
        from mu_map.recon.project import forward_project
    
        parser = argparse.ArgumentParser(
            description="Reconstruct a projection or another reconstruction",
            formatter_class=argparse.ArgumentDefaultsHelpFormatter,
        )
        parser.add_argument(
            "--projection", type=str, help="a projection in INTERFILE format"
        )
        parser.add_argument(
            "--recon",
            type=str,
            help="a reconstruction in DICOM or INTERFILE format - if specified the projection will be overwritten",
        )
        parser.add_argument(
            "--mu_map",
            type=str,
            help="a mu map for attenuation correction in DICOM or INTERFILE format",
        )
        parser.add_argument(
            "--out", type=str, help="the filename to store the reconstruction"
        )
        parser.add_argument(
            "--n_subsets",
            type=int,
            default=4,
            help="the number of subsets for OSEM reconstruction",
        )
        parser.add_argument(
            "--n_iterations",
            type=int,
            default=10,
            help="the number of iterations for OSEM reconstruction",
        )
        parser.add_argument(
            "-v", "--verbosity", type=int, default=0, help="configure the verbosity of STIR"
        )
    
        args = parser.parse_args()
        assert (
            args.projection is not None or args.recon is not None
        ), "You have to specify either a projection or a reconstruction"
        stir.Verbosity_set(args.verbosity)
    
        mu_map = load_as_interfile(args.mu_map) if args.mu_map else None
        mu_map_slices = None if mu_map is None else mu_map[1].shape[0]
    
        if args.recon:
            recon = load_as_interfile(args.recon)
            projection = forward_project(*recon, n_slices=mu_map_slices)
        else:
            projection = load_as_interfile(args.projection)
    
    
        kwargs = vars(args)
        del kwargs["mu_map"]
        del kwargs["projection"]
        print(kwargs)
        header, image = reconstruct(projection, mu_map=mu_map, **kwargs)
    
        image = image[:, :, ::-1]
    
        write_interfile(args.out, header, image)