Skip to content
Snippets Groups Projects
remove_bed.py 8.94 KiB
Newer Older
  • Learn to ignore specific revisions
  • """
    Module containing a GUI to label contours of a bed to be extracted from
    an attenuation maps.
    Additionally, it contains utility functions to load bed contours, remove or
    add the bad.
    """
    
    from typing import Dict, List
    
    from mu_map.dataset.util import align_images
    
    
    
    DEFAULT_BED_CONTOURS_FILENAME = "bed_contours.json"
    
    
    
    def load_contours(filename: str, as_ndarry: bool = True) -> Dict[int, np.ndarray]:
    
        Load contours from a JSON file.
    
        The structure of the file is a dict where the key is the id of the according
        image and the value is a numpy array of the contour.
    
    
        Parameters
        ----------
        filename: str
            filename of a JSON file containing contours
        as_ndarry: bool
            directly parse contours as numpy arrays
    
        Parameters
        ----------
        Dict
            a dict mapping ids to contours either as lists of int or np.arrays
    
        """
        with open(filename, mode="r") as f:
            contours = json.load(f)
    
    
        if not as_ndarry:
            return contours
    
    
        _map = map(
            lambda item: (int(item[0]), np.array(item[1]).astype(int)), contours.items()
        )
    
        return dict(_map)
    
    def remove_bed(mu_map: np.ndarray, bed_contour: np.ndarray) -> np.ndarray:
    
        """
        Remove the bed defined by a contour from all slices.
    
        Parameters
        ----------
        mu_map: np.ndarray
            the mu_map from which the bed is removed
        bed_contour: np.ndarray
            the contour describing where the bed is found
    
        Returns
        -------
        np.ndarray
            the mu_map with the bed removed
    
        _mu_map = mu_map.copy()
        for i in range(_mu_map.shape[0]):
            mu_map[i] = cv.drawContours(_mu_map[i], [bed_contour], -1, 0.0, -1)
        return _mu_map
    
    
    def add_bed(without_bed: np.ndarray, with_bed: np.ndarray, bed_contour: np.ndarray):
        """
        Add the bed to every slice of a mu_map.
    
    
        Parameters
        ----------
        without_bed: np.ndarray
            the mu_map without the bed
        with_bed: np.ndarray
            the mu_map with the bed
        bed_contour: np.ndarray
            the contour defining the location of the bed
    
        Returns
        -------
        np.ndarray
            the mu_map with the bed added
    
        """
        with_bed, without_bed = align_images(with_bed, without_bed)
    
        for _slice in range(with_bed.shape[0]):
            with_bed_i = with_bed[_slice]
            without_bed_i = without_bed[_slice]
    
            cnt_img = np.zeros(without_bed_i.shape, dtype=np.uint8)
            cnt_img = cv.drawContours(cnt_img, [bed_contour], -1, 255, -1)
    
            without_bed[_slice] = np.where(cnt_img > 0, with_bed_i, without_bed_i)
    
        return without_bed
    
    
    
    
        def scale_points(points: List[List[int]], scale: float):
            """
            Utility function to scale all points in a list of points.
            """
            for i in range(len(points)):
                for j in range(len(points[i])):
                    points[i][j] = round(points[i][j] * scale)
    
    
        import argparse
        from enum import Enum
        import os
    
    
        from mu_map.data.prepare import headers
        from mu_map.dataset.default import MuMapDataset
    
        from mu_map.util import to_grayscale, COLOR_BLACK, COLOR_WHITE
    
        parser = argparse.ArgumentParser(
            description="draw and save contours to exclude the bed from mu maps",
            formatter_class=argparse.ArgumentDefaultsHelpFormatter,
        )
        parser.add_argument(
            "dataset_dir", type=str, help="the directory containing the dataset"
        )
    
        parser.add_argument(
            "--revise_ids",
            type=int,
            nargs="*",
            help="only revise the contour of an image withe a specified index",
        )
    
        parser.add_argument(
            "--output_file",
            type=str,
            default=DEFAULT_BED_CONTOURS_FILENAME,
            help="default file in dataset dir where the drawn contours are stored",
        )
        args = parser.parse_args()
        args.output_file = os.path.join(args.dataset_dir, args.output_file)
    
        controls = """
        Controls:
        Left click to add points to the current contour.
    
        q: exit
        d: delete last point
    
        c: delete all points
        f: forward (next slice)
        b: backward (previous slice)
    
        n: save contour and go to the next image
        v: change the visual mode between drawing contours and hiding the are within
        """
        print(controls)
        print()
    
        # set bed contours file to None so that existing contours are not used
        dataset = MuMapDataset(args.dataset_dir, bed_contours_file=None)
    
    
        if os.path.isfile(args.output_file):
            try:
                bed_contours = load_contours(args.output_file, as_ndarry=False)
            except:
                print(f"JSON file {args.output_file} is corrupted! Create a new one.")
                bed_contours = {}
    
                with open(args.output_file, mode="w") as f:
                    f.write(json.dumps(bed_contours, sort_keys=True))
    
        else:
            bed_contours = {}
    
    
        class VisualMode(Enum):
            DRAW_CONTOURS = 1
            HIDE_BED = 2
    
    
        # save the points of the contour in a list and defined a mouse callback
        points = []
    
        def mouse_callback(event, x, y, flags, param):
            if event == cv.EVENT_LBUTTONUP:
                points.append([x, y])
    
        # create a window for display
    
        window_size = 1024
        cv.namedWindow(window_name, cv.WINDOW_NORMAL)
        cv.resizeWindow(window_name, window_size, window_size)
        cv.setMouseCallback(window_name, mouse_callback)
    
        ids = list(dataset.table[headers.id])
        if args.revise_ids:
            ids = args.revise_ids
    
        for _i, _id in enumerate(ids):
    
            _, mu_map = dataset.get_item_by_id(_id)
    
    
            if str(_id) in bed_contours and not args.revise_ids:
    
                print(f"Skip {_id} because file already contains these contours")
                continue
    
    
            if args.revise_ids and str(_id) in bed_contours:
                points.extend(bed_contours[str(_id)].copy())
    
            print(
                f"Image {str(_i + 1):>{len(str(len(ids)))}}/{len(ids)}, ID: {_id:>{len(str(max(ids)))}}",
                end="\r",
            )
            # select the center slice for display (the bed location is constant over all slices)
            mu_map = mu_map.squeeze().numpy()
            _slice = 0
    
            scale = window_size / mu_map.shape[1]
            scale_points(points, scale)
    
    
            # set initial visual mode
            visual_mode = VisualMode.DRAW_CONTOURS
            while True:
                # compute image to display
    
                to_show = mu_map[_slice]
    
    Tamino Huxohl's avatar
    Tamino Huxohl committed
                to_show = to_grayscale(to_show, min_val=mu_map.min(), max_val=mu_map.max())
    
                to_show = cv.resize(to_show, (window_size, window_size))
    
    
                if visual_mode == VisualMode.DRAW_CONTOURS:
                    # draw lines between all points
                    for p1, p2 in zip(points[:-1], points[1:]):
    
                        to_show = cv.line(to_show, p1, p2, color=COLOR_WHITE, thickness=2)
    
                    # close the contour
                    if len(points) > 0:
                        to_show = cv.line(
    
                            to_show, points[0], points[-1], color=COLOR_WHITE, thickness=2
    
                        )
    
                    # draw all points as circles
                    for point in points:
                        to_show = cv.circle(
    
                            to_show, point, radius=4, color=COLOR_BLACK, thickness=-1
    
                            to_show, point, radius=4, color=COLOR_WHITE, thickness=1
    
                        )
                else:
                    # eliminate area inside the contour
                    _points = np.array(points).astype(int)
                    to_show = cv.drawContours(
                        to_show, [_points], -1, COLOR_BLACK, thickness=-1
                    )
    
                # visualize image and handle inputs
                cv.imshow(window_name, to_show)
                key = cv.waitKey(100)
                if key == ord("q"):
    
                    # write current contours to output file
    
    Tamino Huxohl's avatar
    Tamino Huxohl committed
                    if len(points) > 0:
                        scale_points(points, 1.0 / scale)
                        bed_contours[str(_id)] = points.copy()
                        with open(args.output_file, mode="w") as f:
                            f.write(json.dumps(bed_contours, sort_keys=True))
    
                    exit(0)
                elif key == ord("d"):
                    points = points[:-1]
    
                elif key == ord("c"):
                    points.clear()
    
                elif key == ord("f"):
                    _slice = (_slice + 1) % mu_map.shape[0]
                elif key == ord("b"):
                    _slice = (_slice - 1) % mu_map.shape[0]
    
                elif key == ord("v"):
                    visual_mode = (
                        VisualMode.DRAW_CONTOURS
                        if visual_mode == VisualMode.HIDE_BED
                        else VisualMode.HIDE_BED
                    )
    
            # save current contour in dict
    
            scale_points(points, 1.0 / scale)
            bed_contours[str(_id)] = points.copy()
            points.clear()
    
            # write contours to output file
            with open(args.output_file, mode="w") as f:
                f.write(json.dumps(bed_contours, sort_keys=True))