import os
from typing import Optional

import pandas as pd
import pydicom
import numpy as np
from torch.utils.data import Dataset

from mu_map.data.prepare import headers
from mu_map.data.remove_bed import DEFAULT_BED_CONTOURS_FILENAME, load_contours
from mu_map.data.review_mu_map import discard_slices


def align_images(image_1: np.ndarray, image_2: np.ndarray) -> np.ndarray:
    """
    Align one image to another on the first axis (z-axis).
    It is assumed that the second image has less slices than the first.
    Then, the first image is shortened in a way that the centers of both images lie on top of each other.

    :param image_1: the image to be aligned
    :param image_2: the image to which image_1 is aligned
    :return: the aligned image_1
    """
    assert (
        image_1.shape[0] > image_2.shape[0]
    ), f"Alignment is based on the fact that image 1 has more slices {image_1.shape[0]} than image_2 {image_.shape[0]}"

    # central slice of image 2
    c_2 = image_2.shape[0] // 2
    # image to the left and right of the center
    left = c_2
    right = image_2.shape[0] - c_2

    # central slice of image 1
    c_1 = image_1.shape[0] // 2
    # select center and same amount to the left/right as image_2
    return image_1[(c_1 - left) : (c_1 + right)]


class MuMapDataset(Dataset):
    def __init__(
        self,
        dataset_dir: str,
        csv_file: str = "meta.csv",
        images_dir: str = "images",
        bed_contours_file: Optional[str] = DEFAULT_BED_CONTOURS_FILENAME,
        discard_mu_map_slices: bool = True,
        align: bool = True,
    ):
        super().__init__()

        self.dir = dataset_dir
        self.dir_images = os.path.join(dataset_dir, images_dir)
        self.csv_file = os.path.join(dataset_dir, csv_file)

        self.bed_contours_file = (
            os.path.join(dataset_dir, bed_contours_file) if bed_contours_file else None
        )
        self.bed_contours = (
            load_contours(self.bed_contours_file) if bed_contours_file else None
        )

        # read CSV file and from that access DICOM files
        self.table = pd.read_csv(self.csv_file)
        self.table["id"] = self.table["id"].apply(int)

        self.discard_mu_map_slices = discard_mu_map_slices
        self.align = align

        self.reconstructions = {}
        self.mu_maps = {}
        self.pre_load_images()

    def pre_load_images(self):
        print("Pre-loading images ...", end="\r")
        for i in range(len(self.table)):
            row = self.table.iloc[i]
            _id = row["id"]

            mu_map_file = os.path.join(self.dir_images, row[headers.file_mu_map])
            mu_map = pydicom.dcmread(mu_map_file).pixel_array
            if self.discard_mu_map_slices:
                mu_map = discard_slices(row, mu_map)
            if self.bed_contours:
                bed_contour = self.bed_contours[row["id"]]
                for i in range(mu_map.shape[0]):
                    mu_map[i] = cv.drawContours(mu_map[i], [bed_contour], -1, 0.0, -1)
            self.mu_maps[_id] = mu_map

            recon_file = os.path.join(self.dir_images, row[headers.file_recon_nac_nsc])
            recon = pydicom.dcmread(recon_file).pixel_array
            if self.align:
                recon = align_images(recon, mu_map)
            self.reconstructions[_id] = recon
        print("Pre-loading images done!")

    def __getitem__(self, index: int):
        row = self.table.iloc[index]
        _id = row["id"]

        recon = self.reconstructions[_id]
        mu_map = self.mu_maps[_id]

        # recon_file = os.path.join(self.dir_images, row[headers.file_recon_nac_nsc])
        # mu_map_file = os.path.join(self.dir_images, row[headers.file_mu_map])

        # recon = pydicom.dcmread(recon_file).pixel_array
        # mu_map = pydicom.dcmread(mu_map_file).pixel_array

        # if self.discard_mu_map_slices:
        # mu_map = discard_slices(row, mu_map)

        # if self.bed_contours:
        # bed_contour = self.bed_contours[row["id"]]
        # for i in range(mu_map.shape[0]):
        # mu_map[i] = cv.drawContours(mu_map[i], [bed_contour], -1, 0.0, -1)

        # if self.align:
        # recon = align_images(recon, mu_map)

        return recon, mu_map

    def __len__(self):
        return len(self.table)


__all__ = [MuMapDataset.__name__]

if __name__ == "__main__":
    import argparse

    import cv2 as cv

    from mu_map.util import to_grayscale, COLOR_WHITE

    parser = argparse.ArgumentParser(
        description="Visualize the images of a MuMapDataset",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )
    parser.add_argument(
        "dataset_dir", type=str, help="the directory from which the dataset is loaded"
    )
    parser.add_argument(
        "--unaligned",
        action="store_true",
        help="do not perform center alignment of reconstruction an mu-map slices",
    )
    parser.add_argument(
        "--show_bed",
        action="store_true",
        help="do not remove the bed contour from the mu map",
    )
    parser.add_argument(
        "--full_mu_map",
        action="store_true",
        help="do not remove broken slices of the mu map",
    )
    args = parser.parse_args()

    align = not args.unaligned
    discard_mu_map_slices = not args.full_mu_map
    bed_contours_file = None if args.show_bed else DEFAULT_BED_CONTOURS_FILENAME

    dataset = MuMapDataset(
        args.dataset_dir,
        align=align,
        discard_mu_map_slices=discard_mu_map_slices,
        bed_contours_file=bed_contours_file,
    )

    wname = "Dataset"
    cv.namedWindow(wname, cv.WINDOW_NORMAL)
    cv.resizeWindow(wname, 1600, 900)
    space = np.full((1024, 10), 239, np.uint8)

    timeout = 100

    def to_display_image(image, _slice):
        _image = to_grayscale(image[_slice], min_val=image.min(), max_val=image.max())
        _image = cv.resize(_image, (1024, 1024), cv.INTER_AREA)
        _text = f"{str(_slice):>{len(str(image.shape[0]))}}/{str(image.shape[0])}"
        _image = cv.putText(
            _image, _text, (0, 30), cv.FONT_HERSHEY_SIMPLEX, 1, COLOR_WHITE, 3
        )
        return _image

    def combine_images(images, slices):
        image_1 = to_display_image(images[0], slices[0])
        image_2 = to_display_image(images[1], slices[1])
        space = np.full((image_1.shape[0], 10), 239, np.uint8)
        return np.hstack((image_1, space, image_2))

    for i in range(len(dataset)):
        ir = 0
        im = 0

        recon, mu_map = dataset[i]
        print(f"{(i+1):>{len(str(len(dataset)))}}/{len(dataset)}", end="\r")

        cv.imshow(wname, combine_images((recon, mu_map), (ir, im)))
        key = cv.waitKey(timeout)

        while True:
            ir = (ir + 1) % recon.shape[0]
            im = (im + 1) % mu_map.shape[0]

            cv.imshow(wname, combine_images((recon, mu_map), (ir, im)))

            key = cv.waitKey(timeout)

            if key == ord("n"):
                break
            elif key == ord("q"):
                exit(0)
            elif key == ord("p"):
                timeout = 0 if timeout > 0 else 100
            elif key == 83:  # right arrow key
                continue
            elif key == 81:  # left arrow key
                ir = max(ir - 2, 0)
                im = max(im - 2, 0)