From 2ae5e9e386657c6ad2171bc4928143cd865c6ecf Mon Sep 17 00:00:00 2001
From: Tamino Huxohl <thuxohl@techfak.uni-bielefeld.de>
Date: Tue, 27 Sep 2022 10:04:49 +0200
Subject: [PATCH] restructuring to separate data processing from torch dataset
 implementations

---
 mu_map/data/preprocessing.py                  | 38 ------------------
 mu_map/dataset/__init__.py                    |  0
 .../{data/datasets.py => dataset/default.py}  |  0
 mu_map/{data => dataset}/mock.py              |  2 +-
 mu_map/dataset/normalization.py               | 40 +++++++++++++++++++
 .../patch_dataset.py => dataset/patches.py}   |  2 +-
 6 files changed, 42 insertions(+), 40 deletions(-)
 delete mode 100644 mu_map/data/preprocessing.py
 create mode 100644 mu_map/dataset/__init__.py
 rename mu_map/{data/datasets.py => dataset/default.py} (100%)
 rename mu_map/{data => dataset}/mock.py (98%)
 create mode 100644 mu_map/dataset/normalization.py
 rename mu_map/{data/patch_dataset.py => dataset/patches.py} (98%)

diff --git a/mu_map/data/preprocessing.py b/mu_map/data/preprocessing.py
deleted file mode 100644
index dc57c6f..0000000
--- a/mu_map/data/preprocessing.py
+++ /dev/null
@@ -1,38 +0,0 @@
-import torch
-
-
-def norm_max(tensor: torch.Tensor):
-    return (tensor - tensor.min()) / (tensor.max() - tensor.min())
-
-
-class MaxNorm:
-    def __call__(self, tensor: torch.Tensor):
-        return norm_max(tensor)
-
-
-def norm_mean(tensor: torch.Tensor):
-    return tensor / tensor.mean()
-
-
-class MeanNorm:
-    def __call__(self, tensor: torch.Tensor):
-        return norm_mean(tensor)
-
-
-def norm_gaussian(tensor: torch.Tensor):
-    return (tensor - tensor.mean()) / tensor.std()
-
-
-class GaussianNorm:
-    def __call__(self, tensor: torch.Tensor):
-        return norm_gaussian(tensor)
-
-
-__all__ = [
-    norm_max.__name__,
-    norm_mean.__name__,
-    norm_gaussian.__name__,
-    MaxNorm.__name__,
-    MeanNorm.__name__,
-    GaussianNorm.__name__,
-]
diff --git a/mu_map/dataset/__init__.py b/mu_map/dataset/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/mu_map/data/datasets.py b/mu_map/dataset/default.py
similarity index 100%
rename from mu_map/data/datasets.py
rename to mu_map/dataset/default.py
diff --git a/mu_map/data/mock.py b/mu_map/dataset/mock.py
similarity index 98%
rename from mu_map/data/mock.py
rename to mu_map/dataset/mock.py
index 2bb2b9a..de564b0 100644
--- a/mu_map/data/mock.py
+++ b/mu_map/dataset/mock.py
@@ -1,4 +1,4 @@
-from mu_map.data.datasets import MuMapDataset
+from mu_map.dataset.default import MuMapDataset
 
 
 class MuMapMockDataset(MuMapDataset):
diff --git a/mu_map/dataset/normalization.py b/mu_map/dataset/normalization.py
new file mode 100644
index 0000000..85ec9af
--- /dev/null
+++ b/mu_map/dataset/normalization.py
@@ -0,0 +1,40 @@
+from torch import Tensor
+
+from mu_map.dataset.transform import Transform
+
+
+def norm_max(tensor: Tensor) -> Tensor:
+    return (tensor - tensor.min()) / (tensor.max() - tensor.min())
+
+
+class MaxNormTransform(Transform):
+    def __call__(self, inputs: Tensor, outputs_expected: Tensor) -> Tuple[Tensor, Tensor]:
+        return norm_max(inputs), outputs_expected
+
+
+def norm_mean(tensor: Tensor):
+    return tensor / tensor.mean()
+
+
+class MeanNormTransform(Transform):
+    def __call__(self, inputs: Tensor, outputs_expected: Tensor) -> Tuple[Tensor, Tensor]:
+        return norm_mean(inputs), outputs_expected
+
+
+def norm_gaussian(tensor: Tensor):
+    return (tensor - tensor.mean()) / tensor.std()
+
+
+class GaussianNormTransform(Transform):
+    def __call__(self, inputs: Tensor, outputs_expected: Tensor) -> Tuple[Tensor, Tensor]:
+        return norm_gaussian(inputs), outputs_expected
+
+
+__all__ = [
+    norm_max.__name__,
+    norm_mean.__name__,
+    norm_gaussian.__name__,
+    MaxNormTransform.__name__,
+    MeanNormTransform.__name__,
+    GaussianNormTransform.__name__,
+]
diff --git a/mu_map/data/patch_dataset.py b/mu_map/dataset/patches.py
similarity index 98%
rename from mu_map/data/patch_dataset.py
rename to mu_map/dataset/patches.py
index 934ce82..df9f2c9 100644
--- a/mu_map/data/patch_dataset.py
+++ b/mu_map/dataset/patches.py
@@ -4,7 +4,7 @@ import random
 import numpy as np
 import torch
 
-from mu_map.data.datasets import MuMapDataset
+from mu_map.dataset.default import MuMapDataset
 
 
 class MuMapPatchDataset(MuMapDataset):
-- 
GitLab