import torch

from .data.preprocessing import *

means = torch.full((10, 10, 10), 5.0)
stds = torch.full((10, 10, 10), 10.0)
x = torch.normal(means, stds)

print(f"Before: mean={x.mean():.3f} std={x.std():.3f}")

y = norm_gaussian(x)
print(f" After: mean={y.mean():.3f} std={y.std():.3f}")
y = GaussianNorm()(x)
print(f" After: mean={y.mean():.3f} std={y.std():.3f}")


import cv2 as cv
import numpy as np

x = np.zeros((512, 512), np.uint8)
cv.imshow("X", x)
key = cv.waitKey(0)
while key != ord("q"):
    print(key)
    key = cv.waitKey(0)