Skip to content
Snippets Groups Projects
osem.py 9.78 KiB
Newer Older
  • Learn to ignore specific revisions
  • import os
    from typing import Dict, Optional, Tuple
    import tempfile
    
    import numpy as np
    import stir
    
    from mu_map.file.interfile import (
    
        Interfile,
    
        parse_interfile_header_str,
        load_interfile,
        write_interfile,
    
        TEMPLATE_HEADER_IMAGE,
    )
    
    from mu_map.recon.filter import GaussianFilter
    
    """
    Template for a STIR OSEM reconstruction configuration.
    """
    
    TEMPLATE_RECON_PARAMS = """
    OSMAPOSLParameters :=
    
      objective function type:= PoissonLogLikelihoodWithLinearModelForMeanAndProjData
      PoissonLogLikelihoodWithLinearModelForMeanAndProjData Parameters:=
    
        input file := {PROJECTION}
    
        projector pair type := Matrix
          Projector Pair Using Matrix Parameters :=
          Matrix type := SPECT UB
          Projection Matrix By Bin SPECT UB Parameters:=
                ; width of PSF
                maximum number of sigmas:= 2.0
    
                ;PSF type of correction { 2D // 3D // Geometrical }
                psf type:= Geometrical
                ; next 2 parameters define the PSF. They are ignored if psf_type is "Geometrical"
                ; These values are mostly dependent on your collimator.
                ; the PSF is modelled as a Gaussian with sigma dependent on the distance from the collimator
                ; sigma_at_depth = collimator_slope * depth_in_cm + collimator sigma 0(cm)
                collimator slope := 0.0163
                collimator sigma 0(cm) := 0.1466
    
                ;Attenuation correction { Simple // Full // No }
                attenuation type := {ATTENUATION_TYPE}
                ;Values in attenuation map in cm-1
                attenuation map := {ATTENUATION_MAP}
    
                ;Mask properties { Cylinder // Attenuation Map // Explicit Mask // No}
                mask type := {MASK_TYPE}
                mask file := {MASK_FILE}
    
               ; if next variable is set to 0, only a single view is kept in memory
               keep all views in cache:=1
    
            End Projection Matrix By Bin SPECT UB Parameters:=
         End Projector Pair Using Matrix Parameters :=
      end PoissonLogLikelihoodWithLinearModelForMeanAndProjData Parameters:=
    
      initial estimate := {INIT_FILE}
      output filename prefix := {OUTPUT_PREFIX}
    
      number of subsets:= {N_SUBSETS}
      number of subiterations:= {N_SUBITERATIONS}
      Save estimates at subiteration intervals:= {SAVE_INTERVALS}
    
    END :=
    """
    
    
    
    def uniform_estimate(projection: Interfile) -> Interfile:
        """
        Create a uniform estimate (image with all ones) for a projection.
        This serves as the initial estimate for the EM reconstruction.
    
        :param projection: the projection for which the uniform estimate is created
        :return: an interfile image with all ones in a size matching the projection
        """
    
        header_proj, image_proj = projection
    
        image = np.ones(
            (image_proj.shape[1], image_proj.shape[2], image_proj.shape[2]), np.float32
        )
    
    
        offset = -0.5 * image_proj.shape[2] * float(header_proj[InterfileKeys.spacing(1)])
    
        header = TEMPLATE_HEADER_IMAGE.replace("{ROWS}", str(image.shape[2]))
        header = header.replace("{COLUMNS}", str(image.shape[1]))
        header = header.replace("{SLICES}", str(image.shape[0]))
    
        header = header.replace("{SPACING_X}", header_proj[InterfileKeys.spacing(1)])
        header = header.replace("{SPACING_Y}", header_proj[InterfileKeys.spacing(1)])
        header = header.replace("{SPACING_Z}", header_proj[InterfileKeys.spacing(1)])
    
        header = header.replace("{OFFSET_X}", f"{offset:.4f}")
        header = header.replace("{OFFSET_Y}", f"{offset:.4f}")
    
        header = header.replace(
            "{PATIENT_ORIENTATION}", header_proj[InterfileKeys.patient_orientation]
        )
        header = header.replace(
            "{PATIENT_ROTATION}", header_proj[InterfileKeys.patient_rotation]
        )
    
        header = parse_interfile_header_str(header)
    
        return Interfile(header, image)
    
    
    
    def reconstruct(
    
        projection: Interfile,
        mu_map: Optional[Interfile] = None,
        mask: Optional[Interfile] = None,
        init: Optional[Interfile] = None,
    
        n_subsets: Optional[int] = 4,
        n_iterations: Optional[int] = 10,
    
        postfilter: Optional[GaussianFilter] = None,
    
        """
        Perform OSEM reconstruction with STIR.
    
        :param projection: the projection to reconstruct
        :param mu_map: the attenuation map used for reconstruction
        :param mask: a mask defining which voxels should be reconstructed
                        - if not given all voxels are reconstructed
                        - if an attenuation map is given only positive pixels in the
                          attenuation map are reconstructed
        :param init: an initial estimate for the reconstruction which defaults to
                     all ones
        :param n_subsets: number of subsets
        :param n_iterations: number of iterations
        :param postfilter: optional filter applied after the reconstruction
        :returns: a reconstruction of the projection
        """
        # sanitize parameters for STIR
    
        n_subiterations = n_subsets * n_iterations
        save_intervals = n_subiterations
    
        dir_tmp = tempfile.TemporaryDirectory()
        filename_projection = os.path.join(dir_tmp.name, "projection.hv")
    
        write_interfile(filename_projection, projection)
    
        params = TEMPLATE_RECON_PARAMS.replace("{PROJECTION}", filename_projection)
    
        output_prefix = os.path.join(dir_tmp.name, "out")
        filename_out = f"{output_prefix}_{save_intervals}.hv"
        params = params.replace("{OUTPUT_PREFIX}", output_prefix)
        params = params.replace("{N_SUBSETS}", str(n_subsets))
        params = params.replace("{N_SUBITERATIONS}", str(n_subiterations))
        params = params.replace("{SAVE_INTERVALS}", str(save_intervals))
    
    
        # prepare attenuation parameters
    
        if mu_map is not None:
            filename_mu_map = os.path.join(dir_tmp.name, "mu_map.hv")
    
            write_interfile(filename_mu_map, mu_map)
    
            params = params.replace("{ATTENUATION_TYPE}", "Full")
            params = params.replace("{ATTENUATION_MAP}", filename_mu_map)
        else:
            params = params.replace("{ATTENUATION_TYPE}", "No")
            params = params.replace("attenuation map", ";attenuation map")
    
    
        # prepare mask parameters
    
        if mask is not None:
            filename_mask = os.path.join(dir_tmp.name, "mask.hv")
    
            write_interfile(filename_mask, mask)
    
            params = params.replace("{MASK_TYPE}", "Explicit Mask")
            params = params.replace("{MASK_FILE}", filename_mask)
        else:
            params = params.replace("mask file", ";mask file")
            params = params.replace(
                "{MASK_TYPE}", "Attenuation Map" if mu_map is not None else "No"
            )
    
        init = uniform_estimate(projection) if init is None else init
        filename_init = os.path.join(dir_tmp.name, "init.hv")
    
        write_interfile(filename_init, init)
    
        params = params.replace("{INIT_FILE}", filename_init)
    
    
        if postfilter:
            params = postfilter.insert(params)
    
    
        filename_params = os.path.join(dir_tmp.name, "OSEM_SPECT.par")
        with open(filename_params, mode="w") as f:
            f.write(params.strip())
    
        recon = stir.OSMAPOSLReconstruction3DFloat(filename_params)
        recon.reconstruct()
    
        return load_interfile(filename_out)
    
    
    if __name__ == "__main__":
        import argparse
    
    
        from mu_map.file.util import load_as_interfile
    
        from mu_map.recon.project import forward_project
    
        parser = argparse.ArgumentParser(
    
            description="Reconstruct a projection or another reconstruction with attenuation correction",
    
            formatter_class=argparse.ArgumentDefaultsHelpFormatter,
        )
        parser.add_argument(
            "--projection", type=str, help="a projection in INTERFILE format"
        )
        parser.add_argument(
            "--recon",
            type=str,
    
            help="a reconstruction in DICOM or Interfile format - if specified the projection will be overwritten",
    
        )
        parser.add_argument(
            "--mu_map",
            type=str,
    
            help="a mu map for attenuation correction in DICOM or Interfile format",
    
        parser.add_argument(
    
            "--out",
            type=str,
            required=True,
            help="the filename to store the reconstruction",
    
        )
        parser.add_argument(
            "--n_subsets",
            type=int,
    
            default=3,
    
            help="the number of subsets for OSEM reconstruction",
        )
        parser.add_argument(
            "--n_iterations",
            type=int,
            default=10,
            help="the number of iterations for OSEM reconstruction",
        )
    
        parser.add_argument(
            "--postfilter",
            choices=["gaussian", "none"],
            default="gaussian",
            help="apply a postfilter to the reconstruction",
        )
        parser.add_argument(
            "--postfilter_width",
            type=float,
            default=1.0,
            help="the filter witdth for the postfilter is based on the spacing in the projection and can be modified with this factor",
        )
    
        parser.add_argument(
    
            "-v", "--verbosity", type=int, default=1, help="configure the verbosity of STIR"
    
        )
    
        args = parser.parse_args()
        assert (
            args.projection is not None or args.recon is not None
        ), "You have to specify either a projection or a reconstruction"
        stir.Verbosity_set(args.verbosity)
    
        mu_map = load_as_interfile(args.mu_map) if args.mu_map else None
    
        mu_map_slices = None if mu_map is None else mu_map.image.shape[0]
    
    
        if args.recon:
            recon = load_as_interfile(args.recon)
    
            projection = forward_project(recon, n_slices=mu_map_slices)
    
        else:
            projection = load_as_interfile(args.projection)
    
    
        postfilter = None
        if args.postfilter == "gaussian":
            postfilter = GaussianFilter(projection, args.postfilter_width)
    
    
        recon_header, recon_image = reconstruct(
            projection,
            mu_map=mu_map,
            n_subsets=args.n_subsets,
            n_iterations=args.n_iterations,
            postfilter=postfilter,
    
    
        recon_image = recon_image[
    
        ]  # STIR creates reconstructions flipped on the x-axes (https://sourceforge.net/p/stir/mailman/message/36938120/), this is reverted here
    
        recon = Interfile(recon_header, recon_image)
        write_interfile(args.out, recon)