diff --git a/mu_map/models/unet.py b/mu_map/models/unet.py
index e6e0bbc2ea46910c8be210563d91b52f0396e161..9a4d5c875b4ecfa2678d358a3504db084dcac1a6 100644
--- a/mu_map/models/unet.py
+++ b/mu_map/models/unet.py
@@ -154,10 +154,30 @@ class UNet(nn.Module):
 if __name__ == "__main__":
     import torch
 
-    net = UNet(features=[64, 128, 256])
+    net = UNet(features=[64, 128, 256, 512])
     print(net)
 
-    _inputs = torch.rand((1, 1, 64, 64, 64))
+    _inputs = torch.rand((1, 1, 64, 128, 128))
     _outputs = net(_inputs)
 
     print(f"Transform {_inputs.shape} to {_outputs.shape}")
+
+    import time
+    
+    device = torch.device("cuda")
+    net = net.to(device)
+    iterations = 100
+    
+    for batch_size in range(1, 17):
+        since = time.time()
+        for i in range(iterations):
+            print(f"{str(batch_size):>2}/17 - {str(i+1):>3}/{iterations}", end="\r")
+            _inputs = torch.rand((batch_size, 1, 64, 128, 128))
+            _inputs = _inputs.to(device)
+
+            _outputs = net(_inputs)
+        _took = time.time() - since
+        print(f"Batches of size {batch_size} take {_took:.3f}s on average")
+    
+
+