import torch

from .data.preprocessing import *

means = torch.full((10, 10, 10), 5.0)
stds = torch.full((10, 10, 10), 10.0)
x = torch.normal(means, stds)

print(f"Before: mean={x.mean():.3f} std={x.std():.3f}")

y = norm_gaussian(x)
print(f" After: mean={y.mean():.3f} std={y.std():.3f}")
y = GaussianNorm()(x)
print(f" After: mean={y.mean():.3f} std={y.std():.3f}")