diff --git a/mu_map/models/unet.py b/mu_map/models/unet.py index e6e0bbc2ea46910c8be210563d91b52f0396e161..9a4d5c875b4ecfa2678d358a3504db084dcac1a6 100644 --- a/mu_map/models/unet.py +++ b/mu_map/models/unet.py @@ -154,10 +154,30 @@ class UNet(nn.Module): if __name__ == "__main__": import torch - net = UNet(features=[64, 128, 256]) + net = UNet(features=[64, 128, 256, 512]) print(net) - _inputs = torch.rand((1, 1, 64, 64, 64)) + _inputs = torch.rand((1, 1, 64, 128, 128)) _outputs = net(_inputs) print(f"Transform {_inputs.shape} to {_outputs.shape}") + + import time + + device = torch.device("cuda") + net = net.to(device) + iterations = 100 + + for batch_size in range(1, 17): + since = time.time() + for i in range(iterations): + print(f"{str(batch_size):>2}/17 - {str(i+1):>3}/{iterations}", end="\r") + _inputs = torch.rand((batch_size, 1, 64, 128, 128)) + _inputs = _inputs.to(device) + + _outputs = net(_inputs) + _took = time.time() - since + print(f"Batches of size {batch_size} take {_took:.3f}s on average") + + +