diff --git a/mu_map/eval/dataset_stats.py b/mu_map/eval/dataset_stats.py index b58bef7d1d664c5e2be34391456bf32d077f6d37..000361b9d2676d3d2a8d3bf15b766d0a901ed2d8 100644 --- a/mu_map/eval/dataset_stats.py +++ b/mu_map/eval/dataset_stats.py @@ -1,4 +1,7 @@ +from typing import Dict + import numpy as np +import pandas as pd from mu_map.data.prepare import headers from mu_map.data.split import split_csv @@ -8,9 +11,24 @@ def body_mass_index(weight: np.ndarray, size: np.ndarray) -> np.ndarray: return weight / (size**2) +def evaluate_protocols(data: pd.DataFrame) -> Dict[str, int]: + data[headers.datetime_acquisition] = pd.to_datetime( + data[headers.datetime_acquisition] + ) + + protocols = {"stress": 0, "rest": 0, "stress/rest": 0, "rest/stress": 0} + for patient_id in data[headers.patient_id].unique(): + _data = data[data[headers.patient_id] == patient_id] + _data = _data.sort_values(by=headers.datetime_acquisition) + protocol = _data[headers.protocol].unique() + protocol = list(map(lambda _str: _str.lower(), protocol)) + protocol = "/".join(protocol) + protocols[protocol] += 1 + return protocols + + if __name__ == "__main__": import argparse - import pandas as pd parser = argparse.ArgumentParser( description="Evaluate dataset patient stats.", @@ -23,12 +41,16 @@ if __name__ == "__main__": print() data = pd.read_csv(args.meta_csv) - splits = split_csv(data, args.split_csv) - _from = data[headers.datetime_acquisition].min().split(" ")[0] _to = data[headers.datetime_acquisition].max().split(" ")[0] print(f"Scans were performed from {_from} to {_to}") + protocols = evaluate_protocols(data) + print("Protocols:") + for protocol, count in protocols.items(): + print(f" - {protocol:>12}: {str(count):>3}") + print() + splits = split_csv(data, args.split_csv) for split_name, split in splits.items(): print(f"{split_name[0].upper()}{split_name[1:]}:")