Newer
Older
import torch
from .data.preprocessing import *
means = torch.full((10, 10, 10), 5.0)
stds = torch.full((10, 10, 10), 10.0)
x = torch.normal(means, stds)
print(f"Before: mean={x.mean():.3f} std={x.std():.3f}")
y = norm_gaussian(x)
print(f" After: mean={y.mean():.3f} std={y.std():.3f}")
y = GaussianNorm()(x)
print(f" After: mean={y.mean():.3f} std={y.std():.3f}")