import numpy as np
import pandas as pd


HEADER_DISC_FIRST = "discard_first"
HEADER_DISC_LAST = "discard_last"


def discard_slices(row: pd.Series, μ_map: np.ndarray) -> np.ndarray:
    """
    Discard slices based on the flags in the row of th according table.
    The row is expected to contain the flags 'discard_first' and 'discard_last'.

    :param row: the row of meta configuration file of a dataset
    :param μ_map: the μ_map
    :return: the μ_map with according slices removed
    """
    _res = μ_map

    if row[HEADER_DISC_FIRST]:
        _res = _res[1:]

    if row[HEADER_DISC_LAST]:
        _res = _res[:-1]

    return _res


if __name__ == "__main__":
    import argparse

    import cv2 as cv

    from mu_map.data.datasets import MuMapDataset
    from mu_map.util import to_grayscale, COLOR_WHITE, COLOR_BLACK

    parser = argparse.ArgumentParser(
        description="review all μ-maps in a dataset for broken slices at the start or end",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )
    parser.add_argument(
        "dataset_dir", type=str, help="directory containing the dataset"
    )
    parser.add_argument(
        "--view", action="store_true", help="only visualize the current state"
    )
    parser.add_argument(
        "--force",
        action="store_true",
        help="do not ask for permission to replace discard flags in the table",
    )
    args = parser.parse_args()

    dataset = MuMapDataset(dataset_dir=args.dataset_dir, discard_μ_map_slices=args.view)

    if (
        not args.view
        and not args.force
        and (
            (
                HEADER_DISC_FIRST in dataset.table
                and not (dataset.table[HEADER_DISC_FIRST] == False).all()
            )
            or (
                HEADER_DISC_LAST in dataset.table
                and not (dataset.table[HEADER_DISC_LAST] == False).all()
            )
        )
    ):
        print(
            f"This operation is going to set all discard_first and discard_last flags in {dataset.csv_file}. If you only want to see the current state use the --view flag."
        )
        _input = input("Are you okay with setting these flags to false? (y/n): ")
        if _input.lower() != "y":
            exit(0)

    if not args.view:
        dataset.table[HEADER_DISC_FIRST] = False
        dataset.table[HEADER_DISC_LAST] = False

    wname = "μ-map"
    cv.namedWindow(wname, cv.WINDOW_NORMAL)
    cv.resizeWindow(wname, 1600, 900)
    controls = "l = discard last, f = discard first, b = discard both, n = discard none"

    for i, (_, μ_map) in enumerate(dataset):
        slice_first = to_grayscale(μ_map[0])
        slice_last = to_grayscale(μ_map[-1])

        slice_first = cv.resize(slice_first, (512, 512), interpolation=cv.INTER_AREA)
        slice_last = cv.resize(slice_last, (512, 512), interpolation=cv.INTER_AREA)

        cv.putText(
            slice_first, "First", (0, 30), cv.FONT_HERSHEY_SIMPLEX, 1, COLOR_BLACK, 3
        )
        cv.putText(
            slice_first, "First", (0, 30), cv.FONT_HERSHEY_SIMPLEX, 1, COLOR_WHITE, 1
        )

        cv.putText(
            slice_last, "Last", (0, 30), cv.FONT_HERSHEY_SIMPLEX, 1, COLOR_BLACK, 3
        )
        cv.putText(
            slice_last, "Last", (0, 30), cv.FONT_HERSHEY_SIMPLEX, 1, COLOR_WHITE, 1
        )

        space = np.full((slice_first.shape[0], 10), 239, np.uint8)

        to_show = np.hstack((slice_first, space, slice_last))
        textposition = (0, to_show.shape[0] - 12)
        cv.putText(
            to_show,
            controls,
            textposition,
            cv.FONT_HERSHEY_SIMPLEX,
            0.75,
            COLOR_BLACK,
            3,
        )
        cv.putText(
            to_show,
            controls,
            textposition,
            cv.FONT_HERSHEY_SIMPLEX,
            0.75,
            COLOR_WHITE,
            1,
        )

        cv.imshow(wname, to_show)
        key = cv.waitKey(0)
        if key == ord("q"):
            exit(0)
        elif key == ord("l"):
            dataset.table.loc[i, (HEADER_DISC_LAST)] = True
        elif key == ord("f"):
            dataset.table.loc[i, (HEADER_DISC_FIRST)] = True
        elif key == ord("b"):
            dataset.table.loc[i, (HEADER_DISC_LAST)] = True
            dataset.table.loc[i, (HEADER_DISC_FIRST)] = True

    if not args.view:
        dataset.table.to_csv(dataset.csv_file, index=False)