diff --git a/mu_map/models/unet.py b/mu_map/models/unet.py
index 2d0f6a536784f242e3724860cfe29023af6e5fa0..3836e936314bd73658247db3067ccac23b786d1d 100644
--- a/mu_map/models/unet.py
+++ b/mu_map/models/unet.py
@@ -169,11 +169,12 @@ if __name__ == "__main__":
     net = net.to(device)
     iterations = 100
     
-    for batch_size in range(1, 17):
+    for batch_size in range(128, 129):
         since = time.time()
         for i in range(iterations):
             print(f"{str(batch_size):>2}/17 - {str(i+1):>3}/{iterations}", end="\r")
-            _inputs = torch.rand((batch_size, 1, 64, 128, 128))
+            # _inputs = torch.rand((batch_size, 1, 64, 128, 128))
+            _inputs = torch.rand((batch_size, 1, 32, 32, 32))
             _inputs = _inputs.to(device)
 
             _outputs = net(_inputs)