Skip to content
Snippets Groups Projects
Commit 44ac5903 authored by Tamino Huxohl's avatar Tamino Huxohl
Browse files

add a script to create dataset split based on patient ids and provide utility...

add a script to create dataset split based on patient ids and provide utility function to load a split
parent 0cc09b15
No related branches found
No related tags found
No related merge requests found
import pandas as pd
from typing import List
SPLIT_TRAIN = "train"
SPLIT_VALIDATION = "validation"
SPLIT_TEST = "test"
def parse_split_str(_str: str, delimitier: str = "/") -> List[float]:
"""
Parse a string into a list of proportions representing the split.
The string should have the format 70/15/15, where / can be replaced be a specified delimiter.
The numbers must add up to 100.
:param _str: the string to be parsed
:param delimitier: the delimiter used to split the provided string
:return: a list of floats representing the percentages of the split
"""
split_as_str = _str.split(delimitier)
split_as_int = list(map(int, split_as_str))
if sum(split_as_int) != 100:
raise ValueError(f"Invalid split {_str}! It does not add up to 100.")
split = list(map(lambda num: num / 100, split_as_int))
return split
def split_csv(data: pd.DataFrame, split_csv: str) -> List[pd.DataFrame]:
"""
Split a data frames based on a file defining a split.
:param data: the data frame to be split
:param split_csv: the filename of the csv file defining the split
:return: a list of sub-data frames forming the splits
"""
split_data = pd.read_csv(split_csv)
split_names = split_data["split"].unique()
split_patient_ids = map(
lambda name: split_data[split_data["split"] == name]["patient_id"], split_names
)
splits = map(
lambda patient_ids: data[data["patient_id"].isin(patient_ids)],
split_patient_ids,
)
return list(splits)
if __name__ == "__main__":
import argparse
import os
import random
import numpy as np
parser = argparse.ArgumentParser(
description="split a dataset by patient id",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument("dataset_dir", type=str, help="directory of the dataset")
parser.add_argument(
"--output_file",
type=str,
default="split.csv",
help="name of the generated file containing the split",
)
parser.add_argument(
"--split",
type=str,
default="70/15/15",
help="the split as train/validation/test",
)
args = parser.parse_args()
args.output_file = os.path.join(args.dataset_dir, args.output_file)
split = parse_split_str(args.split)
data = pd.read_csv(os.path.join(args.dataset_dir, "meta.csv"))
ids = np.array(data["patient_id"].unique())
indices = list(range(len(ids)))
random.shuffle(indices)
lower = round(len(indices) * split[0])
upper = round(len(indices) * (split[0] + split[1]))
indices_train = indices[:lower]
indices_validation = indices[lower:upper]
indices_test = indices[upper:]
data_split = pd.DataFrame(
{
"patient_id": [
*list(ids[indices_train]),
*list(ids[indices_validation]),
*list(ids[indices_test]),
],
"split": [
*(["train"] * len(indices_train)),
*(["validation"] * len(indices_validation)),
*(["test"] * len(indices_test)),
],
}
)
data_split.to_csv(args.output_file, index=False)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment