import os
from typing import List, Optional, Tuple

import cv2 as cv
import pandas as pd
import numpy as np
import torch
from torch.utils.data import Dataset

from mu_map.data.prepare import headers
from mu_map.data.remove_bed import (
    DEFAULT_BED_CONTOURS_FILENAME,
    load_contours,
    remove_bed,
)
from mu_map.data.review_mu_map import discard_slices
from mu_map.data.split import split_csv
from mu_map.dataset.transform import Transform
from mu_map.dataset.util import align_images
from mu_map.file.dicom import load_dcm_img
from mu_map.logging import get_logger


class MuMapDataset(Dataset):
    """
    A dataset to map reconstructions to attenuation maps (mu maps).

    The dataset is lazy. This means that that dataset creation is
    fast and images are only read into memory when their first accessed.
    Thus, after the first iteration, accessing images becomes a lot
    faster.
    """
    def __init__(
        self,
        dataset_dir: str,
        csv_file: str = "meta.csv",
        split_file: str = "split.csv",
        split_name: Optional[str] = None,
        images_dir: str = "images",
        bed_contours_file: Optional[str] = DEFAULT_BED_CONTOURS_FILENAME,
        discard_mu_map_slices: bool = True,
        align: bool = True,
        scatter_correction: bool = False,
        transform_normalization: Transform = Transform(),
        transform_augmentation: Transform = Transform(),
        logger=None,
    ):
        """
        Create a new mu map dataset.

        Parameters
        ----------
        dataset_dir: str
            directory of the dataset to be loaded
        csv_file: str
            name of the csv file in the dataset directory containing meta information (created by mu_map.data.prepare)
        split_file: str
            csv file defining a split of the dataset in train/validation/test (created by mu_map.data.split)
        split_name: str, optional
            the name of the split which is loaded
        images_dir: str
            directory under `dataset_dir` containing the actual images in DICOM format
        bed_contours_file: str, optional
            json file containing contours around the bed for each mu map (see mu_map.data.remove_bed)
        discard_mu_map_slices: bool
            remove defective slices from mu maps (have to be labeled by mu_map.data.review_mu_map)
        align: bool
            center align reconstructions and mu maps
        scatter_correction: bool
            use scatter corrected reconstructions
        transform_normalization: Transform
            transform used for image normalization which is applied once when the image is loaded
        transform_augmentation: Transform
            transform used for augmentation which is applied every time `__getitem__` is called
        logger: Logger, optional 
        """
        super().__init__()

        self.dir = dataset_dir
        self.dir_images = os.path.join(dataset_dir, images_dir)
        self.csv_file = os.path.join(dataset_dir, csv_file)
        self.split_file = os.path.join(dataset_dir, split_file)

        self.transform_normalization = transform_normalization
        self.transform_augmentation = transform_augmentation
        self.logger = (
            logger if logger is not None else get_logger(name=MuMapDataset.__name__)
        )

        self.bed_contours_file = (
            os.path.join(dataset_dir, bed_contours_file) if bed_contours_file else None
        )
        self.bed_contours = (
            load_contours(self.bed_contours_file) if bed_contours_file else None
        )

        # read CSV file and from that access DICOM files
        self.split_name = split_name
        self.table = pd.read_csv(self.csv_file)
        if split_name:
            self.table = split_csv(self.table, self.split_file)[split_name]
        self.table[headers.id] = self.table[headers.id].apply(int)

        self.discard_mu_map_slices = discard_mu_map_slices
        self.align = align
        self.scatter_correction = scatter_correction
        self.header_recon = (
            headers.file_recon_nac_sc
            if self.scatter_correction
            else headers.file_recon_nac_nsc
        )

        self.reconstructions = {}
        self.mu_maps = {}

    def copy(self, split_name: str, **kwargs) -> MuMapDataset:
        """
        Create a copy of the dataset and modify parameters.

        Parameters
        ----------
        split_name: str
            the split which with which the copy is created
        kwargs:
            Modify parameters by name.
            Currently, only `bed_contours_file` is supported.
        """
        if "bed_contours_file" not in kwargs:
            kwargs["bed_contours_file"] = os.path.basename(self.bed_contours_file)

        return MuMapDataset(
            dataset_dir=self.dir,
            csv_file=os.path.basename(self.csv_file),
            split_file=os.path.basename(self.split_file),
            split_name=split_name,
            images_dir=os.path.basename(self.dir_images),
            bed_contours_file=kwargs["bed_contours_file"],
            discard_mu_map_slices=self.discard_mu_map_slices,
            align=self.align,
            scatter_correction=self.scatter_correction,
            transform_normalization=self.transform_normalization,
            transform_augmentation=self.transform_augmentation,
            logger=self.logger,
        )

    def load_image(self, _id: int):
        """
        Load an image into memory.

        This function also performs all of the pre-processing (discard slices, remove bed, alignment ...).
        Afterwards, the reconstruction and mu map are available from the local
        dicts: `self.reconstructions` and `self.mu_maps.`

        Parameters
        ----------
        _id: int
            the id of the image to be loaded
        """
        row = self.table[self.table[headers.id] == _id].iloc[0]
        _id = row[headers.id]

        mu_map_file = os.path.join(self.dir_images, row[headers.file_mu_map])
        mu_map = load_dcm_img(mu_map_file, direction=1)

        recon_file = os.path.join(self.dir_images, row[self.header_recon])
        recon = load_dcm_img(recon_file, direction=1)

        if self.discard_mu_map_slices:
            mu_map, recon = discard_slices(row, mu_map, recon)
        if self.bed_contours:
            if _id in self.bed_contours:
                bed_contour = self.bed_contours[_id]
                mu_map = remove_bed(mu_map, bed_contour)
            else:
                logger.warning(f"Could not find bed contour for id {_id}")

        if self.align:
            recon, mu_map = align_images(recon, mu_map)

        mu_map = mu_map.astype(np.float32)
        mu_map = torch.from_numpy(mu_map)
        mu_map = mu_map.unsqueeze(dim=0)

        recon = recon.astype(np.float32)
        recon = torch.from_numpy(recon)
        recon = recon.unsqueeze(dim=0)

        recon, mu_map = self.transform_normalization(recon, mu_map)

        self.mu_maps[_id] = mu_map
        self.reconstructions[_id] = recon

    def pre_load_images(self):
        """
        Load all images into memory.
        """
        for _id in self.table[headers.id]:
            self.load_image(_id)

    def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Get a reconstruction mu map pair by index.

        This method retrieves the image id of this index and call `get_item_by_id`.

        Parameters
        ----------
        index: int
        """
        row = self.table.iloc[index]
        _id = row[headers.id]
        return self.get_item_by_id(_id)

    def get_item_by_id(self, _id: int) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Get a reconstruction and mu map pair by their id.

        This methods loads the images of not yet in memory and applies the
        augmentation transform before returning them.

        Parameters
        ----------
        _id: int

        Returns
        -------
        Tuple[torch.Tensor, torch.Tensor]
            a pair of a reconstruction and the according mu map
        """
        if _id not in self.reconstructions:
            self.load_image(_id)

        recon = self.reconstructions[_id]
        mu_map = self.mu_maps[_id]

        recon, mu_map = self.transform_augmentation(recon, mu_map)

        return recon, mu_map

    def __len__(self) -> int:
        """
        Get the number of elements in this dataset.
        """
        return len(self.table)


__all__ = [MuMapDataset.__name__]


def main(dataset: MuMapDataset, ids: Optional[List[int]] = None, paused: bool=False):
    """
    Display reconstructions and mu maps in a dataset.

    Parameters
    ----------
    dataset: MuMapDataset
        the dataset of which elements are displayed
    ids: list of int, optional
        only display these ids
    paused: bool
        start display in paused mode 
    """
    from mu_map.util import to_grayscale, COLOR_WHITE

    wname = "Dataset"
    cv.namedWindow(wname, cv.WINDOW_NORMAL)
    cv.resizeWindow(wname, 1600, 900)
    space = np.full((1024, 10), 239, np.uint8)

    timeout_paused = 0
    timeout_running = 1000 // 15

    timeout = timeout_paused if paused else timeout_running

    def to_display_image(image, _slice):
        _image = to_grayscale(image[_slice], min_val=image.min(), max_val=image.max())
        _image = cv.resize(_image, (1024, 1024), cv.INTER_AREA)
        _text = f"{str(_slice):>{len(str(image.shape[0]))}}/{str(image.shape[0])}"
        _image = cv.putText(
            _image, _text, (0, 30), cv.FONT_HERSHEY_SIMPLEX, 1, COLOR_WHITE, 3
        )
        return _image

    def combine_images(images, slices):
        image_1 = to_display_image(images[0], slices[0])
        image_2 = to_display_image(images[1], slices[1])

        image_1 = image_1.repeat(3).reshape((*image_1.shape, 3))
        image_2 = image_2.repeat(3).reshape((*image_2.shape, 3))

        image_3_2 = cv.applyColorMap(image_1, cv.COLORMAP_INFERNO)
        image_3_1 = image_2.copy()
        image_3 = cv.addWeighted(image_3_1, 0.8, image_3_2, 0.4, 0.0)

        space = np.full((image_1.shape[0], 10, 3), 239, np.uint8)
        return np.hstack((image_1, space, image_3, space, image_2))

    for i in range(len(dataset)):
        ir = 0
        im = 0

        row = dataset.table.iloc[i]
        _id = row[headers.id]

        if ids is not None and _id not in ids:
            continue

        recon, mu_map = dataset[i]
        recon = recon.squeeze().numpy()
        mu_map = mu_map.squeeze().numpy()
        print(f"{(i+1):>{len(str(len(dataset)))}}/{len(dataset)} - ID: {_id}", end="\r")

        cv.imshow(wname, combine_images((recon, mu_map), (ir, im)))
        key = cv.waitKey(timeout)

        running = 0
        while True:
            ir = (ir + 1) % recon.shape[0]
            im = (im + 1) % mu_map.shape[0]

            to_show = combine_images((recon, mu_map), (ir, im))
            cv.imshow(wname, to_show)
            key = cv.waitKey(timeout)

            if key == ord("n"):
                break
            elif key == ord("q"):
                exit(0)
            elif key == ord("p"):
                timeout = timeout_paused if timeout > 0 else timeout_running
            elif key == 82:  # up arrow key
                ir = ir - 1
                continue
            elif key == 83:  # right arrow key
                im = im - 1
                continue
            elif key == 81:  # left arrow key
                im = im - 1
                ir = max(ir - 2, 0)
            elif key == 84:  # down arrow key
                ir = ir - 1
                im = max(im - 2, 0)
            elif key == ord("s"):
                cv.imwrite(f"{running:03d}.png", to_show)
                running += 1


if __name__ == "__main__":
    import argparse

    from mu_map.dataset.transform import PadCropTranform
    from mu_map.logging import add_logging_args, get_logger_by_args

    parser = argparse.ArgumentParser(
        description="Visualize the images of a MuMapDataset",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )
    parser.add_argument(
        "dataset_dir", type=str, help="the directory from which the dataset is loaded"
    )
    parser.add_argument(
        "--split",
        type=str,
        choices=["train", "validation", "test"],
        help="choose the split of the data for the dataset",
    )
    parser.add_argument(
        "--unaligned",
        action="store_true",
        help="do not perform center alignment of reconstruction an mu-map slices",
    )
    parser.add_argument(
        "--show_bed",
        action="store_true",
        help="do not remove the bed contour from the mu map",
    )
    parser.add_argument(
        "--full_mu_map",
        action="store_true",
        help="do not remove broken slices of the mu map",
    )
    parser.add_argument(
        "--ids",
        type=int,
        nargs="*",
        help="only display certain ids",
    )
    parser.add_argument(
        "--paused",
        action="store_true",
        help="start in paused mode",
    )
    parser.add_argument(
        "--pad_crop",
        type=int,
        help="pad crop images to this size",
    )
    add_logging_args(parser, defaults={"--loglevel": "DEBUG"})
    args = parser.parse_args()

    align = not args.unaligned
    discard_mu_map_slices = not args.full_mu_map
    bed_contours_file = None if args.show_bed else DEFAULT_BED_CONTOURS_FILENAME
    logger = get_logger_by_args(args)

    transform_normalization = (
        PadCropTranform(dim=3, size=args.pad_crop) if args.pad_crop else Transform()
    )

    dataset = MuMapDataset(
        args.dataset_dir,
        align=align,
        discard_mu_map_slices=discard_mu_map_slices,
        bed_contours_file=bed_contours_file,
        split_name=args.split,
        transform_normalization=transform_normalization,
        logger=logger,
    )
    main(dataset, args.ids, paused=args.paused)