Skip to content
Snippets Groups Projects
prepare_polar_maps.py 6.88 KiB
Newer Older
  • Learn to ignore specific revisions
  • import argparse
    import os
    from typing import List
    
    from mu_map.file.dicom import dcm_type
    
    
    def get_files_recursive(_dir: str) -> List[str]:
        """
        Recursively get all files in a directory. This means that all
        sub-directories are recursively searched until a file is reached.
    
        :param _dir: the directory of which files are listed
        :return: a list of files in the directory and its sub-directories
        """
        files: List[str] = []
        for _file in os.listdir(_dir):
            _file = os.path.join(_dir, _file)
    
            if os.path.isdir(_file):
                files.extend(get_files_recursive(_file))
            else:
                files.append(_file)
        return files
    
    
    def is_scatter_corrected(dcm: dcm_type) -> bool:
        return not ("NoSC" in dcm.SeriesDescription)
    
    
    def is_attenuation_corrected(dcm: dcm_type) -> bool:
        return not ("NoAC" in dcm.SeriesDescription)
    
    
    def get_type(dcm: dcm_type) -> str:
        description = dcm.SeriesDescription.lower()
        if "syn" in description:
            return "synthetic"
        elif "ct" in description:
            return "ct"
        else:
            return "symbia"
    
    
    headers = argparse.Namespace()
    headers.id = "id"
    headers.scatter_correction = "scatter_correction"
    headers.sc = headers.scatter_correction
    headers.attenuation_correction = "attenuation_correction"
    headers.ac = headers.attenuation_correction
    headers.mu_map_type = "mu_map_type"
    headers.type = headers.mu_map_type
    headers.file = "file"
    
    
    if __name__ == "__main__":
        import cv2 as cv
        import numpy as np
        import pandas as pd
    
        from mu_map.data.prepare import get_protocol
        from mu_map.data.prepare import headers as meta_headers
        from mu_map.file.dicom import load_dcm, DCM_TAG_PIXEL_SCALE_FACTOR, dcm_type
        from mu_map.logging import add_logging_args, get_logger_by_args
    
        parser = argparse.ArgumentParser(
            description="Store/Sort in polar maps for further processing",
            formatter_class=argparse.ArgumentDefaultsHelpFormatter,
        )
        parser.add_argument(
            "--polar_map_dir",
            type=str,
            required=True,
            help="directory where the raw DICOM files of polar maps are stored",
        )
        parser.add_argument(
            "--out_dir", type=str, required=True, help="directory where files are output to"
        )
        parser.add_argument(
            "--images_dir",
            type=str,
            default="images",
            help="directory under <out_dir> where images of polar maps are stored",
        )
        parser.add_argument(
            "--csv",
            type=str,
            default="polar_maps.csv",
            help="file unter <out_dir> where meta information is stored",
        )
        parser.add_argument(
            "--meta_csv",
            type=str,
            required=True,
            help="the csv file containing meta information of the dataset of which the generated polar maps are",
        )
        parser.add_argument(
            "--rect",
            type=int,
            nargs=4,
            default=[21, 598, 425, 1002],
            help="rectangle as [top, left, bottom, right] coordinates where the polar map is found in the DICOM images",
        )
        parser.add_argument(
            "-i",
            "--interactive",
            action="store_true",
            help="start interactive mode where every polar map is displayed and confirmation is queried",
        )
        add_logging_args(parser, defaults={"--logfile": "prepare.log"})
        args = parser.parse_args()
    
        args.images_dir = os.path.join(args.out_dir, args.images_dir)
        args.csv = os.path.join(args.out_dir, args.csv)
        args.logfile = os.path.join(args.out_dir, args.logfile)
        args.loglevel = "DEBUG" if args.interactive else args.loglevel
    
        if args.interactive:
            print(
                """
        Start in interactive mode.
        This mode automatically sets the loglevel to debug so that you can see what is going in.
        Use the following keys on the displayed polar map:
            
            n(ext): go to the next image without saving the polar map
            s(ave): save the polar map and go to the next
            q(uit): quit the application
    
        """
            )
    
        logger = get_logger_by_args(args)
        logger.info(args)
    
        if not os.path.exists(args.out_dir):
            os.mkdir(args.out_dir)
        if not os.path.exists(args.images_dir):
            os.mkdir(args.images_dir)
    
        meta = pd.read_csv(args.meta_csv)
        dcm_files = sorted(get_files_recursive(args.polar_map_dir))
    
        data = pd.DataFrame(
            {
                headers.id: [],
                headers.sc: [],
                headers.ac: [],
                headers.type: [],
                headers.file: [],
            }
        )
        data[headers.id] = data[headers.sc].astype(int)
        data[headers.sc] = data[headers.sc].astype(bool)
        data[headers.ac] = data[headers.ac].astype(bool)
        data[headers.type] = data[headers.ac].astype(str)
        data[headers.file] = data[headers.ac].astype(str)
        for dcm_file in dcm_files:
            logger.debug(f"Process file {dcm_file}")
    
            dcm, img = load_dcm(dcm_file)
    
            protocol = get_protocol(dcm)
            patient_id = int(dcm.PatientID)
            meta_row = meta[
                (meta[meta_headers.patient_id] == patient_id)
                & (meta[meta_headers.protocol] == protocol)
            ].iloc[0]
    
            row = {
                headers.id: meta_row[meta_headers.id],
                headers.sc: is_scatter_corrected(dcm),
                headers.ac: is_attenuation_corrected(dcm),
                headers.type: get_type(dcm),
            }
    
            top, left, bottom, right = args.rect
            polar_map = img[top:bottom, left:right]
            polar_map = cv.cvtColor(polar_map, cv.COLOR_RGB2BGR)
    
            _ac = "ac" if row[headers.ac] else "nac"
            _sc = "sc" if row[headers.sc] else "nsc"
            _file = f"{row[headers.id]:04d}-{_ac}_{_sc}_{row[headers.type]}.png"
            row[headers.file] = _file
    
            _file = os.path.join(args.images_dir, _file)
    
    
            if (
                len(
                    data[
                        (data[headers.id] == row[headers.id])
                        & (data[headers.sc] == row[headers.sc])
                        & (data[headers.ac] == row[headers.ac])
                        & (data[headers.type] == row[headers.type])
                    ]
                )
                > 0
            ):
                logger.warning(f"Skip {dcm_file} as it is a duplicate for row: {row}")
                continue
    
            logger.debug(f"For series {dcm.SeriesDescription} store row {row}")
    
            store = not args.interactive
            if args.interactive:
                while True:
                    cv.imshow("Polar Map", cv.resize(polar_map, (512, 512)))
                    key = cv.waitKey(0)
    
                    if key == ord("q"):
                        exit(0)
                    elif key == ord("n"):
                        break
                    elif key == ord("s"):
                        store = True
                        break
    
            if store:
                cv.imwrite(_file, polar_map)
    
                row = pd.DataFrame(row, index=[0])
                data = pd.concat((data, row), ignore_index=True)
    
                data = data.sort_values(by=[headers.id, headers.file])
                data.to_csv(args.csv, index=False)