Skip to content
Snippets Groups Projects
Commit 2ae5e9e3 authored by Tamino Huxohl's avatar Tamino Huxohl
Browse files

restructuring to separate data processing from torch dataset implementations

parent 96e9d37b
No related branches found
No related tags found
No related merge requests found
import torch
def norm_max(tensor: torch.Tensor):
return (tensor - tensor.min()) / (tensor.max() - tensor.min())
class MaxNorm:
def __call__(self, tensor: torch.Tensor):
return norm_max(tensor)
def norm_mean(tensor: torch.Tensor):
return tensor / tensor.mean()
class MeanNorm:
def __call__(self, tensor: torch.Tensor):
return norm_mean(tensor)
def norm_gaussian(tensor: torch.Tensor):
return (tensor - tensor.mean()) / tensor.std()
class GaussianNorm:
def __call__(self, tensor: torch.Tensor):
return norm_gaussian(tensor)
__all__ = [
norm_max.__name__,
norm_mean.__name__,
norm_gaussian.__name__,
MaxNorm.__name__,
MeanNorm.__name__,
GaussianNorm.__name__,
]
File moved
from mu_map.data.datasets import MuMapDataset from mu_map.dataset.default import MuMapDataset
class MuMapMockDataset(MuMapDataset): class MuMapMockDataset(MuMapDataset):
......
from torch import Tensor
from mu_map.dataset.transform import Transform
def norm_max(tensor: Tensor) -> Tensor:
return (tensor - tensor.min()) / (tensor.max() - tensor.min())
class MaxNormTransform(Transform):
def __call__(self, inputs: Tensor, outputs_expected: Tensor) -> Tuple[Tensor, Tensor]:
return norm_max(inputs), outputs_expected
def norm_mean(tensor: Tensor):
return tensor / tensor.mean()
class MeanNormTransform(Transform):
def __call__(self, inputs: Tensor, outputs_expected: Tensor) -> Tuple[Tensor, Tensor]:
return norm_mean(inputs), outputs_expected
def norm_gaussian(tensor: Tensor):
return (tensor - tensor.mean()) / tensor.std()
class GaussianNormTransform(Transform):
def __call__(self, inputs: Tensor, outputs_expected: Tensor) -> Tuple[Tensor, Tensor]:
return norm_gaussian(inputs), outputs_expected
__all__ = [
norm_max.__name__,
norm_mean.__name__,
norm_gaussian.__name__,
MaxNormTransform.__name__,
MeanNormTransform.__name__,
GaussianNormTransform.__name__,
]
...@@ -4,7 +4,7 @@ import random ...@@ -4,7 +4,7 @@ import random
import numpy as np import numpy as np
import torch import torch
from mu_map.data.datasets import MuMapDataset from mu_map.dataset.default import MuMapDataset
class MuMapPatchDataset(MuMapDataset): class MuMapPatchDataset(MuMapDataset):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment