from dataclasses import dataclass
import os
from typing import Dict, Tuple

import numpy as np

Interfile = Tuple[Dict[str, str], np.ndarray]

@dataclass
class _InterfileKeys:
    placeholder: str = "{_}"

    _dim: str = f"matrix size [{placeholder}]"
    _spacing: str = f"scaling factor (mm/pixel) [{placeholder}]"

    data_file: str = "name of data file"
    n_projections: str = "number of projections"
    bytes_per_pixel: str = "number of bytes per pixel"
    number_format: str = "number format"

    def dim(self, index: int) -> str:
        return self._dim.replace(self.placeholder, str(index))

    def spacing(self, index: int) -> str:
        return self._spacing.replace(self.placeholder, str(index))

InterfileKeys = _InterfileKeys()

"""
Several keys defined in INTERFILE headers.
"""
KEY_DIM_1 = "matrix size [1]"
KEY_DIM_2 = "matrix size [2]"
KEY_DIM_3 = "matrix size [3]"
KEY_SPACING_1 = "scaling factor (mm/pixel) [1]"
KEY_SPACING_2 = "scaling factor (mm/pixel) [2]"
KEY_SPACING_3 = "scaling factor (mm/pixel) [2]"
KEY_NPROJECTIONS = "number of projections"

KEY_DATA_FILE = "name of data file"

KEY_BYTES_PER_PIXEL = "number of bytes per pixel"
KEY_NUMBER_FORMAT = "number format"


"""
A template of an INTERFILE header.
"""
TEMPLATE_HEADER_IMAGE = """
INTERFILE  :=
imaging modality := nucmed
version of keys := STIR4.0
name of data file := {DATA_FILE}
GENERAL DATA :=
GENERAL IMAGE DATA :=
type of data := Tomographic
imagedata byte order := LITTLEENDIAN
isotope name := ^99m^Technetium
SPECT STUDY (General) :=
process status := Reconstructed
number format := float
number of bytes per pixel := 4
number of dimensions := 3
matrix axis label [1] := x
matrix size [1] := {ROWS}
scaling factor (mm/pixel) [1] := {SPACING_X}
matrix axis label [2] := y
matrix size [2] := {COLUMNS}
scaling factor (mm/pixel) [2] := {SPACING_Y}
matrix axis label [3] := z
matrix size [3] := {SLICES}
scaling factor (mm/pixel) [3] := {SPACING_Z}
first pixel offset (mm) [1] := {OFFSET_X}
first pixel offset (mm) [2] := {OFFSET_Y}
first pixel offset (mm) [3] := 0
number of time frames := 1
END OF INTERFILE :=
"""



def type_by_format(number_format: str, bytes_per_pixel: int) -> type:
    """
    Get the corresponding numpy array type for a number format and bytes per
    pixel as defined in an INTERFILE header.

    :param number_format: the number format as a string, e.g. "float"
    :param bytes_per_pixel: the amount of bytes stored for each pixel
    :return: a corresponding numpy data type
    """
    if number_format == "float" and bytes_per_pixel == 4:
        return np.float32

    raise ValueError("Unknown mapping from format {number_format} with {bytes_per_pixel} bytes to numpy type")


def parse_interfile_header_str(header: str) -> Dict[str, str]:
    """
    Parse an INTERFILE header into a dict.
    This is done by splitting non-empty lines in the INTERFILE header by ":=".

    :param header: the text of an INTERFILE header
    :return: a dictionary of value in the header
    """
    header = header.replace("!", "")
    lines = header.strip().split("\n")
    items = map(lambda line: line.split(":="), lines)
    items = filter(lambda item: len(item) == 2, items)
    items = map(lambda item: [item[0].strip(), item[1].strip()], items)
    return dict(items)


def parse_interfile_header(filename: str) -> Dict[str, str]:
    """
    Parse an INTERFILE header into a dict.
    This is done by splitting non-empty lines in the INTERFILE header by ":=".

    :param filename: the filename of the INTERFILE header
    :return: a dictionary of value in the header
    """
    with open(filename, mode="r") as f:
        return parse_interfile_header_str(f.read())


def load_interfile(filename: str) -> Tuple[Dict[str, str], np.ndarray]:
    """
    Load an INTERFILE header and its image as a numpy array.

    :param filename: the filename of the INTERFILE header file
    :return: the header as a dict and the image as a numpy array
    """
    header = parse_interfile_header(filename)

    dim_x = int(header[KEY_DIM_1])
    dim_y = int(header[KEY_DIM_2])
    dim_z = int(header[KEY_DIM_3]) if KEY_DIM_3 in header else int(header[KEY_NPROJECTIONS])

    bytes_per_pixel = int(header[KEY_BYTES_PER_PIXEL])
    num_format = header[KEY_NUMBER_FORMAT]
    dtype = type_by_format(num_format, bytes_per_pixel)

    data_file = os.path.join(os.path.dirname(filename), header[KEY_DATA_FILE])
    with open(data_file, mode="rb") as f:
        image = np.frombuffer(f.read(), dtype)
    image = image.reshape((dim_z, dim_y, dim_x))
    return header, image.copy()


def load_interfile_img(filename: str) -> np.ndarray:
    """
    Load an INTERFILE image as a numpy array.

    :param filename: the filename of the INTERFILE header file
    :return: the image as a numpy array
    """
    _, image = load_interfile(filename)
    return image


def write_interfile(filename: str, header: Dict[str, str], image: np.ndarray):
    filename = os.path.splitext(filename)[0]
    filename_data = f"{filename}.v"
    filename_header = f"{filename}.hv"

    header[KEY_DATA_FILE] = os.path.basename(filename_data)
    header[KEY_DIM_3] = str(image.shape[0])
    header[KEY_DIM_2] = str(image.shape[1])
    header[KEY_DIM_1] = str(image.shape[2])

    image = image.astype(np.float32)

    header_str = map(lambda item: f"{item[0]} := {item[1]}", header.items())
    header_str = "\n".join(header_str)
    with open(filename_header, mode="w") as f:
        f.write(header_str)
    with open(filename_data, mode="wb") as f:
        f.write(image.tobytes())


if __name__ == "__main__":
    import argparse
    import math

    parser = argparse.ArgumentParser(description="Modify an interfile image")
    parser.add_argument("--interfile", type=str, required=True, help="the interfile input")
    parser.add_argument("--out", type=str, help="the interfile output - if not set the input file will be overwritte")
    parser.add_argument("--cmd", choices=["fill", "crop", "pad", "flip"], help="the modification to perform")
    parser.add_argument("--param", type=int, required=True, help="the parameters for the modification [fill: value, crop & pad: target size of z dimension]")
    args = parser.parse_args()
    args.out = args.interfile if args.out is None else args.out

    header, image = load_interfile(args.interfile)

    if args.cmd == "fill":
        value = args.param
        image.fill(value)
    if args.cmd == "pad":
        target_size = args.param
        real_size = image.shape[0]
        assert target_size > real_size, f"Cannot pad from a larger size {real_size} to a smaller size {target_size}"

        diff = target_size - real_size
        pad_lower = math.ceil(diff / 2)
        pad_upper = math.floor(diff / 2)
        image = np.pad(image, ((pad_lower, pad_upper), (0, 0), (0, 0)))
    if args.cmd == "crop":
        target_size = args.param
        real_size = image.shape[0]
        assert target_size < real_size, f"Cannot crop from a smaller size {real_size} to a larger size {target_size}"

        diff = target_size / 2
        center = real_size // 2
        crop_lower = center - math.floor(diff)
        crop_upper = center + math.ceil(diff)
        image = image[crop_lower:crop_upper]
    if args.cmd == "flip":
        image = image[:, :, ::-1]

    write_interfile(args.out, header, image)