import argparse
from datetime import datetime, timedelta
from enum import Enum
import os
from typing import List, Dict

import numpy as np
import pandas as pd
import pydicom

from mu_map.logging import add_logging_args, get_logger_by_args


STUDY_DESCRIPTION = "µ-map_study"


class MyocardialProtocol(Enum):
    Stress = 1
    Rest = 2


headers = argparse.Namespace()
headers.id = "id"
headers.patient_id = "patient_id"
headers.age = "age"
headers.weight = "weight"
headers.size = "size"
headers.protocol = "protocol"
headers.datetime_acquisition = "datetime_acquisition"
headers.datetime_reconstruction = "datetime_reconstruction"
headers.pixel_spacing_x = "pixel_spacing_x"
headers.pixel_spacing_y = "pixel_spacing_y"
headers.pixel_spacing_z = "pixel_spacing_z"
headers.shape_x = "shape_x"
headers.shape_y = "shape_y"
headers.shape_z = "shape_z"
headers.radiopharmaceutical = "radiopharmaceutical"
headers.radionuclide_dose = "radionuclide_dose"
headers.radionuclide_code = "radionuclide_code"
headers.radionuclide_meaning = "radionuclide_meaning"
headers.energy_window_peak_lower = "energy_window_peak_lower"
headers.energy_window_peak_upper = "energy_window_peak_upper"
headers.energy_window_scatter_lower = "energy_window_scatter_lower"
headers.energy_window_scatter_upper = "energy_window_scatter_upper"
headers.detector_count = "detector_count"
headers.collimator_type = "collimator_type"
headers.rotation_start = "rotation_start"
headers.rotation_step = "rotation_step"
headers.rotation_scan_arc = "rotation_scan_arc"
headers.file_projection = "file_projection"
headers.file_recon_ac = "file_recon_ac"
headers.file_recon_no_ac = "file_recon_no_ac"
headers.file_mu_map = "file_mu_map"


def parse_series_time(dicom_image: pydicom.dataset.FileDataset) -> datetime:
    """
    Parse the date and time of a DICOM series object into a datetime object.

    :param dicom_image: the dicom file to parse the series date and time from
    :return: an according python datetime object.
    """
    _date = dicom_image.SeriesDate
    _time = dicom_image.SeriesTime
    return datetime(
        year=int(_date[0:4]),
        month=int(_date[4:6]),
        day=int(_date[6:8]),
        hour=int(_time[0:2]),
        minute=int(_time[2:4]),
        second=int(_time[4:6]),
        microsecond=int(_time.split(".")[1]),
    )


def parse_age(patient_age: str) -> int:
    """
    Parse and age string as defined in the DICOM standard into an integer representing the age in years.

    :param patient_age: age string as defined in the DICOM standard
    :return: the age in years as a number
    """
    assert (
        type(patient_age) == str
    ), f"patient age needs to be a string and not {type(patient_age)}"
    assert (
        len(patient_age) == 4
    ), f"patient age [{patient_age}] has to be four characters long"
    _num, _format = patient_age[:3], patient_age[3]
    assert (
        _format == "Y"
    ), f"currently, only patient ages in years [Y] is supported, not [{_format}]"
    return int(_num)


def get_projection(
    dicom_images: List[pydicom.dataset.FileDataset], protocol: MyocardialProtocol
) -> pydicom.dataset.FileDataset:
    """
    Extract the SPECT projection from a list of DICOM images belonging to a myocardial scintigraphy study given a study protocol.

    :param dicom_images: list of DICOM images of a study
    :param protocol: the protocol for which the projection images should be extracted
    :return: the extracted DICOM image
    """
    _filter = filter(lambda image: "TOMO" in image.ImageType, dicom_images)
    _filter = filter(lambda image: protocol.name in image.SeriesDescription, _filter)
    dicom_images = list(_filter)

    if len(dicom_images) != 1:
        raise ValueError(
            f"No or multiple projections {len(dicom_images)} for protocol {protocol.name} available"
        )

    return dicom_images[0]


def get_reconstruction(
    dicom_images: List[pydicom.dataset.FileDataset],
    protocol: MyocardialProtocol,
    corrected: bool = True,
) -> pydicom.dataset.FileDataset:
    """
    Extract a SPECT reconstruction from a list of DICOM images belonging to a myocardial scintigraphy study given a study protocol.
    The corrected flag can be used to either extract an attenuation corrected or a non-attenuation corrected image.
    If there are multiple images, they are sorted by acquisition date and the newest is returned.

    :param dicom_images: list of DICOM images of a study
    :param protocol: the protocol for which the projection images should be extracted
    :param corrected: extract an attenuation or non-attenuation corrected image
    :return: the extracted DICOM image
    """
    _filter = filter(lambda image: "RECON TOMO" in image.ImageType, dicom_images)
    _filter = filter(
        lambda image: protocol.name in image.SeriesDescription, _filter
    )
    _filter = filter(
        lambda image: STUDY_DESCRIPTION in image.SeriesDescription, _filter
    )

    if corrected:
        _filter = filter(
            lambda image: "AC" in image.SeriesDescription
            and "NoAC" not in image.SeriesDescription,
            _filter,
        )
    else:
        _filter = filter(lambda image: "NoAC" in image.SeriesDescription, _filter)

    # for SPECT reconstructions created in clinical studies this value exists and is set to 'APEX_TO_BASE'
    # for the reconstructions with attenuation maps it does not exist
    _filter = filter(
        lambda image: not hasattr(image, "SliceProgressionDirection"), _filter
    )

    dicom_images = list(_filter)
    dicom_images.sort(key=lambda image: parse_series_time(image), reverse=True)

    if len(dicom_images) == 0:
        _str = "AC" if corrected else "NoAC"
        raise ValueError(
            f"{_str} Reconstruction for protocol {protocol.name} is not available"
        )

    return dicom_images[0]


def get_attenuation_map(
    dicom_images: List[pydicom.dataset.FileDataset], protocol: MyocardialProtocol
) -> pydicom.dataset.FileDataset:
    """
    Extract an attenuation map from a list of DICOM images belonging to a myocardial scintigraphy study given a study protocol.
    If there are multiple attenuation maps, they are sorted by acquisition date and the newest is returned.

    :param dicom_images: list of DICOM images of a study
    :param protocol: the protocol for which the projection images should be extracted
    :return: the extracted DICOM image
    """
    _filter = filter(lambda image: "RECON TOMO" in image.ImageType, dicom_images)
    _filter = filter(
        lambda image: protocol.name in image.SeriesDescription, _filter
    )
    _filter = filter(lambda image: STUDY_DESCRIPTION in image.SeriesDescription, _filter)
    _filter = filter(lambda image: " µ-map]" in image.SeriesDescription, _filter)
    dicom_images = list(_filter)
    dicom_images.sort(key=lambda image: parse_series_time(image), reverse=True)

    if len(dicom_images) == 0:
        raise ValueError(
            f"Attenuation map for protocol {protocol.name} is not available"
        )

    return dicom_images[0]


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Prepare a dataset from DICOM directories",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )
    parser.add_argument(
        "dicom_dirs",
        type=str,
        nargs="+",
        help="paths to DICOMDIR files or directories containing one of them",
    )
    parser.add_argument(
        "--dataset_dir",
        type=str,
        required=True,
        help="directory where images, meta-information and the logs are stored",
    )
    parser.add_argument(
        "--images_dir",
        type=str,
        default="images",
        help="sub-directory of --dataset_dir where images are stored",
    )
    parser.add_argument(
        "--meta_csv",
        type=str,
        default="meta.csv",
        help="CSV file under --dataset_dir where meta-information is stored",
    )
    parser.add_argument(
        "--prefix_projection",
        type=str,
        default="projection",
        help="prefix used to store DICOM images of projections - format <id>-<protocol>-<prefix>.dcm",
    )
    parser.add_argument(
        "--prefix_mu_map",
        type=str,
        default="mu_map",
        help="prefix used to store DICOM images of attenuation maps - format <id>-<protocol>-<prefix>.dcm",
    )
    parser.add_argument(
        "--prefix_recon_ac",
        type=str,
        default="recon_ac",
        help="prefix used to store DICOM images of reconstructions with attenuation correction - format <id>-<protocol>-<prefix>.dcm",
    )
    parser.add_argument(
        "--prefix_recon_no_ac",
        type=str,
        default="recon_no_ac",
        help="prefix used to store DICOM images of reconstructions without attenuation correction - format <id>-<protocol>-<prefix>.dcm",
    )
    add_logging_args(
        parser, defaults={"--logfile": "prepare.log", "--loglevel": "DEBUG"}
    )
    args = parser.parse_args()

    args.dicom_dirs = [
        (os.path.dirname(_file) if os.path.isfile(_file) else _file)
        for _file in args.dicom_dirs
    ]
    args.images_dir = os.path.join(args.dataset_dir, args.images_dir)
    args.meta_csv = os.path.join(args.dataset_dir, args.meta_csv)
    args.logfile = os.path.join(args.dataset_dir, args.logfile)

    if not os.path.exists(args.dataset_dir):
        os.mkdir(args.dataset_dir)

    if not os.path.exists(args.images_dir):
        os.mkdir(args.images_dir)

    global logger
    logger = get_logger_by_args(args)

    try:
        patients = []
        dicom_dir_by_patient: Dict[str, str] = {}
        for dicom_dir in args.dicom_dirs:
            dataset = pydicom.dcmread(os.path.join(dicom_dir, "DICOMDIR"))
            for patient in dataset.patient_records:
                assert (
                    patient.PatientID not in dicom_dir_by_patient
                ), f"Patient {patient.PatientID} is contained twice in the given DICOM directories ({dicom_dir} and {dicom_dir_by_patient[patient.PatientID]})"
                dicom_dir_by_patient[patient.PatientID] = dicom_dir
                patients.append(patient)

        _id = 1
        if os.path.exists(args.meta_csv):
            data = pd.read_csv(args.meta_csv)
            _id = int(data[headers.id].max())
        else:
            data = pd.DataFrame(dict([(key, []) for key in vars(headers).keys()]))

        for i, patient in enumerate(patients, start=1):
            logger.debug(f"Process patient {str(i):>3}/{len(patients)}:")

            # get all myocardial scintigraphy studies
            studies = list(
                filter(
                    lambda child: child.DirectoryRecordType == "STUDY",
                    # and child.StudyDescription == "Myokardszintigraphie",
                    patient.children,
                )
            )

            # extract all dicom images
            dicom_images = []
            for study in studies:
                series = list(
                    filter(
                        lambda child: child.DirectoryRecordType == "SERIES", study.children
                    )
                )
                for _series in series:
                    images = list(
                        filter(
                            lambda child: child.DirectoryRecordType == "IMAGE",
                            _series.children,
                        )
                    )

                    # all SPECT data is stored as a single 3D array which means that it is a series with a single image
                    # this is not the case for CTs, which are skipped here
                    if len(images) != 1:
                        continue

                    images = list(
                        map(
                            lambda image: pydicom.dcmread(
                                os.path.join(
                                    dicom_dir_by_patient[patient.PatientID],
                                    *image.ReferencedFileID,
                                ),
                                stop_before_pixels=True,
                            ),
                            images,
                        )
                    )

                    if len(images) == 0:
                        continue

                    dicom_images.append(images[0])

                for protocol in MyocardialProtocol:
                    if (
                        len(
                            data[
                                (data[headers.patient_id] == patient.PatientID)
                                & (data[headers.protocol] == protocol.name)
                            ]
                        )
                        > 0
                    ):
                        logger.info(
                            f"Skip {patient.PatientID}:{protocol.name} since it is already contained in the dataset"
                        )
                        continue

                    try:
                        projection_image = get_projection(dicom_images, protocol=protocol)
                        recon_ac = get_reconstruction(
                            dicom_images, protocol=protocol, corrected=True
                        )
                        recon_noac = get_reconstruction(
                            dicom_images, protocol=protocol, corrected=False
                        )
                        attenuation_map = get_attenuation_map(
                            dicom_images, protocol=protocol
                        )
                    except ValueError as e:
                        logger.info(f"Skip {patient.PatientID}:{protocol.name} because {e}")
                        continue

                    recon_images = [recon_ac, recon_noac, attenuation_map]

                    # extract date times and assert that they are equal for all reconstruction images
                    datetimes = list(map(parse_series_time, recon_images))
                    _datetimes = sorted(datetimes, reverse=True)
                    _datetimes_delta = list(map(lambda dt: _datetimes[0] - dt, _datetimes))
                    _equal = all(
                        map(lambda dt: dt < timedelta(seconds=300), _datetimes_delta)
                    )
                    assert (
                        _equal
                    ), f"Not all dates and times of the reconstructions are equal: {datetimes}"

                    # extract pixel spacings and assert that they are equal for all reconstruction images
                    _map_lists = map(
                        lambda image: [*image.PixelSpacing, image.SliceThickness],
                        recon_images,
                    )
                    _map_lists = map(
                        lambda pixel_spacing: list(map(float, pixel_spacing)), _map_lists
                    )
                    _map_ndarrays = map(
                        lambda pixel_spacing: np.array(pixel_spacing), _map_lists
                    )
                    pixel_spacings = list(_map_ndarrays)
                    _equal = all(
                        map(
                            lambda pixel_spacing: (
                                pixel_spacing == pixel_spacings[0]
                            ).all(),
                            pixel_spacings,
                        )
                    )
                    assert (
                        _equal
                    ), f"Not all pixel spacings of the reconstructions are equal: {pixel_spacings}"
                    pixel_spacing = pixel_spacings[0]

                    # use the shape with the fewest slices, all other images will be aligned to that
                    _map_lists = map(
                        lambda image: [image.Rows, image.Columns, image.NumberOfSlices],
                        recon_images,
                    )
                    _map_lists = map(lambda shape: list(map(int, shape)), _map_lists)
                    _map_ndarrays = map(lambda shape: np.array(shape), _map_lists)
                    shapes = list(_map_ndarrays)
                    shapes.sort(key=lambda shape: shape[2])
                    shape = shapes[0]

                    # exctract and sort energy windows
                    energy_windows = projection_image.EnergyWindowInformationSequence
                    energy_windows = map(
                        lambda ew: ew.EnergyWindowRangeSequence[0], energy_windows
                    )
                    energy_windows = map(
                        lambda ew: (
                            float(ew.EnergyWindowLowerLimit),
                            float(ew.EnergyWindowUpperLimit),
                        ),
                        energy_windows,
                    )
                    energy_windows = list(energy_windows)
                    energy_windows.sort(key=lambda ew: ew[0], reverse=True)

                    # re-read images with pixel-level data and save accordingly
                    projection_image = pydicom.dcmread(projection_image.filename)
                    recon_ac = pydicom.dcmread(recon_ac.filename)
                    recon_noac = pydicom.dcmread(recon_noac.filename)
                    attenuation_map = pydicom.dcmread(attenuation_map.filename)

                    _filename_base = f"{_id:04d}-{protocol.name.lower()}"
                    _ext = "dcm"
                    _filename_projection = f"{_filename_base}-{args.prefix_projection}.{_ext}"
                    _filename_recon_ac = f"{_filename_base}-{args.prefix_recon_ac}.{_ext}"
                    _filename_recon_no_ac = f"{_filename_base}-{args.prefix_recon_no_ac}.{_ext}"
                    _filename_mu_map = f"{_filename_base}-{args.prefix_mu_map}.{_ext}"
                    pydicom.dcmwrite(
                        os.path.join(
                            args.images_dir,
                            _filename_projection,
                        ),
                        projection_image,
                    )
                    pydicom.dcmwrite(
                        os.path.join(
                            args.images_dir,
                            _filename_recon_ac,
                        ),
                        recon_ac,
                    )
                    pydicom.dcmwrite(
                        os.path.join(
                            args.images_dir,
                            _filename_recon_no_ac,
                        ),
                        recon_noac,
                    )
                    pydicom.dcmwrite(
                        os.path.join(
                            args.images_dir,
                            _filename_mu_map,
                        ),
                        attenuation_map,
                    )

                    row = {
                        headers.id: _id,
                        headers.patient_id: projection_image.PatientID,
                        headers.age: parse_age(projection_image.PatientAge),
                        headers.weight: float(projection_image.PatientWeight),
                        headers.size: float(projection_image.PatientSize),
                        headers.protocol: protocol.name,
                        headers.datetime_acquisition: parse_series_time(projection_image),
                        headers.datetime_reconstruction: datetimes[0],
                        headers.pixel_spacing_x: pixel_spacing[0],
                        headers.pixel_spacing_y: pixel_spacing[1],
                        headers.pixel_spacing_z: pixel_spacing[2],
                        headers.shape_x: shape[0],
                        headers.shape_y: shape[1],
                        headers.shape_z: shape[2],
                        headers.radiopharmaceutical: projection_image.RadiopharmaceuticalInformationSequence[
                            0
                        ].Radiopharmaceutical,
                        headers.radionuclide_dose: projection_image.RadiopharmaceuticalInformationSequence[
                            0
                        ].RadionuclideTotalDose,
                        headers.radionuclide_code: projection_image.RadiopharmaceuticalInformationSequence[
                            0
                        ]
                        .RadionuclideCodeSequence[0]
                        .CodeValue,
                        headers.radionuclide_meaning: projection_image.RadiopharmaceuticalInformationSequence[
                            0
                        ]
                        .RadionuclideCodeSequence[0]
                        .CodeMeaning,
                        headers.energy_window_peak_lower: energy_windows[0][0],
                        headers.energy_window_peak_upper: energy_windows[0][1],
                        headers.energy_window_scatter_lower: energy_windows[1][0],
                        headers.energy_window_scatter_upper: energy_windows[1][1],
                        headers.detector_count: len(
                            projection_image.DetectorInformationSequence
                        ),
                        headers.collimator_type: projection_image.DetectorInformationSequence[
                            0
                        ].CollimatorType,
                        headers.rotation_start: float(
                            projection_image.RotationInformationSequence[0].StartAngle
                        ),
                        headers.rotation_step: float(
                            projection_image.RotationInformationSequence[0].AngularStep
                        ),
                        headers.rotation_scan_arc: float(
                            projection_image.RotationInformationSequence[0].ScanArc
                        ),
                        headers.file_projection: _filename_projection,
                        headers.file_recon_ac: _filename_recon_ac,
                        headers.file_recon_no_ac: _filename_recon_no_ac,
                        headers.file_mu_map: _filename_mu_map,
                    }
                    _id += 1

                    row = pd.DataFrame(row, index=[0])
                    data = pd.concat((data, row), ignore_index=True)

        data.to_csv(args.meta_csv, index=False)
    except Exception as e:
        logger.error(e)
        raise e