import json
from typing import Dict

import numpy as np


DEFAULT_BED_CONTOURS_FILENAME = "bed_contours.json"


def load_contours(filename: str) -> Dict[int, np.ndarray]:
    """
    Load contours from a json file.
    The structure of the file is a dict where the key is the id of the according
    image and the value is a numpy array of the contour.

    :param filename: filename of a json file containing contours
    :return: a dict mapping ids to contours
    """
    with open(filename, mode="r") as f:
        contours = json.load(f)

    _map = map(lambda item: (int(item[0]), np.array(item[1]).astype(int)), contours.items())
    return dict(_map)


if __name__ == "__main__":
    import argparse
    from enum import Enum
    import os

    import cv2 as cv

    from mu_map.data.datasets import MuMapDataset
    from mu_map.util import to_grayscale, COLOR_BLACK, COLOR_WHITE

    parser = argparse.ArgumentParser(
        description="draw and save contours to exclude the bed from mu maps",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )
    parser.add_argument(
        "dataset_dir", type=str, help="the directory containing the dataset"
    )
    parser.add_argument(
        "--output_file",
        type=str,
        default=DEFAULT_BED_CONTOURS_FILENAME,
        help="default file in dataset dir where the drawn contours are stored",
    )
    args = parser.parse_args()
    args.output_file = os.path.join(args.dataset_dir, args.output_file)

    controls = """
    Controls:
    Left click to add points to the current contour.

    q: exit
    d: delete last point
    n: save contour and go to the next image
    v: change the visual mode between drawing contours and hiding the are within
    """
    print(controls)
    print()

    # set bed contours file to None so that existing contours are not used
    dataset = MuMapDataset(args.dataset_dir, bed_contours_file=None)

    # TODO: implement that existing contours are loaded so that labeling can be continued?
    bed_contours = {}

    class VisualMode(Enum):
        DRAW_CONTOURS = 1
        HIDE_BED = 2

    window_name = "Bed Removal"
    for i, (_, mu_map) in enumerate(dataset):
        print(f"Image {str(i + 1):>{len(str(len(dataset)))}}/{len(dataset)}", end="\r")
        # select the center slice for display (the bed location is constant over all slices)
        mu_map = mu_map[mu_map.shape[0] // 2]

        # save the points of the contour in a list and defined a mouse callback
        points = []

        def mouse_callback(event, x, y, flags, param):
            if event == cv.EVENT_LBUTTONUP:
                points.append((x, y))

        # create a window for display
        cv.namedWindow(window_name, cv.WINDOW_NORMAL)
        cv.resizeWindow(window_name, 1024, 1024)
        cv.setMouseCallback(window_name, mouse_callback)

        # set initial visual mode
        visual_mode = VisualMode.DRAW_CONTOURS
        while True:
            # compute image to display
            to_show = to_grayscale(mu_map)

            if visual_mode == VisualMode.DRAW_CONTOURS:
                # draw lines between all points
                for p1, p2 in zip(points[:-1], points[1:]):
                    to_show = cv.line(to_show, p1, p2, color=COLOR_WHITE, thickness=1)
                # close the contour
                if len(points) > 0:
                    to_show = cv.line(
                        to_show, points[0], points[-1], color=COLOR_WHITE, thickness=1
                    )

                # draw all points as circles
                for point in points:
                    to_show = cv.circle(
                        to_show, point, radius=2, color=COLOR_BLACK, thickness=-1
                    )
                    to_show = cv.circle(
                        to_show, point, radius=2, color=COLOR_WHITE, thickness=1
                    )
            else:
                # eliminate area inside the contour
                _points = np.array(points).astype(int)
                to_show = cv.drawContours(
                    to_show, [_points], -1, COLOR_BLACK, thickness=-1
                )

            # visualize image and handle inputs
            cv.imshow(window_name, to_show)
            key = cv.waitKey(100)
            if key == ord("q"):
                exit(0)
            elif key == ord("d"):
                points = points[:-1]
            elif key == ord("n"):
                break
            elif key == ord("v"):
                visual_mode = (
                    VisualMode.DRAW_CONTOURS
                    if visual_mode == VisualMode.HIDE_BED
                    else VisualMode.HIDE_BED
                )

        # remove current window
        cv.destroyWindow(window_name)

        # save current contour in dict
        bed_contours[int(dataset.table.loc[i, "id"])] = points

    # write contours to output file
    with open(args.output_file, mode="w") as f:
        f.write(json.dumps(data, indent=2, sort_keys=True))