import os
from typing import Dict, Optional, Tuple
import tempfile

import numpy as np
import stir

from mu_map.file import load_as_interfile
from mu_map.file.interfile import (
    parse_interfile_header_str,
    load_interfile,
    write_interfile,
    KEY_DIM_1,
    KEY_DIM_2,
    KEY_SPACING_1,
    KEY_SPACING_2,
    TEMPLATE_HEADER_IMAGE,
)


TEMPLATE_RECON_PARAMS = """
OSMAPOSLParameters :=

  objective function type:= PoissonLogLikelihoodWithLinearModelForMeanAndProjData
  PoissonLogLikelihoodWithLinearModelForMeanAndProjData Parameters:=

    input file := {PROJECTION}

    projector pair type := Matrix
      Projector Pair Using Matrix Parameters :=
      Matrix type := SPECT UB
      Projection Matrix By Bin SPECT UB Parameters:=
            ; width of PSF
            maximum number of sigmas:= 2.0

            ;PSF type of correction { 2D // 3D // Geometrical }
            psf type:= Geometrical
            ; next 2 parameters define the PSF. They are ignored if psf_type is "Geometrical"
            ; These values are mostly dependent on your collimator.
            ; the PSF is modelled as a Gaussian with sigma dependent on the distance from the collimator
            ; sigma_at_depth = collimator_slope * depth_in_cm + collimator sigma 0(cm)
            collimator slope := 0.0163
            collimator sigma 0(cm) := 0.1466

            ;Attenuation correction { Simple // Full // No }
            attenuation type := {ATTENUATION_TYPE}
            ;Values in attenuation map in cm-1
            attenuation map := {ATTENUATION_MAP}

            ;Mask properties { Cylinder // Attenuation Map // Explicit Mask // No}
            mask type := {MASK_TYPE}
            mask file := {MASK_FILE}

           ; if next variable is set to 0, only a single view is kept in memory
           keep all views in cache:=1

        End Projection Matrix By Bin SPECT UB Parameters:=
     End Projector Pair Using Matrix Parameters :=
  end PoissonLogLikelihoodWithLinearModelForMeanAndProjData Parameters:=

  initial estimate := {INIT_FILE}
  output filename prefix := {OUTPUT_PREFIX}

  number of subsets:= {N_SUBSETS}
  number of subiterations:= {N_SUBITERATIONS}
  Save estimates at subiteration intervals:= {SAVE_INTERVALS}

  ; keywords that specify the filtering that occurs after every subiteration
  ; warning: do not normally use together with a prior
  ;inter-iteration filter subiteration interval := 4
  ;inter-iteration filter type := Separable Gaussian
  ;post-filter type := Separable Gaussian
  ;separable gaussian filter parameters :=
    ;x-dir filter fwhm (in mm) := 6
    ;y-dir filter fwhm (in mm) := 6
    ;z-dir filter fwhm (in mm) := 6
    ;x-dir maximum kernel size := 129
    ;y-dir maximum kernel size := 129
    ;z-dir maximum kernel size := 31
    ;Normalise filter to 1 := 1
  ;end separable gaussian filter parameters :=

END :=
"""


def uniform_estimate(projection: Tuple[Dict[str, str], np.ndarray]):
    header_proj, image_proj = projection

    image = np.ones(
        (image_proj.shape[1], image_proj.shape[2], image_proj.shape[2]), np.float32
    )

    offset = -0.5 * image_proj.shape[2] * float(header_proj[KEY_SPACING_1])
    header = TEMPLATE_HEADER_IMAGE.replace("{ROWS}", str(image.shape[2]))
    header = header.replace("{COLUMNS}", str(image.shape[1]))
    header = header.replace("{SLICES}", str(image.shape[0]))
    header = header.replace("{SPACING_X}", header_proj[KEY_SPACING_1])
    header = header.replace("{SPACING_Y}", header_proj[KEY_SPACING_1])
    header = header.replace("{SPACING_Z}", header_proj[KEY_SPACING_2])
    header = header.replace("{OFFSET_X}", f"{offset:.4f}")
    header = header.replace("{OFFSET_Y}", f"{offset:.4f}")
    header = parse_interfile_header_str(header)

    return header, image


def reconstruct(
    projection: Tuple[Dict[str, str], np.ndarray],
    mu_map: Optional[Tuple[Dict[str, str], np.ndarray]] = None,
    mask: Optional[Tuple[Dict[str, str], np.ndarray]] = None,
    init: Optional[Tuple[Dict[str, str], np.ndarray]] = None,
    n_subsets: Optional[int] = 4,
    n_iterations: Optional[int] = 10,
    **kwargs,
):
    # sanitize parameters
    n_subiterations = n_subsets * n_iterations
    save_intervals = n_subiterations

    dir_tmp = tempfile.TemporaryDirectory()
    filename_projection = os.path.join(dir_tmp.name, "projection.hv")
    write_interfile(filename_projection, *projection)
    params = TEMPLATE_RECON_PARAMS.replace("{PROJECTION}", filename_projection)

    output_prefix = os.path.join(dir_tmp.name, "out")
    filename_out = f"{output_prefix}_{save_intervals}.hv"
    params = params.replace("{OUTPUT_PREFIX}", output_prefix)
    params = params.replace("{N_SUBSETS}", str(n_subsets))
    params = params.replace("{N_SUBITERATIONS}", str(n_subiterations))
    params = params.replace("{SAVE_INTERVALS}", str(save_intervals))

    if mu_map is not None:
        filename_mu_map = os.path.join(dir_tmp.name, "mu_map.hv")
        write_interfile(filename_mu_map, *mu_map)
        params = params.replace("{ATTENUATION_TYPE}", "Full")
        params = params.replace("{ATTENUATION_MAP}", filename_mu_map)
    else:
        params = params.replace("{ATTENUATION_TYPE}", "No")
        params = params.replace("attenuation map", ";attenuation map")

    if mask is not None:
        filename_mask = os.path.join(dir_tmp.name, "mask.hv")
        write_interfile(filename_mask, *mask)
        params = params.replace("{MASK_TYPE}", "Explicit Mask")
        params = params.replace("{MASK_FILE}", filename_mask)
    else:
        params = params.replace("mask file", ";mask file")
        params = params.replace(
            "{MASK_TYPE}", "Attenuation Map" if mu_map is not None else "No"
        )

    init = uniform_estimate(projection) if init is None else init
    filename_init = os.path.join(dir_tmp.name, "init.hv")
    write_interfile(filename_init, *init)
    params = params.replace("{INIT_FILE}", filename_init)

    filename_params = os.path.join(dir_tmp.name, "OSEM_SPECT.par")
    with open(filename_params, mode="w") as f:
        f.write(params.strip())

    recon = stir.OSMAPOSLReconstruction3DFloat(filename_params)
    recon.reconstruct()

    print(params)
    return load_interfile(filename_out)


if __name__ == "__main__":
    import argparse

    from mu_map.file import load_as_interfile
    from mu_map.recon.project import forward_project

    parser = argparse.ArgumentParser(
        description="Reconstruct a projection or another reconstruction",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )
    parser.add_argument(
        "--projection", type=str, help="a projection in INTERFILE format"
    )
    parser.add_argument(
        "--recon",
        type=str,
        help="a reconstruction in DICOM or INTERFILE format - if specified the projection will be overwritten",
    )
    parser.add_argument(
        "--mu_map",
        type=str,
        help="a mu map for attenuation correction in DICOM or INTERFILE format",
    )
    parser.add_argument(
        "--out", type=str, help="the filename to store the reconstruction"
    )
    parser.add_argument(
        "--n_subsets",
        type=int,
        default=4,
        help="the number of subsets for OSEM reconstruction",
    )
    parser.add_argument(
        "--n_iterations",
        type=int,
        default=10,
        help="the number of iterations for OSEM reconstruction",
    )
    parser.add_argument(
        "-v", "--verbosity", type=int, default=0, help="configure the verbosity of STIR"
    )

    args = parser.parse_args()
    assert (
        args.projection is not None or args.recon is not None
    ), "You have to specify either a projection or a reconstruction"
    stir.Verbosity_set(args.verbosity)

    mu_map = load_as_interfile(args.mu_map) if args.mu_map else None
    mu_map_slices = None if mu_map is None else mu_map[1].shape[0]

    if args.recon:
        recon = load_as_interfile(args.recon)
        projection = forward_project(*recon, n_slices=mu_map_slices)
    else:
        projection = load_as_interfile(args.projection)

    kwargs = vars(args)
    del kwargs["mu_map"]
    del kwargs["projection"]
    print(kwargs)
    header, image = reconstruct(projection, mu_map=mu_map, **kwargs)
    image = image[:, :, ::-1]

    write_interfile(args.out, header, image)