diff --git a/mu_map/models/unet.py b/mu_map/models/unet.py index 2d0f6a536784f242e3724860cfe29023af6e5fa0..3836e936314bd73658247db3067ccac23b786d1d 100644 --- a/mu_map/models/unet.py +++ b/mu_map/models/unet.py @@ -169,11 +169,12 @@ if __name__ == "__main__": net = net.to(device) iterations = 100 - for batch_size in range(1, 17): + for batch_size in range(128, 129): since = time.time() for i in range(iterations): print(f"{str(batch_size):>2}/17 - {str(i+1):>3}/{iterations}", end="\r") - _inputs = torch.rand((batch_size, 1, 64, 128, 128)) + # _inputs = torch.rand((batch_size, 1, 64, 128, 128)) + _inputs = torch.rand((batch_size, 1, 32, 32, 32)) _inputs = _inputs.to(device) _outputs = net(_inputs)