From baf323fd04d3f0d1dc54a0534746081e99f9f4cc Mon Sep 17 00:00:00 2001 From: Tamino Huxohl <thuxohl@techfak.uni-bielefeld.de> Date: Thu, 22 Dec 2022 14:10:28 +0100 Subject: [PATCH] implement transformer between numpy and torch format --- mu_map/dataset/transform.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/mu_map/dataset/transform.py b/mu_map/dataset/transform.py index e2fad43..1d41614 100644 --- a/mu_map/dataset/transform.py +++ b/mu_map/dataset/transform.py @@ -1,6 +1,7 @@ import math from typing import List, Tuple +import numpy as np import torch from torch import Tensor @@ -126,6 +127,20 @@ class PadCropTranform(SequenceTransform): ) +class ToNumpy(): + """ + Special transformer that converts torch tensors to numpy array. + """ + def __call__(self, *tensors: Tensor) -> Tuple[np.ndarray, ...]: + return tuple(map(lambda tensor: tensor.numpy(), tensors)) + +class ToTensor(): + """ + Special transformer that converts numpy arrays to torch tensors. + """ + def __call__(self, *tensors: np.ndarray) -> Tuple[Tensor, ...]: + return tuple(map(lambda tensor: torch.from_numpy(tensor), tensors)) + if __name__ == "__main__": transform = PadCropTranform(dim=3, size=32) -- GitLab