"""
Module containing a GUI to label contours of a bed to be extracted from
an attenuation maps.
Additionally, it contains utility functions to load bed contours, remove or
add the bad.
"""
import json
from typing import Dict, List

import cv2 as cv
import numpy as np

from mu_map.dataset.util import align_images


DEFAULT_BED_CONTOURS_FILENAME = "bed_contours.json"


def load_contours(filename: str, as_ndarry: bool = True) -> Dict[int, np.ndarray]:
    """
    Load contours from a JSON file.
    The structure of the file is a dict where the key is the id of the according
    image and the value is a numpy array of the contour.

    Parameters
    ----------
    filename: str
        filename of a JSON file containing contours
    as_ndarry: bool
        directly parse contours as numpy arrays

    Parameters
    ----------
    Dict
        a dict mapping ids to contours either as lists of int or np.arrays
    """
    with open(filename, mode="r") as f:
        contours = json.load(f)

    if not as_ndarry:
        return contours

    _map = map(
        lambda item: (int(item[0]), np.array(item[1]).astype(int)), contours.items()
    )
    return dict(_map)


def remove_bed(mu_map: np.ndarray, bed_contour: np.ndarray) -> np.ndarray:
    """
    Remove the bed defined by a contour from all slices.

    Parameters
    ----------
    mu_map: np.ndarray
        the mu_map from which the bed is removed
    bed_contour: np.ndarray
        the contour describing where the bed is found

    Returns
    -------
    np.ndarray
        the mu_map with the bed removed
    """
    _mu_map = mu_map.copy()
    for i in range(_mu_map.shape[0]):
        mu_map[i] = cv.drawContours(_mu_map[i], [bed_contour], -1, 0.0, -1)
    return _mu_map


def add_bed(without_bed: np.ndarray, with_bed: np.ndarray, bed_contour: np.ndarray):
    """
    Add the bed to every slice of a mu_map.

    Parameters
    ----------
    without_bed: np.ndarray
        the mu_map without the bed
    with_bed: np.ndarray
        the mu_map with the bed
    bed_contour: np.ndarray
        the contour defining the location of the bed

    Returns
    -------
    np.ndarray
        the mu_map with the bed added
    """
    with_bed, without_bed = align_images(with_bed, without_bed)

    for _slice in range(with_bed.shape[0]):
        with_bed_i = with_bed[_slice]
        without_bed_i = without_bed[_slice]

        cnt_img = np.zeros(without_bed_i.shape, dtype=np.uint8)
        cnt_img = cv.drawContours(cnt_img, [bed_contour], -1, 255, -1)

        without_bed[_slice] = np.where(cnt_img > 0, with_bed_i, without_bed_i)

    return without_bed


if __name__ == "__main__":

    def scale_points(points: List[List[int]], scale: float):
        """
        Utility function to scale all points in a list of points.
        """
        for i in range(len(points)):
            for j in range(len(points[i])):
                points[i][j] = round(points[i][j] * scale)

    import argparse
    from enum import Enum
    import os

    from mu_map.data.prepare import headers
    from mu_map.dataset.default import MuMapDataset
    from mu_map.util import to_grayscale, COLOR_BLACK, COLOR_WHITE

    parser = argparse.ArgumentParser(
        description="draw and save contours to exclude the bed from mu maps",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )
    parser.add_argument(
        "dataset_dir", type=str, help="the directory containing the dataset"
    )
    parser.add_argument(
        "--revise_ids",
        type=int,
        nargs="*",
        help="only revise the contour of an image withe a specified index",
    )
    parser.add_argument(
        "--output_file",
        type=str,
        default=DEFAULT_BED_CONTOURS_FILENAME,
        help="default file in dataset dir where the drawn contours are stored",
    )
    args = parser.parse_args()
    args.output_file = os.path.join(args.dataset_dir, args.output_file)

    controls = """
    Controls:
    Left click to add points to the current contour.

    q: exit
    d: delete last point
    c: delete all points
    f: forward (next slice)
    b: backward (previous slice)
    n: save contour and go to the next image
    v: change the visual mode between drawing contours and hiding the are within
    """
    print(controls)
    print()

    # set bed contours file to None so that existing contours are not used
    dataset = MuMapDataset(args.dataset_dir, bed_contours_file=None)

    if os.path.isfile(args.output_file):
        try:
            bed_contours = load_contours(args.output_file, as_ndarry=False)
        except:
            print(f"JSON file {args.output_file} is corrupted! Create a new one.")
            bed_contours = {}
            with open(args.output_file, mode="w") as f:
                f.write(json.dumps(bed_contours, sort_keys=True))
    else:
        bed_contours = {}

    class VisualMode(Enum):
        DRAW_CONTOURS = 1
        HIDE_BED = 2

    # save the points of the contour in a list and defined a mouse callback
    points = []

    def mouse_callback(event, x, y, flags, param):
        if event == cv.EVENT_LBUTTONUP:
            points.append([x, y])

    # create a window for display
    window_name = "Bed Removal"
    window_size = 1024
    cv.namedWindow(window_name, cv.WINDOW_NORMAL)
    cv.resizeWindow(window_name, window_size, window_size)
    cv.setMouseCallback(window_name, mouse_callback)

    ids = list(dataset.table[headers.id])
    if args.revise_ids:
        ids = args.revise_ids

    for _i, _id in enumerate(ids):
        _, mu_map = dataset.get_item_by_id(_id)

        if str(_id) in bed_contours and not args.revise_ids:
            print(f"Skip {_id} because file already contains these contours")
            continue

        if args.revise_ids and str(_id) in bed_contours:
            points.extend(bed_contours[str(_id)].copy())

        print(
            f"Image {str(_i + 1):>{len(str(len(ids)))}}/{len(ids)}, ID: {_id:>{len(str(max(ids)))}}",
            end="\r",
        )
        # select the center slice for display (the bed location is constant over all slices)
        mu_map = mu_map.squeeze().numpy()
        _slice = 0

        scale = window_size / mu_map.shape[1]
        scale_points(points, scale)

        # set initial visual mode
        visual_mode = VisualMode.DRAW_CONTOURS
        while True:
            # compute image to display
            to_show = mu_map[_slice]
            to_show = to_grayscale(to_show, min_val=mu_map.min(), max_val=mu_map.max())
            to_show = cv.resize(to_show, (window_size, window_size))

            if visual_mode == VisualMode.DRAW_CONTOURS:
                # draw lines between all points
                for p1, p2 in zip(points[:-1], points[1:]):
                    to_show = cv.line(to_show, p1, p2, color=COLOR_WHITE, thickness=2)
                # close the contour
                if len(points) > 0:
                    to_show = cv.line(
                        to_show, points[0], points[-1], color=COLOR_WHITE, thickness=2
                    )

                # draw all points as circles
                for point in points:
                    to_show = cv.circle(
                        to_show, point, radius=4, color=COLOR_BLACK, thickness=-1
                    )
                    to_show = cv.circle(
                        to_show, point, radius=4, color=COLOR_WHITE, thickness=1
                    )
            else:
                # eliminate area inside the contour
                _points = np.array(points).astype(int)
                to_show = cv.drawContours(
                    to_show, [_points], -1, COLOR_BLACK, thickness=-1
                )

            # visualize image and handle inputs
            cv.imshow(window_name, to_show)
            key = cv.waitKey(100)
            if key == ord("q"):
                # write current contours to output file
                if len(points) > 0:
                    scale_points(points, 1.0 / scale)
                    bed_contours[str(_id)] = points.copy()
                    with open(args.output_file, mode="w") as f:
                        f.write(json.dumps(bed_contours, sort_keys=True))
                exit(0)
            elif key == ord("d"):
                points = points[:-1]
            elif key == ord("c"):
                points.clear()
            elif key == ord("n"):
                break
            elif key == ord("f"):
                _slice = (_slice + 1) % mu_map.shape[0]
            elif key == ord("b"):
                _slice = (_slice - 1) % mu_map.shape[0]
            elif key == ord("v"):
                visual_mode = (
                    VisualMode.DRAW_CONTOURS
                    if visual_mode == VisualMode.HIDE_BED
                    else VisualMode.HIDE_BED
                )

        # save current contour in dict
        scale_points(points, 1.0 / scale)
        bed_contours[str(_id)] = points.copy()
        points.clear()

        # write contours to output file
        with open(args.output_file, mode="w") as f:
            f.write(json.dumps(bed_contours, sort_keys=True))