import os
from typing import Optional, Tuple

import cv2 as cv
import pandas as pd
import pydicom
import numpy as np
import torch
from torch.utils.data import Dataset

from mu_map.data.prepare import headers
from mu_map.data.remove_bed import (
    DEFAULT_BED_CONTOURS_FILENAME,
    load_contours,
    remove_bed,
)
from mu_map.data.review_mu_map import discard_slices
from mu_map.data.split import split_csv
from mu_map.dataset.transform import Transform
from mu_map.dataset.util import align_images, load_dcm_img
from mu_map.logging import get_logger


class MuMapDataset(Dataset):
    def __init__(
        self,
        dataset_dir: str,
        csv_file: str = "meta.csv",
        split_file: str = "split.csv",
        split_name: str = None,
        images_dir: str = "images",
        bed_contours_file: Optional[str] = DEFAULT_BED_CONTOURS_FILENAME,
        discard_mu_map_slices: bool = True,
        align: bool = True,
        scatter_correction: bool = False,
        transform_normalization: Transform = Transform(),
        transform_augmentation: Transform = Transform(),
        logger=None,
    ):
        super().__init__()

        self.dir = dataset_dir
        self.dir_images = os.path.join(dataset_dir, images_dir)
        self.csv_file = os.path.join(dataset_dir, csv_file)
        self.split_file = os.path.join(dataset_dir, split_file)

        self.transform_normalization = transform_normalization
        self.transform_augmentation = transform_augmentation
        self.logger = logger if logger is not None else get_logger()

        self.bed_contours_file = (
            os.path.join(dataset_dir, bed_contours_file) if bed_contours_file else None
        )
        self.bed_contours = (
            load_contours(self.bed_contours_file) if bed_contours_file else None
        )

        # read CSV file and from that access DICOM files
        self.table = pd.read_csv(self.csv_file)
        if split_name:
            self.table = split_csv(self.table, self.split_file)[split_name]
        self.table[headers.id] = self.table[headers.id].apply(int)

        self.discard_mu_map_slices = discard_mu_map_slices
        self.align = align
        self.scatter_correction = scatter_correction
        self.header_recon = (
            headers.file_recon_nac_sc
            if self.scatter_correction
            else headers.file_recon_nac_nsc
        )

        self.reconstructions = {}
        self.mu_maps = {}

    def load_image(self, _id: int):
        row = self.table[self.table[headers.id] == _id].iloc[0]
        _id = row[headers.id]

        mu_map_file = os.path.join(self.dir_images, row[headers.file_mu_map])
        mu_map = load_dcm_img(mu_map_file)
        if self.discard_mu_map_slices:
            mu_map = discard_slices(row, mu_map)
        if self.bed_contours:
            if _id in self.bed_contours:
                bed_contour = self.bed_contours[_id]
                mu_map = remove_bed(mu_map, bed_contour)
            else:
                logger.warning(f"Could not find bed contour for id {_id}")

        recon_file = os.path.join(self.dir_images, row[self.header_recon])
        recon = load_dcm_img(recon_file)
        if self.align:
            recon, mu_map = align_images(recon, mu_map)

        mu_map = mu_map.astype(np.float32)
        mu_map = torch.from_numpy(mu_map)
        mu_map = mu_map.unsqueeze(dim=0)

        recon = recon.astype(np.float32)
        recon = torch.from_numpy(recon)
        recon = recon.unsqueeze(dim=0)

        recon, mu_map = self.transform_normalization(recon, mu_map)

        self.mu_maps[_id] = mu_map
        self.reconstructions[_id] = recon

    def __getitem__(self, index: int):
        row = self.table.iloc[index]
        _id = row[headers.id]
        return self.getitem_by_id(_id)

    def getitem_by_id(self, _id: int):
        if _id not in self.reconstructions:
            self.load_image(_id)

        recon = self.reconstructions[_id]
        mu_map = self.mu_maps[_id]

        recon, mu_map = self.transform_augmentation(recon, mu_map)

        return recon, mu_map

    def __len__(self):
        return len(self.table)


__all__ = [MuMapDataset.__name__]


def main(dataset, ids, paused=False):
    from mu_map.util import to_grayscale, COLOR_WHITE

    wname = "Dataset"
    cv.namedWindow(wname, cv.WINDOW_NORMAL)
    cv.resizeWindow(wname, 1600, 900)
    space = np.full((1024, 10), 239, np.uint8)

    TIMEOUT_PAUSED = 0
    TIMEOUT_RUNNING = 1000 // 15

    timeout = TIMEOUT_PAUSED if paused else TIMEOUT_RUNNING

    def to_display_image(image, _slice):
        _image = to_grayscale(image[_slice], min_val=image.min(), max_val=image.max())
        _image = cv.resize(_image, (1024, 1024), cv.INTER_AREA)
        _text = f"{str(_slice):>{len(str(image.shape[0]))}}/{str(image.shape[0])}"
        _image = cv.putText(
            _image, _text, (0, 30), cv.FONT_HERSHEY_SIMPLEX, 1, COLOR_WHITE, 3
        )
        return _image

    def combine_images(images, slices):
        image_1 = to_display_image(images[0], slices[0])
        image_2 = to_display_image(images[1], slices[1])

        image_1 = image_1.repeat(3).reshape((*image_1.shape, 3))
        image_2 = image_2.repeat(3).reshape((*image_2.shape, 3))

        image_3_2 = cv.applyColorMap(image_1, cv.COLORMAP_INFERNO)
        image_3_1 = image_2.copy()
        image_3 = cv.addWeighted(image_3_1, 0.8, image_3_2, 0.4, 0.0)

        space = np.full((image_1.shape[0], 10, 3), 239, np.uint8)
        return np.hstack((image_1, space, image_3, space, image_2))

    for i in range(len(dataset)):
        ir = 0
        im = 0

        row = dataset.table.iloc[i]
        _id = row[headers.id]

        if ids is not None and _id not in ids:
            continue

        recon, mu_map = dataset[i]
        recon = recon.squeeze().numpy()
        mu_map = mu_map.squeeze().numpy()
        print(f"{(i+1):>{len(str(len(dataset)))}}/{len(dataset)} - ID: {_id}", end="\r")

        cv.imshow(wname, combine_images((recon, mu_map), (ir, im)))
        key = cv.waitKey(timeout)

        running = 0
        while True:
            ir = (ir + 1) % recon.shape[0]
            im = (im + 1) % mu_map.shape[0]

            to_show = combine_images((recon, mu_map), (ir, im))
            cv.imshow(wname, to_show)
            key = cv.waitKey(timeout)

            if key == ord("n"):
                break
            elif key == ord("q"):
                exit(0)
            elif key == ord("p"):
                timeout = TIMEOUT_PAUSED if timeout > 0 else TIMEOUT_RUNNING
            elif key == 82:  # up arrow key
                ir = ir - 1
                continue
            elif key == 83:  # right arrow key
                im = im - 1
                continue
            elif key == 81:  # left arrow key
                im = im - 1
                ir = max(ir - 2, 0)
            elif key == 84:  # down arrow key
                ir = ir - 1
                im = max(im - 2, 0)
            elif key == ord("s"):
                cv.imwrite(f"{running:03d}.png", to_show)
                running += 1


if __name__ == "__main__":
    import argparse

    from mu_map.dataset.transform import PadCropTranform
    from mu_map.logging import add_logging_args, get_logger_by_args

    parser = argparse.ArgumentParser(
        description="Visualize the images of a MuMapDataset",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )
    parser.add_argument(
        "dataset_dir", type=str, help="the directory from which the dataset is loaded"
    )
    parser.add_argument(
        "--split",
        type=str,
        choices=["train", "validation", "test"],
        help="choose the split of the data for the dataset",
    )
    parser.add_argument(
        "--unaligned",
        action="store_true",
        help="do not perform center alignment of reconstruction an mu-map slices",
    )
    parser.add_argument(
        "--show_bed",
        action="store_true",
        help="do not remove the bed contour from the mu map",
    )
    parser.add_argument(
        "--full_mu_map",
        action="store_true",
        help="do not remove broken slices of the mu map",
    )
    parser.add_argument(
        "--ids",
        type=int,
        nargs="*",
        help="only display certain ids",
    )
    parser.add_argument(
        "--paused",
        action="store_true",
        help="start in paused mode",
    )
    parser.add_argument(
        "--pad_crop",
        type=int,
        help="pad crop images to this size",
    )
    add_logging_args(parser, defaults={"--loglevel": "DEBUG"})
    args = parser.parse_args()

    align = not args.unaligned
    discard_mu_map_slices = not args.full_mu_map
    bed_contours_file = None if args.show_bed else DEFAULT_BED_CONTOURS_FILENAME
    logger = get_logger_by_args(args)

    transform_normalization = (
        PadCropTranform(dim=3, size=args.pad_crop) if args.pad_crop else Transform()
    )

    dataset = MuMapDataset(
        args.dataset_dir,
        align=align,
        discard_mu_map_slices=discard_mu_map_slices,
        bed_contours_file=bed_contours_file,
        split_name=args.split,
        transform_normalization=transform_normalization,
        logger=logger,
    )
    main(dataset, args.ids, paused=args.paused)