Skip to content
Snippets Groups Projects
Commit 6d5bbf9c authored by Tamino Huxohl's avatar Tamino Huxohl
Browse files

test speed and max batch size of UNet model

parent 6d14f212
No related branches found
No related tags found
No related merge requests found
...@@ -154,10 +154,30 @@ class UNet(nn.Module): ...@@ -154,10 +154,30 @@ class UNet(nn.Module):
if __name__ == "__main__": if __name__ == "__main__":
import torch import torch
net = UNet(features=[64, 128, 256]) net = UNet(features=[64, 128, 256, 512])
print(net) print(net)
_inputs = torch.rand((1, 1, 64, 64, 64)) _inputs = torch.rand((1, 1, 64, 128, 128))
_outputs = net(_inputs) _outputs = net(_inputs)
print(f"Transform {_inputs.shape} to {_outputs.shape}") print(f"Transform {_inputs.shape} to {_outputs.shape}")
import time
device = torch.device("cuda")
net = net.to(device)
iterations = 100
for batch_size in range(1, 17):
since = time.time()
for i in range(iterations):
print(f"{str(batch_size):>2}/17 - {str(i+1):>3}/{iterations}", end="\r")
_inputs = torch.rand((batch_size, 1, 64, 128, 128))
_inputs = _inputs.to(device)
_outputs = net(_inputs)
_took = time.time() - since
print(f"Batches of size {batch_size} take {_took:.3f}s on average")
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment