From 6d5bbf9c49d8a98918158d1911c6ac1c3cb72fff Mon Sep 17 00:00:00 2001
From: Tamino Huxohl <thuxohl@techfak.uni-bielefeld.de>
Date: Mon, 26 Sep 2022 09:30:47 +0200
Subject: [PATCH] test speed and max batch size of UNet model

---
 mu_map/models/unet.py | 24 ++++++++++++++++++++++--
 1 file changed, 22 insertions(+), 2 deletions(-)

diff --git a/mu_map/models/unet.py b/mu_map/models/unet.py
index e6e0bbc..9a4d5c8 100644
--- a/mu_map/models/unet.py
+++ b/mu_map/models/unet.py
@@ -154,10 +154,30 @@ class UNet(nn.Module):
 if __name__ == "__main__":
     import torch
 
-    net = UNet(features=[64, 128, 256])
+    net = UNet(features=[64, 128, 256, 512])
     print(net)
 
-    _inputs = torch.rand((1, 1, 64, 64, 64))
+    _inputs = torch.rand((1, 1, 64, 128, 128))
     _outputs = net(_inputs)
 
     print(f"Transform {_inputs.shape} to {_outputs.shape}")
+
+    import time
+    
+    device = torch.device("cuda")
+    net = net.to(device)
+    iterations = 100
+    
+    for batch_size in range(1, 17):
+        since = time.time()
+        for i in range(iterations):
+            print(f"{str(batch_size):>2}/17 - {str(i+1):>3}/{iterations}", end="\r")
+            _inputs = torch.rand((batch_size, 1, 64, 128, 128))
+            _inputs = _inputs.to(device)
+
+            _outputs = net(_inputs)
+        _took = time.time() - since
+        print(f"Batches of size {batch_size} take {_took:.3f}s on average")
+    
+
+
-- 
GitLab