import os
from typing import Optional

import cv2 as cv
import pandas as pd
import pydicom
import numpy as np
import torch
from torch.utils.data import Dataset

from mu_map.data.prepare import headers
from mu_map.data.remove_bed import DEFAULT_BED_CONTOURS_FILENAME, load_contours
from mu_map.data.review_mu_map import discard_slices
from mu_map.data.split import split_csv
from mu_map.dataset.transform import Transform
from mu_map.logging import get_logger


"""
Since DICOM images only allow images stored in short integer format,
the Siemens scanner software multiplies values by a factor before storing
so that no precision is lost.
The scale can be found in this private DICOM tag.
"""
DCM_TAG_PIXEL_SCALE_FACTOR = 0x00331038


def align_images(image_1: np.ndarray, image_2: np.ndarray) -> np.ndarray:
    """
    Align one image to another on the first axis (z-axis).
    It is assumed that the second image has less slices than the first.
    Then, the first image is shortened in a way that the centers of both images lie on top of each other.

    :param image_1: the image to be aligned
    :param image_2: the image to which image_1 is aligned
    :return: the aligned image_1
    """
    assert (
        image_1.shape[0] > image_2.shape[0]
    ), f"Alignment is based on the fact that image 1 has more slices {image_1.shape[0]} than image_2 {image_.shape[0]}"

    # central slice of image 2
    c_2 = image_2.shape[0] // 2
    # image to the left and right of the center
    left = c_2
    right = image_2.shape[0] - c_2

    # central slice of image 1
    c_1 = image_1.shape[0] // 2
    # select center and same amount to the left/right as image_2
    return image_1[(c_1 - left) : (c_1 + right)]


class MuMapDataset(Dataset):
    def __init__(
        self,
        dataset_dir: str,
        csv_file: str = "meta.csv",
        split_file: str = "split.csv",
        split_name: str = None,
        images_dir: str = "images",
        bed_contours_file: Optional[str] = DEFAULT_BED_CONTOURS_FILENAME,
        discard_mu_map_slices: bool = True,
        align: bool = True,
        transform_normalization: Transform = Transform(),
        transform_augmentation: Transform = Transform(),
        logger=None,
    ):
        super().__init__()

        self.dir = dataset_dir
        self.dir_images = os.path.join(dataset_dir, images_dir)
        self.csv_file = os.path.join(dataset_dir, csv_file)
        self.split_file = os.path.join(dataset_dir, split_file)

        self.transform_normalization = transform_normalization
        self.transform_augmentation = transform_augmentation
        self.logger = logger if logger is not None else get_logger()

        self.bed_contours_file = (
            os.path.join(dataset_dir, bed_contours_file) if bed_contours_file else None
        )
        self.bed_contours = (
            load_contours(self.bed_contours_file) if bed_contours_file else None
        )

        # read CSV file and from that access DICOM files
        self.table = pd.read_csv(self.csv_file)
        if split_name:
            self.table = split_csv(self.table, self.split_file)[split_name]
        self.table["id"] = self.table["id"].apply(int)

        self.discard_mu_map_slices = discard_mu_map_slices
        self.align = align

        self.reconstructions = {}
        self.mu_maps = {}
        self.pre_load_images()

    def pre_load_images(self):
        self.logger.debug("Pre-loading images ...")
        for i in range(len(self.table)):
            row = self.table.iloc[i]
            _id = row["id"]

            mu_map_file = os.path.join(self.dir_images, row[headers.file_mu_map])
            mu_map = pydicom.dcmread(mu_map_file)
            mu_map = mu_map.pixel_array / mu_map[DCM_TAG_PIXEL_SCALE_FACTOR].value
            if self.discard_mu_map_slices:
                mu_map = discard_slices(row, mu_map)
            if self.bed_contours:
                bed_contour = self.bed_contours[row["id"]]
                for i in range(mu_map.shape[0]):
                    mu_map[i] = cv.drawContours(mu_map[i], [bed_contour], -1, 0.0, -1)

            recon_file = os.path.join(self.dir_images, row[headers.file_recon_nac_nsc])
            recon = pydicom.dcmread(recon_file)
            recon = recon.pixel_array / recon[DCM_TAG_PIXEL_SCALE_FACTOR].value
            if self.align:
                recon = align_images(recon, mu_map)

            mu_map = mu_map.astype(np.float32)
            mu_map = torch.from_numpy(mu_map)
            mu_map = mu_map.unsqueeze(dim=0)

            recon = recon.astype(np.float32)
            recon = torch.from_numpy(recon)
            recon = recon.unsqueeze(dim=0)

            recon, mu_map = self.transform_normalization(recon, mu_map)

            self.mu_maps[_id] = mu_map
            self.reconstructions[_id] = recon
        self.logger.debug("Pre-loading images done!")

    def __getitem__(self, index: int):
        row = self.table.iloc[index]
        _id = row["id"]

        recon = self.reconstructions[_id]
        mu_map = self.mu_maps[_id]

        recon, mu_map = self.transform_augmentation(recon, mu_map)

        return recon, mu_map

    def __len__(self):
        return len(self.table)


__all__ = [MuMapDataset.__name__]

def main(dataset):
    from mu_map.util import to_grayscale, COLOR_WHITE

    wname = "Dataset"
    cv.namedWindow(wname, cv.WINDOW_NORMAL)
    cv.resizeWindow(wname, 1600, 900)
    space = np.full((1024, 10), 239, np.uint8)

    timeout = 100

    def to_display_image(image, _slice):
        _image = to_grayscale(image[_slice], min_val=image.min(), max_val=image.max())
        _image = cv.resize(_image, (1024, 1024), cv.INTER_AREA)
        _text = f"{str(_slice):>{len(str(image.shape[0]))}}/{str(image.shape[0])}"
        _image = cv.putText(
            _image, _text, (0, 30), cv.FONT_HERSHEY_SIMPLEX, 1, COLOR_WHITE, 3
        )
        return _image

    def combine_images(images, slices):
        image_1 = to_display_image(images[0], slices[0])
        image_2 = to_display_image(images[1], slices[1])
        space = np.full((image_1.shape[0], 10), 239, np.uint8)
        return np.hstack((image_1, space, image_2))

    for i in range(len(dataset)):
        ir = 0
        im = 0

        recon, mu_map = dataset[i]
        recon = recon.squeeze().numpy()
        mu_map = mu_map.squeeze().numpy()
        print(f"{(i+1):>{len(str(len(dataset)))}}/{len(dataset)}", end="\r")

        cv.imshow(wname, combine_images((recon, mu_map), (ir, im)))
        key = cv.waitKey(timeout)

        running = 0
        while True:
            ir = (ir + 1) % recon.shape[0]
            im = (im + 1) % mu_map.shape[0]

            to_show = combine_images((recon, mu_map), (ir, im))
            cv.imshow(wname, to_show)
            key = cv.waitKey(timeout)

            if key == ord("n"):
                break
            elif key == ord("q"):
                exit(0)
            elif key == ord("p"):
                timeout = 0 if timeout > 0 else 100
            elif key == 83:  # right arrow key
                continue
            elif key == 81:  # left arrow key
                ir = max(ir - 2, 0)
                im = max(im - 2, 0)
            elif key == ord("s"):
                cv.imwrite(f"{running:03d}.png", to_show)
                running += 1


if __name__ == "__main__":
    import argparse

    from mu_map.logging import add_logging_args, get_logger_by_args

    parser = argparse.ArgumentParser(
        description="Visualize the images of a MuMapDataset",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )
    parser.add_argument(
        "dataset_dir", type=str, help="the directory from which the dataset is loaded"
    )
    parser.add_argument(
        "--split",
        type=str,
        choices=["train", "validation", "test"],
        help="choose the split of the data for the dataset",
    )
    parser.add_argument(
        "--unaligned",
        action="store_true",
        help="do not perform center alignment of reconstruction an mu-map slices",
    )
    parser.add_argument(
        "--show_bed",
        action="store_true",
        help="do not remove the bed contour from the mu map",
    )
    parser.add_argument(
        "--full_mu_map",
        action="store_true",
        help="do not remove broken slices of the mu map",
    )
    add_logging_args(parser, defaults={"--loglevel": "DEBUG"})
    args = parser.parse_args()

    align = not args.unaligned
    discard_mu_map_slices = not args.full_mu_map
    bed_contours_file = None if args.show_bed else DEFAULT_BED_CONTOURS_FILENAME
    logger = get_logger_by_args(args)

    dataset = MuMapDataset(
        args.dataset_dir,
        align=align,
        discard_mu_map_slices=discard_mu_map_slices,
        bed_contours_file=bed_contours_file,
        split_name=args.split,
        logger=logger,
    )
    main(dataset)