import torch from .data.preprocessing import * means = torch.full((10, 10, 10), 5.0) stds = torch.full((10, 10, 10), 10.0) x = torch.normal(means, stds) print(f"Before: mean={x.mean():.3f} std={x.std():.3f}") y = norm_gaussian(x) print(f" After: mean={y.mean():.3f} std={y.std():.3f}") y = GaussianNorm()(x) print(f" After: mean={y.mean():.3f} std={y.std():.3f}")