import os
from typing import Optional

import pandas as pd
import pydicom
import numpy as np
from torch.utils.data import Dataset

from mu_map.data.remove_bed import DEFAULT_BED_CONTOURS_FILENAME, load_contours
from mu_map.data.review_mu_map import discard_slices


def align_images(image_1: np.ndarray, image_2: np.ndarray) -> np.ndarray:
    """
    Align one image to another on the first axis (z-axis).
    It is assumed that the second image has less slices than the first.
    Then, the first image is shortened in a way that the centers of both images lie on top of each other.

    :param image_1: the image to be aligned
    :param image_2: the image to which image_1 is aligned
    :return: the aligned image_1
    """
    assert (
        image_1.shape[0] > image_2.shape[0]
    ), f"Alignment is based on the fact that image 1 has more slices {image_1.shape[0]} than image_2 {image_.shape[0]}"

    # central slice of image 2
    c_2 = image_2.shape[0] // 2
    # image to the left and right of the center
    left = c_2
    right = image_2.shape[0] - c_2

    # central slice of image 1
    c_1 = image_1.shape[0] // 2
    # select center and same amount to the left/right as image_2
    return image_1[(c_1 - left) : (c_1 + right)]


class MuMapDataset(Dataset):
    def __init__(
        self,
        dataset_dir: str,
        csv_file: str = "meta.csv",
        images_dir: str = "images",
        bed_contours_file: Optional[str] = DEFAULT_BED_CONTOURS_FILENAME,
        discard_mu_map_slices: bool = True,
    ):
        super().__init__()

        self.dir = dataset_dir
        self.dir_images = os.path.join(dataset_dir, images_dir)
        self.csv_file = os.path.join(dataset_dir, csv_file)

        self.bed_contours_file = (
            os.path.join(dataset_dir, bed_contours_file) if bed_contours_file else None
        )
        self.bed_contours = (
            load_contours(self.bed_contours_file) if bed_contours_file else None
        )

        # read CSV file and from that access DICOM files
        self.table = pd.read_csv(self.csv_file)
        self.table["id"] = self.table["id"].apply(int)

        self.discard_mu_map_slices = discard_mu_map_slices

    def __getitem__(self, index: int):
        row = self.table.iloc[index]

        recon_file = os.path.join(self.dir_images, row["file_recon_no_ac"])
        mu_map_file = os.path.join(self.dir_images, row["file_mu_map"])

        recon = pydicom.dcmread(recon_file).pixel_array
        mu_map = pydicom.dcmread(mu_map_file).pixel_array

        if self.discard_mu_map_slices:
            mu_map = discard_slices(row, mu_map)

        if self.bed_contours:
            bed_contour = self.bed_contours[row["id"]]
            for i in range(mu_map.shape[0]):
                mu_map[i] = cv.drawContours(mu_map[i], [bed_contour], -1, 0.0, -1)

        recon = align_images(recon, mu_map)

        return recon, mu_map

    def __len__(self):
        return len(self.table)


__all__ = [MuMapDataset.__name__]

if __name__ == "__main__":
    dataset = MuMapDataset("data/tmp")

    import cv2 as cv

    wname = "Images"
    cv.namedWindow(wname, cv.WINDOW_NORMAL)
    cv.resizeWindow(wname, 1024, 512)
    space = np.full((128, 10), 239, np.uint8)

    def to_grayscale(img: np.ndarray, min_val=None, max_val=None):
        if min_val is None:
            min_val = img.min()

        if max_val is None:
            max_val = img.max()

        _img = (img - min_val) / (max_val - min_val)
        _img = (_img * 255).astype(np.uint8)
        return _img

    for i in range(len(dataset)):
        ir = 0
        im = 0

        recon, mu_map = dataset[i]
        print(f"{i+1}/{len(dataset)} - {recon.shape} - {mu_map.shape}")

        to_show = np.hstack(
            (
                to_grayscale(recon[ir], min_val=recon.min(), max_val=recon.max()),
                space,
                to_grayscale(mu_map[im], min_val=mu_map.min(), max_val=mu_map.max()),
            )
        )
        cv.imshow(wname, to_show)
        key = cv.waitKey(100)

        while True:
            ir = (ir + 1) % recon.shape[0]
            im = (im + 1) % mu_map.shape[0]

            to_show = np.hstack(
                (
                    to_grayscale(recon[ir], min_val=recon.min(), max_val=recon.max()),
                    space,
                    to_grayscale(
                        mu_map[im], min_val=mu_map.min(), max_val=mu_map.max()
                    ),
                )
            )
            cv.imshow(wname, to_show)

            key = cv.waitKey(100)

            if key == ord("n"):
                break
            if key == ord("q"):
                exit(0)