Skip to content
Snippets Groups Projects
Commit 91b8bd38 authored by Tamino Huxohl's avatar Tamino Huxohl
Browse files
parents d961d0a5 db4917fa
No related branches found
No related tags found
No related merge requests found
......@@ -155,10 +155,30 @@ class UNet(nn.Module):
if __name__ == "__main__":
import torch
net = UNet(features=[64, 128, 256])
net = UNet(features=[64, 128, 256, 512])
print(net)
_inputs = torch.rand((1, 1, 64, 64, 64))
_inputs = torch.rand((1, 1, 64, 128, 128))
_outputs = net(_inputs)
print(f"Transform {_inputs.shape} to {_outputs.shape}")
import time
device = torch.device("cuda")
net = net.to(device)
iterations = 100
for batch_size in range(1, 17):
since = time.time()
for i in range(iterations):
print(f"{str(batch_size):>2}/17 - {str(i+1):>3}/{iterations}", end="\r")
_inputs = torch.rand((batch_size, 1, 64, 128, 128))
_inputs = _inputs.to(device)
_outputs = net(_inputs)
_took = time.time() - since
print(f"Batches of size {batch_size} take {_took:.3f}s on average")
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment