import os
from typing import Dict, Optional, Tuple
import tempfile

import numpy as np
import stir

from mu_map.file.interfile import (
    Interfile,
    parse_interfile_header_str,
    load_interfile,
    write_interfile,
    InterfileKeys,
    TEMPLATE_HEADER_IMAGE,
)
from mu_map.recon.filter import GaussianFilter


"""
Template for a STIR OSEM reconstruction configuration.
"""
TEMPLATE_RECON_PARAMS = """
OSMAPOSLParameters :=

  objective function type:= PoissonLogLikelihoodWithLinearModelForMeanAndProjData
  PoissonLogLikelihoodWithLinearModelForMeanAndProjData Parameters:=

    input file := {PROJECTION}

    projector pair type := Matrix
      Projector Pair Using Matrix Parameters :=
      Matrix type := SPECT UB
      Projection Matrix By Bin SPECT UB Parameters:=
            ; width of PSF
            maximum number of sigmas:= 2.0

            ;PSF type of correction { 2D // 3D // Geometrical }
            psf type:= Geometrical
            ; next 2 parameters define the PSF. They are ignored if psf_type is "Geometrical"
            ; These values are mostly dependent on your collimator.
            ; the PSF is modelled as a Gaussian with sigma dependent on the distance from the collimator
            ; sigma_at_depth = collimator_slope * depth_in_cm + collimator sigma 0(cm)
            collimator slope := 0.0163
            collimator sigma 0(cm) := 0.1466

            ;Attenuation correction { Simple // Full // No }
            attenuation type := {ATTENUATION_TYPE}
            ;Values in attenuation map in cm-1
            attenuation map := {ATTENUATION_MAP}

            ;Mask properties { Cylinder // Attenuation Map // Explicit Mask // No}
            mask type := {MASK_TYPE}
            mask file := {MASK_FILE}

           ; if next variable is set to 0, only a single view is kept in memory
           keep all views in cache:=1

        End Projection Matrix By Bin SPECT UB Parameters:=
     End Projector Pair Using Matrix Parameters :=
  end PoissonLogLikelihoodWithLinearModelForMeanAndProjData Parameters:=

  initial estimate := {INIT_FILE}
  output filename prefix := {OUTPUT_PREFIX}

  number of subsets:= {N_SUBSETS}
  number of subiterations:= {N_SUBITERATIONS}
  Save estimates at subiteration intervals:= {SAVE_INTERVALS}

END :=
"""


def uniform_estimate(projection: Interfile) -> Interfile:
    """
    Create a uniform estimate (image with all ones) for a projection.
    This serves as the initial estimate for the EM reconstruction.

    :param projection: the projection for which the uniform estimate is created
    :return: an interfile image with all ones in a size matching the projection
    """
    header_proj, image_proj = projection

    image = np.ones(
        (image_proj.shape[1], image_proj.shape[2], image_proj.shape[2]), np.float32
    )

    offset = -0.5 * image_proj.shape[2] * float(header_proj[InterfileKeys.spacing(1)])
    header = TEMPLATE_HEADER_IMAGE.replace("{ROWS}", str(image.shape[2]))
    header = header.replace("{COLUMNS}", str(image.shape[1]))
    header = header.replace("{SLICES}", str(image.shape[0]))
    header = header.replace("{SPACING_X}", header_proj[InterfileKeys.spacing(1)])
    header = header.replace("{SPACING_Y}", header_proj[InterfileKeys.spacing(1)])
    header = header.replace("{SPACING_Z}", header_proj[InterfileKeys.spacing(1)])
    header = header.replace("{OFFSET_X}", f"{offset:.4f}")
    header = header.replace("{OFFSET_Y}", f"{offset:.4f}")
    header = header.replace(
        "{PATIENT_ORIENTATION}", header_proj[InterfileKeys.patient_orientation]
    )
    header = header.replace(
        "{PATIENT_ROTATION}", header_proj[InterfileKeys.patient_rotation]
    )
    header = parse_interfile_header_str(header)
    return Interfile(header, image)


def reconstruct(
    projection: Interfile,
    mu_map: Optional[Interfile] = None,
    mask: Optional[Interfile] = None,
    init: Optional[Interfile] = None,
    n_subsets: Optional[int] = 4,
    n_iterations: Optional[int] = 10,
    postfilter: Optional[GaussianFilter] = None,
):
    """
    Perform OSEM reconstruction with STIR.

    :param projection: the projection to reconstruct
    :param mu_map: the attenuation map used for reconstruction
    :param mask: a mask defining which voxels should be reconstructed
                    - if not given all voxels are reconstructed
                    - if an attenuation map is given only positive pixels in the
                      attenuation map are reconstructed
    :param init: an initial estimate for the reconstruction which defaults to
                 all ones
    :param n_subsets: number of subsets
    :param n_iterations: number of iterations
    :param postfilter: optional filter applied after the reconstruction
    :returns: a reconstruction of the projection
    """
    # sanitize parameters for STIR
    n_subiterations = n_subsets * n_iterations
    save_intervals = n_subiterations

    dir_tmp = tempfile.TemporaryDirectory()
    filename_projection = os.path.join(dir_tmp.name, "projection.hv")
    write_interfile(filename_projection, projection)
    params = TEMPLATE_RECON_PARAMS.replace("{PROJECTION}", filename_projection)

    output_prefix = os.path.join(dir_tmp.name, "out")
    filename_out = f"{output_prefix}_{save_intervals}.hv"
    params = params.replace("{OUTPUT_PREFIX}", output_prefix)
    params = params.replace("{N_SUBSETS}", str(n_subsets))
    params = params.replace("{N_SUBITERATIONS}", str(n_subiterations))
    params = params.replace("{SAVE_INTERVALS}", str(save_intervals))

    # prepare attenuation parameters
    if mu_map is not None:
        filename_mu_map = os.path.join(dir_tmp.name, "mu_map.hv")
        write_interfile(filename_mu_map, mu_map)
        params = params.replace("{ATTENUATION_TYPE}", "Full")
        params = params.replace("{ATTENUATION_MAP}", filename_mu_map)
    else:
        params = params.replace("{ATTENUATION_TYPE}", "No")
        params = params.replace("attenuation map", ";attenuation map")

    # prepare mask parameters
    if mask is not None:
        filename_mask = os.path.join(dir_tmp.name, "mask.hv")
        write_interfile(filename_mask, mask)
        params = params.replace("{MASK_TYPE}", "Explicit Mask")
        params = params.replace("{MASK_FILE}", filename_mask)
    else:
        params = params.replace("mask file", ";mask file")
        params = params.replace(
            "{MASK_TYPE}", "Attenuation Map" if mu_map is not None else "No"
        )

    init = uniform_estimate(projection) if init is None else init
    filename_init = os.path.join(dir_tmp.name, "init.hv")
    write_interfile(filename_init, init)
    params = params.replace("{INIT_FILE}", filename_init)

    if postfilter:
        params = postfilter.insert(params)

    filename_params = os.path.join(dir_tmp.name, "OSEM_SPECT.par")
    with open(filename_params, mode="w") as f:
        f.write(params.strip())

    recon = stir.OSMAPOSLReconstruction3DFloat(filename_params)
    recon.reconstruct()

    return load_interfile(filename_out)


if __name__ == "__main__":
    import argparse

    from mu_map.file.util import load_as_interfile
    from mu_map.recon.project import forward_project

    parser = argparse.ArgumentParser(
        description="Reconstruct a projection or another reconstruction with attenuation correction",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )
    parser.add_argument(
        "--projection", type=str, help="a projection in INTERFILE format"
    )
    parser.add_argument(
        "--recon",
        type=str,
        help="a reconstruction in DICOM or Interfile format - if specified the projection will be overwritten",
    )
    parser.add_argument(
        "--mu_map",
        type=str,
        help="a mu map for attenuation correction in DICOM or Interfile format",
    )
    parser.add_argument(
        "--out",
        type=str,
        required=True,
        help="the filename to store the reconstruction",
    )
    parser.add_argument(
        "--n_subsets",
        type=int,
        default=3,
        help="the number of subsets for OSEM reconstruction",
    )
    parser.add_argument(
        "--n_iterations",
        type=int,
        default=10,
        help="the number of iterations for OSEM reconstruction",
    )
    parser.add_argument(
        "--postfilter",
        choices=["gaussian", "none"],
        default="gaussian",
        help="apply a postfilter to the reconstruction",
    )
    parser.add_argument(
        "--postfilter_width",
        type=float,
        default=1.0,
        help="the filter witdth for the postfilter is based on the spacing in the projection and can be modified with this factor",
    )
    parser.add_argument(
        "-v", "--verbosity", type=int, default=1, help="configure the verbosity of STIR"
    )

    args = parser.parse_args()
    assert (
        args.projection is not None or args.recon is not None
    ), "You have to specify either a projection or a reconstruction"
    stir.Verbosity_set(args.verbosity)

    mu_map = load_as_interfile(args.mu_map) if args.mu_map else None
    mu_map_slices = None if mu_map is None else mu_map.image.shape[0]

    if args.recon:
        recon = load_as_interfile(args.recon)
        projection = forward_project(recon, n_slices=mu_map_slices)
    else:
        projection = load_as_interfile(args.projection)

    postfilter = None
    if args.postfilter == "gaussian":
        postfilter = GaussianFilter(projection, args.postfilter_width)

    recon_header, recon_image = reconstruct(
        projection,
        mu_map=mu_map,
        n_subsets=args.n_subsets,
        n_iterations=args.n_iterations,
        postfilter=postfilter,
    )

    recon_image = recon_image[
        :, :, ::-1
    ]  # STIR creates reconstructions flipped on the x-axes (https://sourceforge.net/p/stir/mailman/message/36938120/), this is reverted here
    recon = Interfile(recon_header, recon_image)
    write_interfile(args.out, recon)