From 6442ed6870eb73159692078c41f158acbab4ea88 Mon Sep 17 00:00:00 2001
From: Tamino Huxohl <thuxohl@techfak.uni-bielefeld.de>
Date: Tue, 27 Sep 2022 15:16:35 +0200
Subject: [PATCH] splits are now parsed as dicts

---
 mu_map/data/split.py | 8 +++++---
 1 file changed, 5 insertions(+), 3 deletions(-)

diff --git a/mu_map/data/split.py b/mu_map/data/split.py
index b21dda7..a685e07 100644
--- a/mu_map/data/split.py
+++ b/mu_map/data/split.py
@@ -1,6 +1,6 @@
 import pandas as pd
 
-from typing import List
+from typing import Dict, List
 
 
 def parse_split_str(_str: str, delimitier: str = "/") -> List[float]:
@@ -23,7 +23,7 @@ def parse_split_str(_str: str, delimitier: str = "/") -> List[float]:
     return split
 
 
-def split_csv(data: pd.DataFrame, split_csv: str) -> List[pd.DataFrame]:
+def split_csv(data: pd.DataFrame, split_csv: str) -> Dict[str, pd.DataFrame]:
     """
     Split a data frames based on a file defining a split.
 
@@ -40,7 +40,9 @@ def split_csv(data: pd.DataFrame, split_csv: str) -> List[pd.DataFrame]:
         lambda patient_ids: data[data["patient_id"].isin(patient_ids)],
         split_patient_ids,
     )
-    return list(splits)
+    splits = zip(split_names, splits)
+    splits = dict(splits)
+    return splits
 
 
 if __name__ == "__main__":
-- 
GitLab