from typing import Callable

import cv2 as cv
import numpy as np
import torch

torch.set_grad_enabled(False)

from mu_map.data.prepare import headers
from mu_map.dataset.default import MuMapDataset
from mu_map.dataset.normalization import (
    GaussianNormTransform,
    MeanNormTransform,
    MaxNormTransform,
)
from mu_map.dataset.transform import PadCropTranform, SequenceTransform
from mu_map.eval.measures import nmae, mse
from mu_map.models.unet import UNet
from mu_map.util import to_grayscale, COLOR_WHITE
from mu_map.vis.slices import join_images
from mu_map.random_search.cgan import load_params


def main(
    model: torch.nn.Module,
    dataset: MuMapDataset,
    wname: str = "Prediction",
    action: Callable[int, bool] = None,
    _print: bool = True,
):
    """
    Visualize the predictions of a model for all reconstructions
    in a dataset.

    Parameters
    ----------
    model: torch.nn.Module
        the mode with which predictions are computed
    dataset: MuMapDataset
        the dataset containing reconstructions for which images are computed and target attenuation maps for comparison
    wname: str
        the name of the display window
    action: Callable[int, bool]
        Add control behaviour by providing a callable reacting to key presses.
        If it returns true, the display is stopped.
    _print: bool
        if measures for predictions should be printed
    """
    timeout = 100

    # disable print if not wanted
    print_func = print if _print else lambda x: x

    # print header
    print_func(" Id |     NMAE |      MSE")
    print_func("----|----------|---------")

    for i, (recon, mu_map) in enumerate(dataset):
        _id = dataset.table.iloc[i][headers.id]

        prediction = model(recon.unsqueeze(dim=0)).squeeze().numpy()
        mu_map = mu_map.squeeze().numpy()

        _nmae = nmae(prediction, mu_map)
        _mse = mse(prediction, mu_map)
        print_func(f"{_id:03d} | {_nmae:.6f} | {_mse:.6f}")

        prediction = np.clip(prediction, 0, mu_map.max())
        diff = np.abs(prediction - mu_map)

        volumes = [prediction, mu_map, diff]
        min_val = 0
        max_val = mu_map.max()
        n_slices = mu_map.shape[0]

        _slice = 0
        _break_outer = False
        timeout = 100
        show_text = True
        while True:
            images = map(lambda v: v[_slice], volumes)
            images = map(
                lambda img: to_grayscale(img, min_val=min_val, max_val=max_val), images
            )
            images = map(lambda img: cv.resize(img, (512, 512)), images)
            images = list(images)

            if show_text:
                txt = f"{str(_slice):{len(str(n_slices))}}/{n_slices}"
                cv.putText(images[0], txt, (0, 30), cv.FONT_HERSHEY_SIMPLEX, 1, 255, 3)

            _slice = (_slice + 1) % n_slices

            cv.imshow(wname, join_images(images))
            key = cv.waitKey(timeout)
            if action is not None and action(key):
                _break_outer = True
                break

            if key == ord("q"):
                exit(0)
            elif key == ord("n"):
                break
            elif key == 81:
                _slice = (_slice - 2) % n_slices
            elif key == ord("p"):
                timeout = 0 if timeout > 0 else 100
            elif key == ord("t"):
                show_text = not show_text
                _slice = (_slice - 1) % n_slices
            elif key == ord("s"):
                _slice = (_slice - 1) % n_slices
                cv.imwrite("prediction.png", images[0])
                cv.imwrite("mu_map.png", images[1])
                cv.imwrite("difference.png", images[2])
                with open("info.txt", mode="w") as f:
                    f.write(f"Arguments: {args}\n")
                    f.write("\n")
                    f.write(f"Id: {_id}\n")
                    f.write(f"Slice: {_slice}\n")

        if _break_outer:
            break


if __name__ == "__main__":
    import argparse
    import os

    parser = argparse.ArgumentParser(
        description="visualize the results of a random search run",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )
    parser.add_argument(
        "dir", type=str, help="directory containing the data of the random search run"
    )
    parser.add_argument(
        "--split",
        choices=["train", "validation", "test"],
        default="validation",
        help="the split of the dataset used",
    )
    args = parser.parse_args()

    params = load_params(os.path.join(args.dir, "params.json"))
    dataset = MuMapDataset(
        "data/second/",
        transform_normalization=SequenceTransform(
            [params["normalization"], PadCropTranform(dim=3, size=32)]
        ),
        split_name=args.split,
        scatter_correction=False,
    )
    device = torch.device("cpu")
    model = UNet(features=params["generator_features"])
    weights = torch.load(
        os.path.join(args.dir, "snapshots", "val_min_generator.pth"),
        map_location=device,
    )
    model.load_state_dict(weights)
    model = model.to(device).eval()

    wname = "Prediction"
    cv.namedWindow(wname, cv.WINDOW_NORMAL)
    cv.resizeWindow(wname, 1600, 900)

    main(model, dataset, wname=wname)