"""
Utility script to prepare raw data for further processing.
It is highly dependant on the way the raw DICOM data were exported.
The script creates a folder containing the file `meta.csv` and another
folder containing the DICOM images sorted by an id for each study.
"""
import argparse
from datetime import datetime, timedelta
from enum import Enum
import os
from typing import List, Dict, Callable

import numpy as np
import pandas as pd
import pydicom

from mu_map.file.dicom import DICOM, DICOMTime, parse_age
from mu_map.logging import add_logging_args, get_logger_by_args


STUDY_DESCRIPTION = "µ-map_study"


headers = argparse.Namespace()
headers.id = "id"
headers.patient_id = "patient_id"
headers.age = "age"
headers.sex = "sex"
headers.weight = "weight"
headers.size = "size"
headers.patient_position = "patient_position"
headers.protocol = "protocol"
headers.datetime_acquisition = "datetime_acquisition"
headers.datetime_reconstruction = "datetime_reconstruction"
headers.pixel_spacing_x = "pixel_spacing_x"
headers.pixel_spacing_y = "pixel_spacing_y"
headers.pixel_spacing_z = "pixel_spacing_z"
headers.shape_x = "shape_x"
headers.shape_y = "shape_y"
headers.shape_z = "shape_z"
headers.radiopharmaceutical = "radiopharmaceutical"
headers.radionuclide_dose = "radionuclide_dose"
headers.radionuclide_code = "radionuclide_code"
headers.radionuclide_meaning = "radionuclide_meaning"
headers.energy_window_peak_lower = "energy_window_peak_lower"
headers.energy_window_peak_upper = "energy_window_peak_upper"
headers.energy_window_scatter_lower = "energy_window_scatter_lower"
headers.energy_window_scatter_upper = "energy_window_scatter_upper"
headers.detector_count = "detector_count"
headers.collimator_type = "collimator_type"
headers.rotation_start = "rotation_start"
headers.rotation_step = "rotation_step"
headers.rotation_scan_arc = "rotation_scan_arc"
headers.file_projection = "file_projection"
headers.file_recon_ac_sc = "file_recon_ac_sc"
headers.file_recon_nac_sc = "file_recon_nac_sc"
headers.file_recon_ac_nsc = "file_recon_ac_nsc"
headers.file_recon_nac_nsc = "file_recon_nac_nsc"
headers.file_mu_map = "file_mu_map"


def get_protocol(projection: DICOM) -> str:
    """
    Get the protocol (stress, rest) of a projection image by checking if
    it is part of the series description.

    Parameters
    ----------
    projection: DICOM
        DICOM image of the projection

    Returns
    -------
    str
        the protocol as a string (Stress or Rest)
    """
    if "stress" in projection.SeriesDescription.lower():
        return "Stress"

    if "rest" in projection.SeriesDescription.lower():
        return "Rest"

    raise ValueError(f"Unkown protocol in projection {projection.SeriesDescription}")


def find_projections(
    dicom_images: List[DICOM],
) -> List[DICOM]:
    """
    Find a projections in a list of DICOM images belonging to a study.

    Parameters
    ----------
     dicom_images: list of DICOM
        DICOM images of a study

    Returns
    -------
    list of DICOM
        all projection images in the input list
    """
    _filter = filter(lambda image: "TOMO" in image.ImageType, dicom_images)
    dicom_images = []
    for dicom_image in _filter:
        # filter for allows series descriptions
        if dicom_image.SeriesDescription not in ["Stress", "Rest", "Stress_2"]:
            logger.warning(
                f"Skip projection with unknown protocol [{dicom_image.SeriesDescription}]"
            )
            continue
        dicom_images.append(dicom_image)

    if len(dicom_images) == 0:
        raise ValueError(f"No projections available")

    return dicom_images


def is_recon_type(
    scatter_corrected: bool, attenuation_corrected: bool
) -> Callable[[DICOM], bool]:
    """
    Get a filter function for reconstructions that are (non-)scatter and/or (non-)attenuation corrected.

    Parameters
    ----------
    scatter_corrected: bool
        if the filter should only return true for scatter corrected reconstructions
    attenuation_corrected: bool
        if the filter should only return true for attenuation corrected reconstructions

    Returns
    -------
    Callable[DICOM, bool]
        a filter function that returns true if the DICOM image has the specified corrections
    """
    if scatter_corrected and attenuation_corrected:
        filter_str = " SC - AC "
    elif not scatter_corrected and attenuation_corrected:
        filter_str = " NoSC - AC "
    elif scatter_corrected and not attenuation_corrected:
        filter_str = " SC  - NoAC "
    elif not scatter_corrected and not attenuation_corrected:
        filter_str = " NoSC - NoAC "

    return lambda dicom: filter_str in dicom.SeriesDescription


def find_reconstruction(
    dicom_images: List[DICOM],
    projection: DICOM,
    scatter_corrected: bool,
    attenuation_corrected: bool,
) -> List[DICOM]:
    """
    Find all reconstructions in a list of DICOM images of a study belonging to a projection.

    Parameters
    ----------
    dicom_images: lost of DICOM
        DICOM images belonging to the study
    projection: DICOM
        the DICOM image of the projection the reconstructions belong to
    scatter_corrected: bool
        if it should be searched for a scatter corrected reconstruction
    attenuation_corrected: bool
        if it should be searched fo a attenuation corrected reconstruction

    Returns
    -------
    DICOM
        the according reconstruction
    """
    protocol = get_protocol(projection)

    _filter = filter(lambda image: "RECON TOMO" in image.ImageType, dicom_images)
    _filter = filter(lambda image: protocol in image.SeriesDescription, _filter)
    _filter = filter(
        lambda image: STUDY_DESCRIPTION in image.SeriesDescription, _filter
    )
    _filter = filter(
        lambda image: "CT" not in image.SeriesDescription, _filter
    )  # remove µ-maps
    _filter = filter(is_recon_type(scatter_corrected, attenuation_corrected), _filter)
    _filter = filter(
        lambda image: DICOMTime.Acquisition.to_datetime(image)
        == DICOMTime.Acquisition.to_datetime(projection),
        _filter,
    )

    dicom_images = list(_filter)
    if len(dicom_images) == 0:
        raise ValueError(
            f"No reconstruction with SC={scatter_corrected}, AC={attenuation_corrected} available"
        )

    # sort oldest to be first
    if len(dicom_images) > 1:
        logger.warning(
            f"Multiple reconstructions ({len(dicom_images)}) with SC={scatter_corrected}, AC={attenuation_corrected} for projection {projection.SeriesDescription} of patient {projection.PatientID}"
        )
        dicom_images.sort(
            key=lambda image: DICOMTime.Series.to_datetime(image), reverse=True
        )
    return dicom_images[0]


def find_attenuation_map(
    dicom_images: List[DICOM],
    projection: DICOM,
    reconstructions: List[DICOM],
    max_time_diff: int = 30,
) -> DICOM:
    """
    Find an attenuation map in a list of DICOM images of a study belonging to a projection and reconstructions.

    Parameters
    ----------
    dicom_images: list of DICOM
        the list of DICOM images belonging to the study
    projection: DICOM
        the DICOM image of the projection
    reconstructions: DICOM
        DICOM images of reconstructions belonging to the projection
    max_time_diff: int, optional
        filter out DICOM files which differ more than this value
        in series time to the reconstructions

    Returns
    -------
    DICOM
        the according attenuation map
    """
    protocol = get_protocol(projection)
    recon_times = list(
        map(lambda recon: DICOMTime.Series.to_datetime(recon), reconstructions)
    )

    _filter = filter(lambda image: "RECON TOMO" in image.ImageType, dicom_images)
    _filter = filter(lambda image: protocol in image.SeriesDescription, _filter)
    _filter = filter(
        lambda image: STUDY_DESCRIPTION in image.SeriesDescription, _filter
    )
    _filter = filter(lambda image: " µ-map]" in image.SeriesDescription, _filter)
    _filter = filter(
        lambda image: any(
            map(
                lambda recon_time: (
                    DICOMTime.Series.to_datetime(image) - recon_time
                ).seconds
                < max_time_diff,
                recon_times,
            )
        ),
        _filter,
    )

    dicom_images = list(_filter)
    if len(dicom_images) == 0:
        raise ValueError(f"No Attenuation map available")

    # sort oldest to be first
    if len(dicom_images) > 1:
        logger.warning(
            f"Multiple attenuation maps ({len(dicom_images)}) for projection {projection.SeriesDescription} of patient {projection.PatientID}"
        )
        dicom_images.sort(
            key=lambda image: DICOMTime.Series.to_datetime(image), reverse=True
        )

    return dicom_images[0]


def get_relevant_images(patient: DICOM, dicom_dir: str) -> List[DICOM]:
    """
    Get all relevant images of a patient.

    Parameters
    ----------
    patient: DICOM
        DICOM dataset of a patient
    dicom_dir: str
        the directory of the DICOM files

    Returns
    -------
    list of DICOM
        all relevant DICOM images
    """
    # get all myocardial scintigraphy studies
    studies = list(
        filter(
            lambda child: child.DirectoryRecordType == "STUDY",
            patient.children,
        )
    )

    # extract all DICOM images
    dicom_images = []
    for study in studies:
        series = list(
            filter(
                lambda child: child.DirectoryRecordType == "SERIES",
                study.children,
            )
        )
        for _series in series:
            images = list(
                filter(
                    lambda child: child.DirectoryRecordType == "IMAGE",
                    _series.children,
                )
            )

            # all SPECT data is stored as a single 3D array which means that it is a series with a single image
            # this is not the case for CTs, which are skipped here
            if len(images) != 1:
                continue

            images = list(
                map(
                    lambda image: pydicom.dcmread(
                        os.path.join(
                            dicom_dir_by_patient[patient.PatientID],
                            *image.ReferencedFileID,
                        ),
                        stop_before_pixels=True,
                    ),
                    images,
                )
            )

            if len(images) == 0:
                continue

            dicom_images.append(images[0])
    return dicom_images


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Prepare a dataset from DICOM directories",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )
    parser.add_argument(
        "dicom_dirs",
        type=str,
        nargs="+",
        help="paths to DICOMDIR files or directories containing one of them",
    )
    parser.add_argument(
        "--dataset_dir",
        type=str,
        required=True,
        help="directory where images, meta-information and the logs are stored",
    )
    parser.add_argument(
        "--images_dir",
        type=str,
        default="images",
        help="sub-directory of --dataset_dir where images are stored",
    )
    parser.add_argument(
        "--meta_csv",
        type=str,
        default="meta.csv",
        help="CSV file under --dataset_dir where meta-information is stored",
    )
    add_logging_args(
        parser, defaults={"--logfile": "prepare.log", "--loglevel": "DEBUG"}
    )
    args = parser.parse_args()

    args.dicom_dirs = [
        (os.path.dirname(_file) if os.path.isfile(_file) else _file)
        for _file in args.dicom_dirs
    ]
    args.images_dir = os.path.join(args.dataset_dir, args.images_dir)
    args.meta_csv = os.path.join(args.dataset_dir, args.meta_csv)
    args.logfile = os.path.join(args.dataset_dir, args.logfile)

    if not os.path.exists(args.dataset_dir):
        os.mkdir(args.dataset_dir)

    if not os.path.exists(args.images_dir):
        os.mkdir(args.images_dir)

    global logger
    logger = get_logger_by_args(args)

    try:
        patients = []
        dicom_dir_by_patient: Dict[str, str] = {}
        for dicom_dir in args.dicom_dirs:
            dataset = pydicom.dcmread(os.path.join(dicom_dir, "DICOMDIR"))
            for patient in dataset.patient_records:
                assert (
                    patient.PatientID not in dicom_dir_by_patient
                ), f"Patient {patient.PatientID} is contained twice in the given DICOM directories ({dicom_dir} and {dicom_dir_by_patient[patient.PatientID]})"
                dicom_dir_by_patient[patient.PatientID] = dicom_dir
                patients.append(patient)

        _id = 1
        if os.path.exists(args.meta_csv):
            data = pd.read_csv(args.meta_csv)
            _id = int(data[headers.id].max())
        else:
            data = pd.DataFrame(dict([(key, []) for key in vars(headers).keys()]))

        for i, patient in enumerate(patients, start=1):
            logger.debug(
                f"Process patient {str(i):>3}/{len(patients)} - {patient.PatientName.given_name}, {patient.PatientName.family_name}:"
            )

            dicom_images = get_relevant_images(
                patient, dicom_dir_by_patient[patient.PatientID]
            )

            projections = find_projections(dicom_images)
            logger.debug(
                f"- Found {len(projections)}: projections with protocols {list(map(lambda p: get_protocol(p), projections))}"
            )

            for projection in projections:
                protocol = get_protocol(projection)

                reconstructions = []
                recon_headers = [
                    headers.file_recon_nac_nsc,
                    headers.file_recon_nac_sc,
                    headers.file_recon_ac_nsc,
                    headers.file_recon_ac_sc,
                ]
                recon_postfixes = [
                    "recon_nac_nsc",
                    "recon_nac_sc",
                    "recon_ac_nsc",
                    "recon_ac_sc",
                ]
                for ac, sc in [
                    (False, False),
                    (False, True),
                    (True, False),
                    (True, True),
                ]:
                    recon = find_reconstruction(
                        dicom_images,
                        projection,
                        scatter_corrected=sc,
                        attenuation_corrected=ac,
                    )
                    reconstructions.append(recon)
                mu_map = find_attenuation_map(dicom_images, projection, reconstructions)

                # extract pixel spacings and assert that they are equal for all reconstruction images
                _map_lists = map(
                    lambda image: [*image.PixelSpacing, image.SliceThickness],
                    [*reconstructions, mu_map],
                )
                _map_lists = map(
                    lambda pixel_spacing: list(map(float, pixel_spacing)),
                    _map_lists,
                )
                _map_ndarrays = map(
                    lambda pixel_spacing: np.array(pixel_spacing), _map_lists
                )
                pixel_spacings = list(_map_ndarrays)
                _equal = all(
                    map(
                        lambda pixel_spacing: (
                            pixel_spacing == pixel_spacings[0]
                        ).all(),
                        pixel_spacings,
                    )
                )
                assert (
                    _equal
                ), f"Not all pixel spacings of the reconstructions are equal: {pixel_spacings}"
                pixel_spacing = pixel_spacings[0]

                # use the shape with the fewest slices, all other images will be aligned to that
                _map_lists = map(
                    lambda image: [
                        image.Rows,
                        image.Columns,
                        image.NumberOfSlices,
                    ],
                    [*reconstructions, mu_map],
                )
                _map_lists = map(lambda shape: list(map(int, shape)), _map_lists)
                _map_ndarrays = map(lambda shape: np.array(shape), _map_lists)
                shapes = list(_map_ndarrays)
                shapes.sort(key=lambda shape: shape[2])
                shape = shapes[0]

                # extract and sort energy windows
                energy_windows = projection.EnergyWindowInformationSequence
                energy_windows = map(
                    lambda ew: ew.EnergyWindowRangeSequence[0], energy_windows
                )
                energy_windows = map(
                    lambda ew: (
                        float(ew.EnergyWindowLowerLimit),
                        float(ew.EnergyWindowUpperLimit),
                    ),
                    energy_windows,
                )
                energy_windows = list(energy_windows)
                energy_windows.sort(key=lambda ew: ew[0], reverse=True)

                row = {
                    headers.id: _id,
                    headers.patient_id: projection.PatientID,
                    headers.age: parse_age(projection.PatientAge),
                    headers.sex: projection.PatientSex,
                    headers.weight: float(projection.PatientWeight),
                    headers.size: float(projection.PatientSize),
                    headers.patient_position: projection.PatientPosition,
                    headers.protocol: protocol,
                    headers.datetime_acquisition: DICOMTime.Series.to_datetime(
                        projection
                    ),
                    headers.datetime_reconstruction: DICOMTime.Series.to_datetime(
                        reconstructions[0]
                    ),
                    headers.pixel_spacing_x: pixel_spacing[0],
                    headers.pixel_spacing_y: pixel_spacing[1],
                    headers.pixel_spacing_z: pixel_spacing[2],
                    headers.shape_x: shape[0],
                    headers.shape_y: shape[1],
                    headers.shape_z: shape[2],
                    headers.radiopharmaceutical: projection.RadiopharmaceuticalInformationSequence[
                        0
                    ].Radiopharmaceutical,
                    headers.radionuclide_dose: projection.RadiopharmaceuticalInformationSequence[
                        0
                    ].RadionuclideTotalDose,
                    headers.radionuclide_code: projection.RadiopharmaceuticalInformationSequence[
                        0
                    ]
                    .RadionuclideCodeSequence[0]
                    .CodeValue,
                    headers.radionuclide_meaning: projection.RadiopharmaceuticalInformationSequence[
                        0
                    ]
                    .RadionuclideCodeSequence[0]
                    .CodeMeaning,
                    headers.energy_window_peak_lower: energy_windows[0][0],
                    headers.energy_window_peak_upper: energy_windows[0][1],
                    headers.energy_window_scatter_lower: energy_windows[1][0],
                    headers.energy_window_scatter_upper: energy_windows[1][1],
                    headers.detector_count: len(projection.DetectorInformationSequence),
                    headers.collimator_type: projection.DetectorInformationSequence[
                        0
                    ].CollimatorType,
                    headers.rotation_start: float(
                        projection.RotationInformationSequence[0].StartAngle
                    ),
                    headers.rotation_step: float(
                        projection.RotationInformationSequence[0].AngularStep
                    ),
                    headers.rotation_scan_arc: float(
                        projection.RotationInformationSequence[0].ScanArc
                    ),
                }

                _filename_base = f"{_id:04d}-{protocol.lower()}"
                _ext = "dcm"
                img_headers = [
                    headers.file_projection,
                    *recon_headers,
                    headers.file_mu_map,
                ]
                img_postfixes = ["projection", *recon_postfixes, "mu_map"]
                images = [projection, *reconstructions, mu_map]
                for header, image, postfix in zip(img_headers, images, img_postfixes):
                    _image = pydicom.dcmread(image.filename)
                    filename = f"{_filename_base}-{postfix}.{_ext}"
                    pydicom.dcmwrite(os.path.join(args.images_dir, filename), _image)
                    row[header] = filename

                _id += 1
                row = pd.DataFrame(row, index=[0])
                data = pd.concat((data, row), ignore_index=True)

        data.to_csv(args.meta_csv, index=False)
    except Exception as e:
        logger.error(e)
        raise e