import argparse
import os
from typing import List

from mu_map.file.dicom import DICOM, DICOMTime


def get_files_recursive(_dir: str) -> List[str]:
    """
    Recursively get all files in a directory. This means that all
    sub-directories are recursively searched until a file is reached.

    :param _dir: the directory of which files are listed
    :return: a list of files in the directory and its sub-directories
    """
    files: List[str] = []
    for _file in os.listdir(_dir):
        _file = os.path.join(_dir, _file)

        if os.path.isdir(_file):
            files.extend(get_files_recursive(_file))
        else:
            files.append(_file)
    return files


def is_scatter_corrected(dcm: DICOM) -> bool:
    return "SC " in dcm.SeriesDescription


def is_attenuation_corrected(dcm: DICOM) -> bool:
    return not ("NoAC" in dcm.SeriesDescription)


def shows_segments(dcm: DICOM) -> bool:
    return "segment" in dcm.SeriesDescription.lower()


def get_type(dcm: DICOM) -> str:
    description = dcm.SeriesDescription.lower()
    if "DLAC" in description:
        return "dl"
    elif "CTAC" in description:
        return "ct"
    else:
        return "symbia"


headers = argparse.Namespace()
headers.id = "id"
headers.scatter_correction = "scatter_correction"
headers.sc = headers.scatter_correction
headers.attenuation_correction = "attenuation_correction"
headers.ac = headers.attenuation_correction
headers.segments = "segments"
headers.mu_map_type = "mu_map_type"
headers.type = headers.mu_map_type
headers.file = "file"


if __name__ == "__main__":
    import cv2 as cv
    import numpy as np
    import pandas as pd

    from mu_map.data.prepare import get_protocol
    from mu_map.data.prepare import headers as meta_headers
    from mu_map.file.dicom import load_dcm, DCM_TAG_PIXEL_SCALE_FACTOR, DICOM
    from mu_map.logging import add_logging_args, get_logger_by_args

    parser = argparse.ArgumentParser(
        description="Store/Sort in polar maps for further processing",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )
    parser.add_argument(
        "--polar_map_dir",
        type=str,
        required=True,
        help="directory where the raw DICOM files of polar maps are stored",
    )
    parser.add_argument(
        "--out_dir", type=str, required=True, help="directory where files are output to"
    )
    parser.add_argument(
        "--images_dir",
        type=str,
        default="images",
        help="directory under <out_dir> where images of polar maps are stored",
    )
    parser.add_argument(
        "--csv",
        type=str,
        default="polar_maps.csv",
        help="file unter <out_dir> where meta information is stored",
    )
    parser.add_argument(
        "--meta_csv",
        type=str,
        required=True,
        help="the csv file containing meta information of the dataset of which the generated polar maps are",
    )
    parser.add_argument(
        "--rect_upper",
        type=int,
        nargs=4,
        default=[22, 1499, 22 + 588, 1500 + 588],
        help="rectangle as [top, left, bottom, right] coordinates where the polar map is found in the DICOM images",
    )
    parser.add_argument(
        "--rect_lower",
        type=int,
        nargs=4,
        default=[634, 1499, 634 + 588, 1500 + 588],
        help="rectangle as [top, left, bottom, right] coordinates where the polar map is found in the DICOM images",
    )
    parser.add_argument(
        "-i",
        "--interactive",
        action="store_true",
        help="start interactive mode where every polar map is displayed and confirmation is queried",
    )
    add_logging_args(parser, defaults={"--logfile": "prepare.log"})
    args = parser.parse_args()

    args.images_dir = os.path.join(args.out_dir, args.images_dir)
    args.csv = os.path.join(args.out_dir, args.csv)
    args.logfile = os.path.join(args.out_dir, args.logfile)
    args.loglevel = "DEBUG" if args.interactive else args.loglevel

    if args.interactive:
        print(
            """
    Start in interactive mode.
    This mode automatically sets the loglevel to debug so that you can see what is going on.
    Use the following keys on the displayed polar map:
        
        n(ext): go to the next image without saving the polar map
        s(ave): save the polar map and go to the next
        q(uit): quit the application

    """
        )

    logger = get_logger_by_args(args)
    logger.info(args)

    if not os.path.exists(args.out_dir):
        os.mkdir(args.out_dir)
    if not os.path.exists(args.images_dir):
        os.mkdir(args.images_dir)

    meta = pd.read_csv(args.meta_csv)
    meta[meta_headers.datetime_acquisition] = pd.to_datetime(meta[meta_headers.datetime_acquisition])

    dcm_files = sorted(get_files_recursive(args.polar_map_dir))

    data = pd.DataFrame(
        {
            headers.id: [],
            headers.sc: [],
            headers.ac: [],
            headers.segments: [],
            headers.type: [],
            headers.file: [],
        }
    )
    data[headers.id] = data[headers.id].astype(int)
    data[headers.sc] = data[headers.sc].astype(bool)
    data[headers.ac] = data[headers.ac].astype(bool)
    data[headers.segments] = data[headers.segments].astype(bool)
    data[headers.type] = data[headers.type].astype(str)
    data[headers.file] = data[headers.file].astype(str)
    for dcm_file in dcm_files:
        logger.debug(f"Process file {dcm_file}")

        dcm, img = load_dcm(dcm_file)

        protocol = get_protocol(dcm)
        patient_id = int(dcm.PatientID)
        meta_rows = meta[
            (meta[meta_headers.patient_id] == patient_id)
            & (meta[meta_headers.protocol] == protocol)
            & (meta[meta_headers.datetime_acquisition] == DICOMTime.Acquisition.to_datetime(dcm))
        ]
        assert len(meta_rows) == 1
        meta_row = meta_rows.iloc[0]

        row = {
            headers.id: meta_row[meta_headers.id],
            headers.sc: is_scatter_corrected(dcm),
            headers.ac: is_attenuation_corrected(dcm),
            headers.segments: shows_segments(dcm),
        }

        types = dcm.SeriesDescription.split("-")[-1]
        types = types[:-1].strip()
        types = types.split("/")

        for _type, rect in zip(types, [args.rect_upper, args.rect_lower]):
            top, left, bottom, right = rect
            polar_map = img[top:bottom, left:right]
            polar_map = cv.cvtColor(polar_map, cv.COLOR_RGB2BGR)

            row[headers.type] = _type[:2].lower()
            _seg = "segments" if row[headers.segments] else "gray"
            _file = f"{row[headers.id]:04d}-{_seg}_ac_nsc_{row[headers.type]}.png"
            row[headers.file] = _file
            _file = os.path.join(args.images_dir, _file)

            logger.debug(f"For series {dcm.SeriesDescription} store row {row}")

            store = not args.interactive
            if args.interactive:
                while True:
                    cv.imshow("Polar Map", cv.resize(polar_map, (1024, 1024)))
                    key = cv.waitKey(0)

                    if key == ord("q"):
                        exit(0)
                    elif key == ord("n"):
                        break
                    elif key == ord("s"):
                        store = True
                        break

            if store:
                cv.imwrite(_file, polar_map)

                data = pd.concat((data, pd.DataFrame(row, index=[0])), ignore_index=True)
                data = data.sort_values(by=[headers.id, headers.file])
                data.to_csv(args.csv, index=False)