diff --git a/mu_map/eval/dataset_stats.py b/mu_map/eval/dataset_stats.py new file mode 100644 index 0000000000000000000000000000000000000000..6f10adf89f076233d4220eed9186d76a8ae00cda --- /dev/null +++ b/mu_map/eval/dataset_stats.py @@ -0,0 +1,70 @@ +import numpy as np + +from mu_map.data.prepare import headers +from mu_map.data.split import split_csv + + +def body_mass_index(weight: np.ndarray, size: np.ndarray) -> np.ndarray: + return weight / (size**2) + + +if __name__ == "__main__": + import argparse + import pandas as pd + + parser = argparse.ArgumentParser( + description="Evaluate dataset patient stats.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument("--meta_csv", type=str, default="data/second/meta.csv") + parser.add_argument("--split_csv", type=str, default="data/second/split.csv") + args = parser.parse_args() + print(args) + + data = pd.read_csv(args.meta_csv) + splits = split_csv(data, args.split_csv) + + for split_name, split in splits.items(): + print(f"{split_name[0].upper()}{split_name[1:]}:") + + n_studies = len(split) + print(f" - Studies: {n_studies}") + n_stress = len(split[split[headers.protocol] == "Stress"]) + n_rest = len(split[split[headers.protocol] == "Rest"]) + print(f" - Stress: {n_stress}") + print(f" - Rest: {n_rest}") + n_patients = len(split[headers.patient_id].unique()) + _group = split.groupby(headers.patient_id).count()[headers.id] + n_studies_one = (_group == 1).sum() + n_studies_two = (_group == 2).sum() + n_studies_three = (_group == 3).sum() + print( + f" - Patients: {n_patients} [{n_studies_one}, {n_studies_two}, {n_studies_three}]" + ) + _split = split.drop_duplicates(headers.patient_id) + n_males = len(_split[_split[headers.sex] == "M"]) + n_females = len(_split[_split[headers.sex] == "F"]) + print(f" - M: {n_males}") + print(f" - F: {n_females}") + + age = split[headers.age] + height = split[headers.size] + weight = split[headers.weight] + bmi = body_mass_index(weight, height) + + age = age.astype(int) + weight = weight.astype(int) + height = np.round(height * 100).astype(int) + bmi = np.round(bmi, 2) + + for stat, label in [ + (age, "Age"), + (weight, "Weight"), + (height, "Height"), + (bmi, "BMI"), + ]: + _min = f"{str(stat.min()):>5}" + _max = f"{str(stat.max()):>5}" + _mean = f"{stat.mean():.1f}" + _std = f"{stat.std():.1f}" + print(f" - {label:>10}: [{_min}, {_max}] - {_mean:>5}±{_std:>5}")