import math
import random
from typing import Optional

import numpy as np
from tomolab.Reconstruction.SPECT import SPECT_Static_Scan

from mu_map.dataset.util import align_images

COLOR_BLACK = (0, 0, 0)
COLOR_WHITE = (255, 255, 255)


def to_grayscale(
    img: np.ndarray, min_val: Optional[float] = None, max_val: Optional[float] = None
):
    """
    Convert an arbitrary image to a grayscale image with a value range of [0, 255].

    :param img: the image to be converted to grayscale
    :param min_val: minimum value used for normalization, this is helpful if the image has to be normalized relative to others
    :param max_val: maximum value used for normalization, this is helpful if the image has to be normalized relative to others
    :return: the image in grayscale
    """
    if min_val is None:
        min_val = img.min()

    if max_val is None:
        max_val = img.max()

    if (max_val - min_val) == 0:
        return np.zeros(img.shape, np.uint8)

    _img = (img - min_val) / (max_val - min_val)
    _img = (_img * 255).astype(np.uint8)
    return _img


def grayscale_to_rgb(img: np.ndarray):
    """
    Convert a grayscale image to an rgb image by repeating it three times.

    :param img: the grayscale image to be converted to rgb
    :return: the image in rgb
    """
    assert img.ndim == 2, f"grascale image has more than 2 dimensions {img.shape}"
    return img.repeat(3).reshape((*img.shape, 3))


def reconstruct(recon, mu_map=None, iterations=10, use_gpu=True, seed=42):
    random.seed(seed)

    spect = SPECT_Static_Scan()
    spect.set_use_gpu(use_gpu)

    recon_t = np.transpose(recon, (2, 1, 0)).astype(np.float32)
    n_pixels = recon_t.shape[0]

    padding = (n_pixels - recon_t.shape[2]) / 2
    padding_lower = math.ceil(padding)
    padding_upper = math.floor(padding)
    recon_t = np.pad(recon_t, [(0, 0), (0, 0), (padding_lower, padding_upper)])

    spect.set_n_pixels(n_pixels, n_pixels)
    spect.set_gantry_angular_positions(0.0, 360.0, 59)

    measurement = spect.project(recon_t)
    spect.set_measurement(measurement.data)

    if mu_map is not None:
        mu_map_t = np.transpose(mu_map, (2, 1, 0)).astype(np.float32)
        mu_map_t = np.pad(mu_map_t, [(0, 0), (0, 0), (padding_lower, padding_upper)])
        spect.set_attenuation(mu_map_t)

    spect.set_pixel_size(4.8, 4.8)
    spect.set_radius(200.0)
    spect.set_psf(fwhm0_mm=5.0, depth_dependence=0.0001)

    activity = spect.estimate_activity(
        iterations=iterations, subset_size=16, subset_mode="random", method="EM"
    )

    activity = activity.data
    activity = np.transpose(activity, (2, 1, 0))
    activity, _ = align_images(activity, recon)
    return activity

if __name__ == "__main__":
    import pydicom
    import time

    import cv2 as cv
    
    from mu_map.dataset.util import load_dcm_img

    recon_nac = load_dcm_img("./data/second/images/0001-stress-recon_nac_nsc.dcm")
    mu_map = load_dcm_img("./data/second/images/0001-stress-mu_map.dcm")
    recon_nac, mu_map = align_images(recon_nac, mu_map)

    since = time.time()
    recon_ac = reconstruct(recon_nac)
    took = time.time() - since
    print(f"Reconstruction took {took:.3f}s")

    print(f"            |   Max |  Mean |   Min")
    print(f"  Recon nAC | {recon_nac.max():.3f} | {recon_nac.mean():.3f} | {recon_nac.min():.3f}")
    print(f"  Recon  AC | {recon_ac.max():.3f} | {recon_ac.mean():.3f} | {recon_ac.min():.3f}")

    def display(recon_nac, recon_ac, _slice):
        img1 = to_grayscale(recon_nac[_slice], max_val=recon_nac.max())
        img1 = cv.resize(img1, (1024, 1024))

        img2 = to_grayscale(recon_ac[_slice], max_val=recon_ac.max())
        img2 = cv.resize(img2, (1024, 1024))

        s = np.full((1024, 10), 239, np.uint8)
        return np.hstack((img1, s, img2))
    exit(0)

    i = 0
    wname = "Test"
    cv.namedWindow(wname, cv.WINDOW_NORMAL)
    cv.resizeWindow(wname, 1600, 900)
    while True:
        cv.imshow(wname, display(recon_nac, recon_ac, i))

        i = (i + 1) % recon_nac.shape[0]

        key = cv.waitKey(100)
        if key == ord("q"):
            break