import argparse
import os
from typing import List

from mu_map.file.dicom import dcm_type


def get_files_recursive(_dir: str) -> List[str]:
    """
    Recursively get all files in a directory. This means that all
    sub-directories are recursively searched until a file is reached.

    :param _dir: the directory of which files are listed
    :return: a list of files in the directory and its sub-directories
    """
    files: List[str] = []
    for _file in os.listdir(_dir):
        _file = os.path.join(_dir, _file)

        if os.path.isdir(_file):
            files.extend(get_files_recursive(_file))
        else:
            files.append(_file)
    return files


def is_scatter_corrected(dcm: dcm_type) -> bool:
    return not ("NoSC" in dcm.SeriesDescription)


def is_attenuation_corrected(dcm: dcm_type) -> bool:
    return not ("NoAC" in dcm.SeriesDescription)


def get_type(dcm: dcm_type) -> str:
    description = dcm.SeriesDescription.lower()
    if "syn" in description:
        return "synthetic"
    elif "ct" in description:
        return "ct"
    else:
        return "symbia"


headers = argparse.Namespace()
headers.id = "id"
headers.scatter_correction = "scatter_correction"
headers.sc = headers.scatter_correction
headers.attenuation_correction = "attenuation_correction"
headers.ac = headers.attenuation_correction
headers.mu_map_type = "mu_map_type"
headers.type = headers.mu_map_type
headers.file = "file"


if __name__ == "__main__":
    import cv2 as cv
    import numpy as np
    import pandas as pd

    from mu_map.data.prepare import get_protocol
    from mu_map.data.prepare import headers as meta_headers
    from mu_map.file.dicom import load_dcm, DCM_TAG_PIXEL_SCALE_FACTOR, dcm_type
    from mu_map.logging import add_logging_args, get_logger_by_args

    parser = argparse.ArgumentParser(
        description="Store/Sort in polar maps for further processing",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )
    parser.add_argument(
        "--polar_map_dir",
        type=str,
        required=True,
        help="directory where the raw DICOM files of polar maps are stored",
    )
    parser.add_argument(
        "--out_dir", type=str, required=True, help="directory where files are output to"
    )
    parser.add_argument(
        "--images_dir",
        type=str,
        default="images",
        help="directory under <out_dir> where images of polar maps are stored",
    )
    parser.add_argument(
        "--csv",
        type=str,
        default="polar_maps.csv",
        help="file unter <out_dir> where meta information is stored",
    )
    parser.add_argument(
        "--meta_csv",
        type=str,
        required=True,
        help="the csv file containing meta information of the dataset of which the generated polar maps are",
    )
    parser.add_argument(
        "--rect",
        type=int,
        nargs=4,
        default=[21, 598, 425, 1002],
        help="rectangle as [top, left, bottom, right] coordinates where the polar map is found in the DICOM images",
    )
    parser.add_argument(
        "-i",
        "--interactive",
        action="store_true",
        help="start interactive mode where every polar map is displayed and confirmation is queried",
    )
    add_logging_args(parser, defaults={"--logfile": "prepare.log"})
    args = parser.parse_args()

    args.images_dir = os.path.join(args.out_dir, args.images_dir)
    args.csv = os.path.join(args.out_dir, args.csv)
    args.logfile = os.path.join(args.out_dir, args.logfile)
    args.loglevel = "DEBUG" if args.interactive else args.loglevel

    if args.interactive:
        print(
            """
    Start in interactive mode.
    This mode automatically sets the loglevel to debug so that you can see what is going in.
    Use the following keys on the displayed polar map:
        
        n(ext): go to the next image without saving the polar map
        s(ave): save the polar map and go to the next
        q(uit): quit the application

    """
        )

    logger = get_logger_by_args(args)
    logger.info(args)

    if not os.path.exists(args.out_dir):
        os.mkdir(args.out_dir)
    if not os.path.exists(args.images_dir):
        os.mkdir(args.images_dir)

    meta = pd.read_csv(args.meta_csv)
    dcm_files = sorted(get_files_recursive(args.polar_map_dir))

    data = pd.DataFrame(
        {
            headers.id: [],
            headers.sc: [],
            headers.ac: [],
            headers.type: [],
            headers.file: [],
        }
    )
    data[headers.id] = data[headers.sc].astype(int)
    data[headers.sc] = data[headers.sc].astype(bool)
    data[headers.ac] = data[headers.ac].astype(bool)
    data[headers.type] = data[headers.ac].astype(str)
    data[headers.file] = data[headers.ac].astype(str)
    for dcm_file in dcm_files:
        logger.debug(f"Process file {dcm_file}")

        dcm, img = load_dcm(dcm_file)

        protocol = get_protocol(dcm)
        patient_id = int(dcm.PatientID)
        meta_row = meta[
            (meta[meta_headers.patient_id] == patient_id)
            & (meta[meta_headers.protocol] == protocol)
        ].iloc[0]

        row = {
            headers.id: meta_row[meta_headers.id],
            headers.sc: is_scatter_corrected(dcm),
            headers.ac: is_attenuation_corrected(dcm),
            headers.type: get_type(dcm),
        }

        top, left, bottom, right = args.rect
        polar_map = img[top:bottom, left:right]
        polar_map = cv.cvtColor(polar_map, cv.COLOR_RGB2BGR)

        _ac = "ac" if row[headers.ac] else "nac"
        _sc = "sc" if row[headers.sc] else "nsc"
        _file = f"{row[headers.id]:04d}-{_ac}_{_sc}_{row[headers.type]}.png"
        row[headers.file] = _file
        _file = os.path.join(args.images_dir, _file)

        if (
            len(
                data[
                    (data[headers.id] == row[headers.id])
                    & (data[headers.sc] == row[headers.sc])
                    & (data[headers.ac] == row[headers.ac])
                    & (data[headers.type] == row[headers.type])
                ]
            )
            > 0
        ):
            logger.warning(f"Skip {dcm_file} as it is a duplicate for row: {row}")
            continue

        logger.debug(f"For series {dcm.SeriesDescription} store row {row}")

        store = not args.interactive
        if args.interactive:
            while True:
                cv.imshow("Polar Map", cv.resize(polar_map, (512, 512)))
                key = cv.waitKey(0)

                if key == ord("q"):
                    exit(0)
                elif key == ord("n"):
                    break
                elif key == ord("s"):
                    store = True
                    break

        if store:
            cv.imwrite(_file, polar_map)

            row = pd.DataFrame(row, index=[0])
            data = pd.concat((data, row), ignore_index=True)

            data = data.sort_values(by=[headers.id, headers.file])
            data.to_csv(args.csv, index=False)