import os
from typing import Dict, Tuple

import numpy as np


"""
Several keys defined in INTERFILE headers.
"""
KEY_DIM_1 = "!matrix size [1]"
KEY_DIM_2 = "!matrix size [2]"
KEY_DIM_3 = "!matrix size [3]"
KEY_SPACING_1 = "scaling factor (mm/pixel) [1]"
KEY_SPACING_2 = "scaling factor (mm/pixel) [2]"
KEY_SPACING_3 = "scaling factor (mm/pixel) [2]"
KEY_NPROJECTIONS = "!number of projections"

KEY_DATA_FILE = "name of data file"

KEY_BYTES_PER_PIXEL = "!number of bytes per pixel"
KEY_NUMBER_FORMAT = "!number format"


"""
A template of an INTERFILE header.
"""
HEADER_TEMPLATE = """
!INTERFILE  :=
!imaging modality := nucmed
!version of keys := STIR4.0
name of data file := {DATA_FILE}
!GENERAL DATA :=
!GENERAL IMAGE DATA :=
!type of data := Tomographic
imagedata byte order := LITTLEENDIAN
isotope name := ^99m^Technetium
!SPECT STUDY (General) :=
process status := Reconstructed
!number format := float
!number of bytes per pixel := 4
number of dimensions := 3
matrix axis label [1] := x
!matrix size [1] := {ROWS}
scaling factor (mm/pixel) [1] := {SPACING_X}
matrix axis label [2] := y
!matrix size [2] := {COLUMNS}
scaling factor (mm/pixel) [2] := {SPACING_Y}
matrix axis label [3] := z
!matrix size [3] := {SLICES}
scaling factor (mm/pixel) [3] := {SPACING_Z}
first pixel offset (mm) [1] := {OFFSET_X}
first pixel offset (mm) [2] := {OFFSET_Y}
first pixel offset (mm) [3] := 0
number of time frames := 1
!END OF INTERFILE :=
"""

HEADER_TEMPLATE_PROJ = """
!INTERFILE  :=
!imaging modality := nucmed
!version of keys := 3.3
name of data file := {DATA_FILE}
;data offset in bytes := 0

!GENERAL IMAGE DATA :=
!type of data := Tomographic
imagedata byte order := LITTLEENDIAN
!number format := float
!number of bytes per pixel := 4

!SPECT STUDY (General) := 
;number of dimensions := 2
;matrix axis label [2] := axial coordinate
!matrix size [2] := {ROWS}
!scaling factor (mm/pixel) [2] := {SPACING_X}
;matrix axis label [1] := bin coordinate
!matrix size [1] := {COLUMNS}
!scaling factor (mm/pixel) [1] := {SPACING_Y}
!number of projections := {N_PROJECTIONS}
!extent of rotation := {ROTATION}
!process status := acquired

!SPECT STUDY (acquired data) :=
!direction of rotation := CW
start angle := {START_ANGLE}
orbit := circular
radius := {RADIUS}

!END OF INTERFILE :=
"""


def type_by_format(number_format: str, bytes_per_pixel: int) -> type:
    """
    Get the corresponding numpy array type for a number format and bytes per
    pixel as defined in an INTERFILE header.

    :param number_format: the number format as a string, e.g. "float"
    :param bytes_per_pixel: the amount of bytes stored for each pixel
    :return: a corresponding numpy data type
    """
    if number_format == "float" and bytes_per_pixel == 4:
        return np.float32

    raise ValueError("Unknown mapping from format {number_format} with {bytes_per_pixel} bytes to numpy type")


def parse_interfile_header(filename: str) -> Dict[str, str]:
    """
    Parse an INTERFILE header into a dict.
    This is done by splitting non-empty lines in the INTERFILE header by ":=".

    :param filename: the filename of the INTERFILE header
    :return: a dictionary of value in the header
    """
    with open(filename, mode="r") as f:
        lines = f.readlines()

    items = map(lambda line: line.split(":="), lines)
    items = filter(lambda item: len(item) == 2, items)
    items = map(lambda item: [item[0].strip(), item[1].strip()], items)
    return dict(items)


def load_interfile(filename: str) -> Tuple[Dict[str, str], np.ndarray]:
    """
    Load an INTERFILE header and its image as a numpy array.

    :param filename: the filename of the INTERFILE header file
    :return: the header as a dict and the image as a numpy array
    """
    header = parse_interfile_header(filename)

    dim_x = int(header[KEY_DIM_1])
    dim_y = int(header[KEY_DIM_2])
    dim_z = int(header[KEY_DIM_3]) if KEY_DIM_3 in header else int(header[KEY_NPROJECTIONS])

    bytes_per_pixel = int(header[KEY_BYTES_PER_PIXEL])
    num_format = header[KEY_NUMBER_FORMAT]
    dtype = type_by_format(num_format, bytes_per_pixel)

    data_file = os.path.join(os.path.dirname(filename), header[KEY_DATA_FILE])
    with open(data_file, mode="rb") as f:
        image = np.frombuffer(f.read(), dtype)
    image = image.reshape((dim_z, dim_y, dim_x))
    return header, image.copy()


def load_interfile_img(filename: str) -> np.ndarray:
    """
    Load an INTERFILE image as a numpy array.

    :param filename: the filename of the INTERFILE header file
    :return: the image as a numpy array
    """
    _, image = load_interfile(filename)
    return image


def write_interfile(filename: str, header: Dict[str, str], image: np.ndarray):
    filename = os.path.splitext(filename)[0]
    filename_data = f"{filename}.v"
    filename_header = f"{filename}.hv"

    header[KEY_DATA_FILE] = os.path.basename(filename_data)
    header[KEY_DIM_3] = str(image.shape[0])


    header_str = map(lambda item: f"{item[0]} := {item[1]}", header.items())
    header_str = "\n".join(header_str)
    with open(filename_header, mode="w") as f:
        f.write(header_str)
    with open(filename_data, mode="wb") as f:
        f.write(image.tobytes())


if __name__ == "__main__":
    import argparse
    import math

    parser = argparse.ArgumentParser(description="Modify an interfile image")
    parser.add_argument("--interfile", type=str, required=True, help="the interfile input")
    parser.add_argument("--out", type=str, help="the interfile output - if not set the input file will be overwritte")
    parser.add_argument("--cmd", choices=["fill", "crop", "pad"], help="the modification to perform")
    parser.add_argument("--param", type=int, required=True, help="the parameters for the modification [fill: value, crop & pad: target size of z dimension]")
    args = parser.parse_args()
    args.out = args.interfile if args.out is None else args.out

    header, image = load_interfile(args.interfile)

    if args.cmd == "fill":
        value = args.param
        image.fill(value)
    if args.cmd == "pad":
        target_size = args.param
        real_size = image.shape[0]
        assert target_size > real_size, f"Cannot pad from a larger size {real_size} to a smaller size {target_size}"

        diff = target_size - real_size
        pad_lower = math.ceil(diff / 2)
        pad_upper = math.floor(diff / 2)
        image = np.pad(image, ((pad_lower, pad_upper), (0, 0), (0, 0)))
    if args.cmd == "crop":
        target_size = args.param
        real_size = image.shape[0]
        assert target_size < real_size, f"Cannot crop from a smaller size {real_size} to a larger size {target_size}"

        diff = target_size / 2
        center = real_size // 2
        crop_lower = center - math.floor(diff)
        crop_upper = center + math.ceil(diff)
        image = image[crop_lower:crop_upper]

    write_interfile(args.out, header, image)