From 6d5bbf9c49d8a98918158d1911c6ac1c3cb72fff Mon Sep 17 00:00:00 2001 From: Tamino Huxohl <thuxohl@techfak.uni-bielefeld.de> Date: Mon, 26 Sep 2022 09:30:47 +0200 Subject: [PATCH] test speed and max batch size of UNet model --- mu_map/models/unet.py | 24 ++++++++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/mu_map/models/unet.py b/mu_map/models/unet.py index e6e0bbc..9a4d5c8 100644 --- a/mu_map/models/unet.py +++ b/mu_map/models/unet.py @@ -154,10 +154,30 @@ class UNet(nn.Module): if __name__ == "__main__": import torch - net = UNet(features=[64, 128, 256]) + net = UNet(features=[64, 128, 256, 512]) print(net) - _inputs = torch.rand((1, 1, 64, 64, 64)) + _inputs = torch.rand((1, 1, 64, 128, 128)) _outputs = net(_inputs) print(f"Transform {_inputs.shape} to {_outputs.shape}") + + import time + + device = torch.device("cuda") + net = net.to(device) + iterations = 100 + + for batch_size in range(1, 17): + since = time.time() + for i in range(iterations): + print(f"{str(batch_size):>2}/17 - {str(i+1):>3}/{iterations}", end="\r") + _inputs = torch.rand((batch_size, 1, 64, 128, 128)) + _inputs = _inputs.to(device) + + _outputs = net(_inputs) + _took = time.time() - since + print(f"Batches of size {batch_size} take {_took:.3f}s on average") + + + -- GitLab