diff --git a/mu_map/data/split.py b/mu_map/data/split.py index b21dda72dc67fa462f3409aedf8108a3a205ea44..a685e07a15497ef492180fa2e08f36cb69a210c9 100644 --- a/mu_map/data/split.py +++ b/mu_map/data/split.py @@ -1,6 +1,6 @@ import pandas as pd -from typing import List +from typing import Dict, List def parse_split_str(_str: str, delimitier: str = "/") -> List[float]: @@ -23,7 +23,7 @@ def parse_split_str(_str: str, delimitier: str = "/") -> List[float]: return split -def split_csv(data: pd.DataFrame, split_csv: str) -> List[pd.DataFrame]: +def split_csv(data: pd.DataFrame, split_csv: str) -> Dict[str, pd.DataFrame]: """ Split a data frames based on a file defining a split. @@ -40,7 +40,9 @@ def split_csv(data: pd.DataFrame, split_csv: str) -> List[pd.DataFrame]: lambda patient_ids: data[data["patient_id"].isin(patient_ids)], split_patient_ids, ) - return list(splits) + splits = zip(split_names, splits) + splits = dict(splits) + return splits if __name__ == "__main__":