From 44ac5903101698b6e982574896c265cb8e6dfd9e Mon Sep 17 00:00:00 2001 From: Tamino Huxohl <thuxohl@techfak.uni-bielefeld.de> Date: Thu, 1 Sep 2022 13:30:06 +0200 Subject: [PATCH] add a script to create dataset split based on patient ids and provide utility function to load a split --- mu_map/data/split.py | 105 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 105 insertions(+) create mode 100644 mu_map/data/split.py diff --git a/mu_map/data/split.py b/mu_map/data/split.py new file mode 100644 index 0000000..ef39df2 --- /dev/null +++ b/mu_map/data/split.py @@ -0,0 +1,105 @@ +import pandas as pd + +from typing import List + +SPLIT_TRAIN = "train" +SPLIT_VALIDATION = "validation" +SPLIT_TEST = "test" + + +def parse_split_str(_str: str, delimitier: str = "/") -> List[float]: + """ + Parse a string into a list of proportions representing the split. + The string should have the format 70/15/15, where / can be replaced be a specified delimiter. + The numbers must add up to 100. + + :param _str: the string to be parsed + :param delimitier: the delimiter used to split the provided string + :return: a list of floats representing the percentages of the split + """ + split_as_str = _str.split(delimitier) + split_as_int = list(map(int, split_as_str)) + + if sum(split_as_int) != 100: + raise ValueError(f"Invalid split {_str}! It does not add up to 100.") + + split = list(map(lambda num: num / 100, split_as_int)) + return split + + +def split_csv(data: pd.DataFrame, split_csv: str) -> List[pd.DataFrame]: + """ + Split a data frames based on a file defining a split. + + :param data: the data frame to be split + :param split_csv: the filename of the csv file defining the split + :return: a list of sub-data frames forming the splits + """ + split_data = pd.read_csv(split_csv) + split_names = split_data["split"].unique() + split_patient_ids = map( + lambda name: split_data[split_data["split"] == name]["patient_id"], split_names + ) + splits = map( + lambda patient_ids: data[data["patient_id"].isin(patient_ids)], + split_patient_ids, + ) + return list(splits) + + +if __name__ == "__main__": + import argparse + import os + import random + + import numpy as np + + parser = argparse.ArgumentParser( + description="split a dataset by patient id", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument("dataset_dir", type=str, help="directory of the dataset") + parser.add_argument( + "--output_file", + type=str, + default="split.csv", + help="name of the generated file containing the split", + ) + parser.add_argument( + "--split", + type=str, + default="70/15/15", + help="the split as train/validation/test", + ) + args = parser.parse_args() + args.output_file = os.path.join(args.dataset_dir, args.output_file) + + split = parse_split_str(args.split) + data = pd.read_csv(os.path.join(args.dataset_dir, "meta.csv")) + + ids = np.array(data["patient_id"].unique()) + + indices = list(range(len(ids))) + random.shuffle(indices) + + lower = round(len(indices) * split[0]) + upper = round(len(indices) * (split[0] + split[1])) + indices_train = indices[:lower] + indices_validation = indices[lower:upper] + indices_test = indices[upper:] + + data_split = pd.DataFrame( + { + "patient_id": [ + *list(ids[indices_train]), + *list(ids[indices_validation]), + *list(ids[indices_test]), + ], + "split": [ + *(["train"] * len(indices_train)), + *(["validation"] * len(indices_validation)), + *(["test"] * len(indices_test)), + ], + } + ) + data_split.to_csv(args.output_file, index=False) -- GitLab