import os

import pandas as pd
import pydicom
import numpy as np
from torch.utils.data import Dataset


class MuMapDataset(Dataset):
    def __init__(
        self, dataset_dir: str, csv_file: str = "meta.csv", images_dir: str = "images"
    ):
        super().__init__()

        self.dir = dataset_dir
        self.dir_images = os.path.join(dataset_dir, images_dir)
        self.csv_file = os.path.join(dataset_dir, csv_file)

        self.table = pd.read_csv(self.csv_file)
        # read csv file and from that access dicom files

    def __getitem__(self, index: int):
        row = self.table.iloc[index]

        recon_file = os.path.join(self.dir_images, row["file_recon_no_ac"])
        mu_map_file = os.path.join(self.dir_images, row["file_mu_map"])

        recon = pydicom.dcmread(recon_file).pixel_array
        mu_map = pydicom.dcmread(mu_map_file).pixel_array
        recon, mu_map = self.align(recon, mu_map)

        return recon, mu_map

    def __len__(self):
        return len(self.table)

    def align(self, recon, mu_map):
        assert recon.shape[0] > mu_map.shape[0], f"Alignment is based on the fact that the NoAC Recon has more slices {recon.shape[0]} than the attenuation map {mu_map.shape[0]}"

        cm = mu_map.shape[0] // 2

        left = cm
        right = mu_map.shape[0] - cm

        cr = recon.shape[0] // 2
        recon = recon[(cr - left):(cr + right)]
        return recon, mu_map


__all__ = [MuMapDataset.__name__]

if __name__ == "__main__":
    dataset = MuMapDataset("data/tmp")
    print(f"Images: {len(dataset)}")

    import cv2 as cv

    wname = "Images"
    cv.namedWindow(wname, cv.WINDOW_NORMAL)
    cv.resizeWindow(wname, 1024, 512)
    space = np.full((128, 10), 239, np.uint8)

    def to_grayscale(img: np.ndarray, min_val=None, max_val=None):
        if min_val is None:
            min_val = img.min()

        if max_val is None:
            max_val = img.max()

        _img = (img - min_val) / (max_val - min_val)
        _img = (_img * 255).astype(np.uint8)
        return _img

    for i in range(len(dataset)):
        ir = 0
        im = 0

        recon, mu_map = dataset[i]
        print(f"{i+1}/{len(dataset)} - {recon.shape} - {mu_map.shape}")

        to_show = np.hstack(
            (
                to_grayscale(recon[ir], min_val=recon.min(), max_val=recon.max()),
                space,
                to_grayscale(mu_map[im], min_val=mu_map.min(), max_val=mu_map.max()),
            )
        )
        cv.imshow(wname, to_show)
        key = cv.waitKey(100)

        while True:
            ir = (ir + 1) % recon.shape[0]
            im = (im + 1) % mu_map.shape[0]

            to_show = np.hstack(
                (
                    to_grayscale(recon[ir], min_val=recon.min(), max_val=recon.max()),
                    space,
                    to_grayscale(
                        mu_map[im], min_val=mu_map.min(), max_val=mu_map.max()
                    ),
                )
            )
            cv.imshow(wname, to_show)

            key = cv.waitKey(100)

            if key == ord("n"):
                break
            if key == ord("q"):
                exit(0)