Skip to content
Snippets Groups Projects
remove_bed.py 8.45 KiB
import json
from typing import Dict, List

import cv2 as cv
import numpy as np

from mu_map.dataset.util import align_images


DEFAULT_BED_CONTOURS_FILENAME = "bed_contours.json"


def load_contours(filename: str, as_ndarry: bool = True) -> Dict[int, np.ndarray]:
    """
    Load contours from a json file.
    The structure of the file is a dict where the key is the id of the according
    image and the value is a numpy array of the contour.

    :param filename: filename of a json file containing contours
    :param as_ndarry: directly parse contours as numpy arrays
    :return: a dict mapping ids to contours
    """
    with open(filename, mode="r") as f:
        contours = json.load(f)

    if not as_ndarry:
        return contours

    _map = map(
        lambda item: (int(item[0]), np.array(item[1]).astype(int)), contours.items()
    )
    return dict(_map)


def remove_bed(mu_map: np.ndarray, bed_contour: np.ndarray):
    """
    Remove the bed defined by a contour from all slices.

    :param mu_map: the mu_map from which the bed is removed.
    :param bed_contour: the contour describing where the bed is found
    :return: the mu_map with the bed removed
    """
    _mu_map = mu_map.copy()
    for i in range(_mu_map.shape[0]):
        mu_map[i] = cv.drawContours(_mu_map[i], [bed_contour], -1, 0.0, -1)
    return _mu_map


def add_bed(without_bed: np.ndarray, with_bed: np.ndarray, bed_contour: np.ndarray):
    """
    Add the bed to every slice of a mu_map.

    :param without_bed: the mu_map without the bed
    :param with_bed: the mu_map with the bed
    :param bed_contour: the contour defining the location of the bed
    :return: the mu_map with the bed added
    """
    with_bed, without_bed = align_images(with_bed, without_bed)

    for _slice in range(with_bed.shape[0]):
        with_bed_i = with_bed[_slice]
        without_bed_i = without_bed[_slice]

        cnt_img = np.zeros(without_bed_i.shape, dtype=np.uint8)
        cnt_img = cv.drawContours(cnt_img, [bed_contour], -1, 255, -1)

        without_bed[_slice] = np.where(cnt_img > 0, with_bed_i, without_bed_i)

    return without_bed


if __name__ == "__main__":

    def scale_points(points: List[List[int]], scale: float):
        """
        Utility function to scale all points in a list of points.
        """
        for i in range(len(points)):
            for j in range(len(points[i])):
                points[i][j] = round(points[i][j] * scale)

    import argparse
    from enum import Enum
    import os

    from mu_map.data.prepare import headers
    from mu_map.dataset.default import MuMapDataset
    from mu_map.util import to_grayscale, COLOR_BLACK, COLOR_WHITE

    parser = argparse.ArgumentParser(
        description="draw and save contours to exclude the bed from mu maps",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )
    parser.add_argument(
        "dataset_dir", type=str, help="the directory containing the dataset"
    )
    parser.add_argument(
        "--revise_ids",
        type=int,
        nargs="*",
        help="only revise the contour of an image withe a specified index",
    )
    parser.add_argument(
        "--output_file",
        type=str,
        default=DEFAULT_BED_CONTOURS_FILENAME,
        help="default file in dataset dir where the drawn contours are stored",
    )
    args = parser.parse_args()
    args.output_file = os.path.join(args.dataset_dir, args.output_file)

    controls = """
    Controls:
    Left click to add points to the current contour.

    q: exit
    d: delete last point
    c: delete all points
    f: forward (next slice)
    b: backward (previous slice)
    n: save contour and go to the next image
    v: change the visual mode between drawing contours and hiding the are within
    """
    print(controls)
    print()

    # set bed contours file to None so that existing contours are not used
    dataset = MuMapDataset(args.dataset_dir, bed_contours_file=None)

    if os.path.isfile(args.output_file):
        try:
            bed_contours = load_contours(args.output_file, as_ndarry=False)
        except:
            print(f"JSON file {args.output_file} is corrupted! Create a new one.")
            bed_contours = {}
            with open(args.output_file, mode="w") as f:
                f.write(json.dumps(bed_contours, sort_keys=True))
    else:
        bed_contours = {}

    class VisualMode(Enum):
        DRAW_CONTOURS = 1
        HIDE_BED = 2

    # save the points of the contour in a list and defined a mouse callback
    points = []

    def mouse_callback(event, x, y, flags, param):
        if event == cv.EVENT_LBUTTONUP:
            points.append([x, y])

    # create a window for display
    window_name = "Bed Removal"
    window_size = 1024
    cv.namedWindow(window_name, cv.WINDOW_NORMAL)
    cv.resizeWindow(window_name, window_size, window_size)
    cv.setMouseCallback(window_name, mouse_callback)

    ids = list(dataset.table[headers.id])
    if args.revise_ids:
        ids = args.revise_ids

    for _i, _id in enumerate(ids):
        _, mu_map = dataset.getitem_by_id(_id)

        if str(_id) in bed_contours and not args.revise_ids:
            print(f"Skip {_id} because file already contains these contours")
            continue

        if args.revise_ids and str(_id) in bed_contours:
            points.extend(bed_contours[str(_id)].copy())

        print(
            f"Image {str(_i + 1):>{len(str(len(ids)))}}/{len(ids)}, ID: {_id:>{len(str(max(ids)))}}",
            end="\r",
        )
        # select the center slice for display (the bed location is constant over all slices)
        mu_map = mu_map.squeeze().numpy()
        _slice = 0

        scale = window_size / mu_map.shape[1]
        scale_points(points, scale)

        # set initial visual mode
        visual_mode = VisualMode.DRAW_CONTOURS
        while True:
            # compute image to display
            to_show = mu_map[_slice]
            to_show = to_grayscale(to_show, min_val=mu_map.min(), max_val=mu_map.max())
            to_show = cv.resize(to_show, (window_size, window_size))

            if visual_mode == VisualMode.DRAW_CONTOURS:
                # draw lines between all points
                for p1, p2 in zip(points[:-1], points[1:]):
                    to_show = cv.line(to_show, p1, p2, color=COLOR_WHITE, thickness=2)
                # close the contour
                if len(points) > 0:
                    to_show = cv.line(
                        to_show, points[0], points[-1], color=COLOR_WHITE, thickness=2
                    )

                # draw all points as circles
                for point in points:
                    to_show = cv.circle(
                        to_show, point, radius=4, color=COLOR_BLACK, thickness=-1
                    )
                    to_show = cv.circle(
                        to_show, point, radius=4, color=COLOR_WHITE, thickness=1
                    )
            else:
                # eliminate area inside the contour
                _points = np.array(points).astype(int)
                to_show = cv.drawContours(
                    to_show, [_points], -1, COLOR_BLACK, thickness=-1
                )

            # visualize image and handle inputs
            cv.imshow(window_name, to_show)
            key = cv.waitKey(100)
            if key == ord("q"):
                # write current contours to output file
                if len(points) > 0:
                    scale_points(points, 1.0 / scale)
                    bed_contours[str(_id)] = points.copy()
                    with open(args.output_file, mode="w") as f:
                        f.write(json.dumps(bed_contours, sort_keys=True))
                exit(0)
            elif key == ord("d"):
                points = points[:-1]
            elif key == ord("c"):
                points.clear()
            elif key == ord("n"):
                break
            elif key == ord("f"):
                _slice = (_slice + 1) % mu_map.shape[0]
            elif key == ord("b"):
                _slice = (_slice - 1) % mu_map.shape[0]
            elif key == ord("v"):
                visual_mode = (
                    VisualMode.DRAW_CONTOURS
                    if visual_mode == VisualMode.HIDE_BED
                    else VisualMode.HIDE_BED
                )

        # save current contour in dict
        scale_points(points, 1.0 / scale)
        bed_contours[str(_id)] = points.copy()
        points.clear()

        # write contours to output file
        with open(args.output_file, mode="w") as f:
            f.write(json.dumps(bed_contours, sort_keys=True))