From 2ae5e9e386657c6ad2171bc4928143cd865c6ecf Mon Sep 17 00:00:00 2001 From: Tamino Huxohl <thuxohl@techfak.uni-bielefeld.de> Date: Tue, 27 Sep 2022 10:04:49 +0200 Subject: [PATCH] restructuring to separate data processing from torch dataset implementations --- mu_map/data/preprocessing.py | 38 ------------------ mu_map/dataset/__init__.py | 0 .../{data/datasets.py => dataset/default.py} | 0 mu_map/{data => dataset}/mock.py | 2 +- mu_map/dataset/normalization.py | 40 +++++++++++++++++++ .../patch_dataset.py => dataset/patches.py} | 2 +- 6 files changed, 42 insertions(+), 40 deletions(-) delete mode 100644 mu_map/data/preprocessing.py create mode 100644 mu_map/dataset/__init__.py rename mu_map/{data/datasets.py => dataset/default.py} (100%) rename mu_map/{data => dataset}/mock.py (98%) create mode 100644 mu_map/dataset/normalization.py rename mu_map/{data/patch_dataset.py => dataset/patches.py} (98%) diff --git a/mu_map/data/preprocessing.py b/mu_map/data/preprocessing.py deleted file mode 100644 index dc57c6f..0000000 --- a/mu_map/data/preprocessing.py +++ /dev/null @@ -1,38 +0,0 @@ -import torch - - -def norm_max(tensor: torch.Tensor): - return (tensor - tensor.min()) / (tensor.max() - tensor.min()) - - -class MaxNorm: - def __call__(self, tensor: torch.Tensor): - return norm_max(tensor) - - -def norm_mean(tensor: torch.Tensor): - return tensor / tensor.mean() - - -class MeanNorm: - def __call__(self, tensor: torch.Tensor): - return norm_mean(tensor) - - -def norm_gaussian(tensor: torch.Tensor): - return (tensor - tensor.mean()) / tensor.std() - - -class GaussianNorm: - def __call__(self, tensor: torch.Tensor): - return norm_gaussian(tensor) - - -__all__ = [ - norm_max.__name__, - norm_mean.__name__, - norm_gaussian.__name__, - MaxNorm.__name__, - MeanNorm.__name__, - GaussianNorm.__name__, -] diff --git a/mu_map/dataset/__init__.py b/mu_map/dataset/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/mu_map/data/datasets.py b/mu_map/dataset/default.py similarity index 100% rename from mu_map/data/datasets.py rename to mu_map/dataset/default.py diff --git a/mu_map/data/mock.py b/mu_map/dataset/mock.py similarity index 98% rename from mu_map/data/mock.py rename to mu_map/dataset/mock.py index 2bb2b9a..de564b0 100644 --- a/mu_map/data/mock.py +++ b/mu_map/dataset/mock.py @@ -1,4 +1,4 @@ -from mu_map.data.datasets import MuMapDataset +from mu_map.dataset.default import MuMapDataset class MuMapMockDataset(MuMapDataset): diff --git a/mu_map/dataset/normalization.py b/mu_map/dataset/normalization.py new file mode 100644 index 0000000..85ec9af --- /dev/null +++ b/mu_map/dataset/normalization.py @@ -0,0 +1,40 @@ +from torch import Tensor + +from mu_map.dataset.transform import Transform + + +def norm_max(tensor: Tensor) -> Tensor: + return (tensor - tensor.min()) / (tensor.max() - tensor.min()) + + +class MaxNormTransform(Transform): + def __call__(self, inputs: Tensor, outputs_expected: Tensor) -> Tuple[Tensor, Tensor]: + return norm_max(inputs), outputs_expected + + +def norm_mean(tensor: Tensor): + return tensor / tensor.mean() + + +class MeanNormTransform(Transform): + def __call__(self, inputs: Tensor, outputs_expected: Tensor) -> Tuple[Tensor, Tensor]: + return norm_mean(inputs), outputs_expected + + +def norm_gaussian(tensor: Tensor): + return (tensor - tensor.mean()) / tensor.std() + + +class GaussianNormTransform(Transform): + def __call__(self, inputs: Tensor, outputs_expected: Tensor) -> Tuple[Tensor, Tensor]: + return norm_gaussian(inputs), outputs_expected + + +__all__ = [ + norm_max.__name__, + norm_mean.__name__, + norm_gaussian.__name__, + MaxNormTransform.__name__, + MeanNormTransform.__name__, + GaussianNormTransform.__name__, +] diff --git a/mu_map/data/patch_dataset.py b/mu_map/dataset/patches.py similarity index 98% rename from mu_map/data/patch_dataset.py rename to mu_map/dataset/patches.py index 934ce82..df9f2c9 100644 --- a/mu_map/data/patch_dataset.py +++ b/mu_map/dataset/patches.py @@ -4,7 +4,7 @@ import random import numpy as np import torch -from mu_map.data.datasets import MuMapDataset +from mu_map.dataset.default import MuMapDataset class MuMapPatchDataset(MuMapDataset): -- GitLab