From 7bf09ddd13a8e4ad6db15eaa5b1e8ae01d83285e Mon Sep 17 00:00:00 2001
From: Tamino Huxohl <thuxohl@techfak.uni-bielefeld.de>
Date: Thu, 15 Dec 2022 10:11:02 +0100
Subject: [PATCH] add missing training lib module

---
 mu_map/training/lib.py | 127 +++++++++++++++++++++++++++++++++++++++++
 1 file changed, 127 insertions(+)
 create mode 100644 mu_map/training/lib.py

diff --git a/mu_map/training/lib.py b/mu_map/training/lib.py
new file mode 100644
index 0000000..00811cd
--- /dev/null
+++ b/mu_map/training/lib.py
@@ -0,0 +1,127 @@
+from dataclasses import dataclass
+import os
+from typing import Dict
+
+import torch
+from torch import Tensor
+
+from mu_map.dataset.default import MuMapDataset
+
+@dataclass
+class TrainingParams:
+    name: str
+    model: torch.nn.Module
+    optimizer: torch.optim.Optimizer
+    lr_scheduler: Optional[torch.optim.lr_scheduler._LRScheduler]
+
+class AbstractTraining:
+
+    def __init__(
+        self,
+        epochs: int,
+        dataset: MuMapDataset,
+        batch_size: int,
+        device: torch.device,
+        snapshot_dir: str,
+        snapshot_epoch: int,
+        logger=None,
+    ):
+        self.epochs = epochs
+        self.batch_size = batch_size
+        self.dataset = dataset
+        self.device = device
+
+        self.snapshot_dir = snapshot_dir
+        self.snapshot_epoch = snapshot_epoch
+
+        self.logger = logger
+
+        self.training_params = []
+
+
+    def run(self) -> float:
+        loss_val_min = sys.maxsize
+        for epoch in range(1, self.epochs + 1):
+            str_epoch = f"{str(epoch):>{len(str(self.epochs))}}"
+            self.logger.debug(f"Run epoch {str_epoch}/{self.epochs} ...")
+
+            loss_train = self._train_epoch()
+            self.logger.info(
+                f"Epoch {str_epoch}/{self.epochs} - Loss train: {loss_train:.6f}"
+            )
+            loss_val = self._eval_epoch()
+            self.logger.info(
+                f"Epoch {str_epoch}/{self.epochs} - Loss validation: {loss_val:.6f}"
+            )
+
+            if loss_val < loss_val_min:
+                loss_val_min = loss_val
+                self.logger.info(
+                    f"Store snapshot val_min of epoch {str_epoch} with minimal validation loss"
+                )
+                self.store_snapshot("val_min")
+
+            if epoch % self.snapshot_epoch == 0:
+                self._store_snapshot(f"{epoch:0d{len(str(self.epochs))}}")
+
+            for param in self.training_params:
+                if param.lr_scheduler is not None:
+                    param.lr_scheduler.step()
+        return loss_val_min
+
+    def _train_epoch(self):
+        torch.set_grad_enabled(True)
+        for param in self.training_params:
+            param.model.train()
+
+        loss = 0.0
+        data_loader = self.data_loaders["train"]
+        for i, (inputs, targets) in enumerate(data_loader):
+            print(
+                f"Batch {str(i):>{len(str(len(data_loader)))}}/{len(data_loader)}",
+                end="\r",
+            )
+
+            inputs = inputs.to(self.device)
+            targets = targets.to(self.device)
+
+            for param in self.training_params:
+                param.optimizer.zero_grad()
+
+            loss = loss + self._train_batch(self, inputs, targets)
+
+            for param in self.training_params:
+                param.optimizer.step()
+        return loss / len(data_loader)
+
+    def _eval_epoch(self, phase: str):
+        torch.set_grad_enabled(False)
+        for model in self.models:
+            model.eval()
+
+        loss = 0.0
+        data_loader = self.data_loaders["validation"]
+        for i, (inputs, targets) in enumerate(data_loader):
+            print(
+                f"Batch {str(i):>{len(str(len(data_loader)))}}/{len(data_loader)}",
+                end="\r",
+            )
+
+            inputs = inputs.to(self.device)
+            targets = targets.to(self.device)
+
+            loss = loss + self._eval_batch(self, inputs, targets)
+        return loss / len(data_loader)
+
+    def store_snapshot(self, prefix: str):
+        for param in self.training_params:
+            snapshot_file = os.path.join(self.snapshot_dir, f"{prefix}_{param.name}.pth")
+            self.logger.debug(f"Store snapshot at {snapshot_file}")
+            torch.save(param.model.state_dict(), snapshot_file)
+
+    def _train_batch(self, inputs: torch.Tensor, targets: torch.Tensor) -> float:
+        return 0
+
+    def _eval_batch(self, inputs: torch.Tensor, targets: torch.Tensor) -> float:
+        return 0
+
-- 
GitLab