From 6442ed6870eb73159692078c41f158acbab4ea88 Mon Sep 17 00:00:00 2001 From: Tamino Huxohl <thuxohl@techfak.uni-bielefeld.de> Date: Tue, 27 Sep 2022 15:16:35 +0200 Subject: [PATCH] splits are now parsed as dicts --- mu_map/data/split.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/mu_map/data/split.py b/mu_map/data/split.py index b21dda7..a685e07 100644 --- a/mu_map/data/split.py +++ b/mu_map/data/split.py @@ -1,6 +1,6 @@ import pandas as pd -from typing import List +from typing import Dict, List def parse_split_str(_str: str, delimitier: str = "/") -> List[float]: @@ -23,7 +23,7 @@ def parse_split_str(_str: str, delimitier: str = "/") -> List[float]: return split -def split_csv(data: pd.DataFrame, split_csv: str) -> List[pd.DataFrame]: +def split_csv(data: pd.DataFrame, split_csv: str) -> Dict[str, pd.DataFrame]: """ Split a data frames based on a file defining a split. @@ -40,7 +40,9 @@ def split_csv(data: pd.DataFrame, split_csv: str) -> List[pd.DataFrame]: lambda patient_ids: data[data["patient_id"].isin(patient_ids)], split_patient_ids, ) - return list(splits) + splits = zip(split_names, splits) + splits = dict(splits) + return splits if __name__ == "__main__": -- GitLab