From baf323fd04d3f0d1dc54a0534746081e99f9f4cc Mon Sep 17 00:00:00 2001
From: Tamino Huxohl <thuxohl@techfak.uni-bielefeld.de>
Date: Thu, 22 Dec 2022 14:10:28 +0100
Subject: [PATCH] implement transformer between numpy and torch format

---
 mu_map/dataset/transform.py | 15 +++++++++++++++
 1 file changed, 15 insertions(+)

diff --git a/mu_map/dataset/transform.py b/mu_map/dataset/transform.py
index e2fad43..1d41614 100644
--- a/mu_map/dataset/transform.py
+++ b/mu_map/dataset/transform.py
@@ -1,6 +1,7 @@
 import math
 from typing import List, Tuple
 
+import numpy as np
 import torch
 from torch import Tensor
 
@@ -126,6 +127,20 @@ class PadCropTranform(SequenceTransform):
         )
 
 
+class ToNumpy():
+    """
+    Special transformer that converts torch tensors to numpy array.
+    """
+    def __call__(self, *tensors: Tensor) -> Tuple[np.ndarray, ...]:
+        return tuple(map(lambda tensor: tensor.numpy(), tensors))
+
+class ToTensor():
+    """
+    Special transformer that converts numpy arrays to torch tensors.
+    """
+    def __call__(self, *tensors: np.ndarray) -> Tuple[Tensor, ...]:
+        return tuple(map(lambda tensor: torch.from_numpy(tensor), tensors))
+
 if __name__ == "__main__":
     transform = PadCropTranform(dim=3, size=32)
 
-- 
GitLab