diff --git a/mu_map/test.py b/mu_map/test.py
index fb197707dafb8ef048a3cad6b2d7de1cdc9b0526..1da35dfb4f733be69e8ff23f51d1fbc744d4b9bc 100644
--- a/mu_map/test.py
+++ b/mu_map/test.py
@@ -14,5 +14,13 @@ y = GaussianNorm()(x)
 print(f" After: mean={y.mean():.3f} std={y.std():.3f}")
 
 
-
+import cv2 as cv
+import numpy as np
+
+x = np.zeros((512, 512), np.uint8)
+cv.imshow("X", x)
+key = cv.waitKey(0)
+while key != ord("q"):
+    print(key)
+    key = cv.waitKey(0)
 
diff --git a/mu_map/train.py b/mu_map/train.py
index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..a1143a864214e3f6ba1e2055e8c40058ba42be11 100644
--- a/mu_map/train.py
+++ b/mu_map/train.py
@@ -0,0 +1,216 @@
+import logging
+import logging.handlers
+
+log_formatter = logging.Formatter(
+    fmt="%(asctime)s - %(levelname)s: %(message)s", datefmt="%m/%d/%Y %I:%M:%S %p"
+)
+
+logger = logging.getLogger()
+logger.setLevel(logging.INFO)
+
+log_handler_console = logging.StreamHandler()
+log_handler_console.setFormatter(log_formatter)
+logger.addHandler(log_handler_console)
+
+logfile = "train.log"
+log_handler_file = logging.handlers.WatchedFileHandler(logfile)
+log_handler_file.setFormatter(log_formatter)
+logger.addHandler(log_handler_file)
+
+logger.info("This is a test!")
+logger.warning("The training loss is over 9000!")
+logger.error("This is an error")
+logger.info("The end")
+
+args = parser.parse_args()
+args.channel_mapping = dict(zip(args.channel_mapping, [0, 1, 2]))
+
+log_formatter = logging.Formatter(
+    fmt="%(asctime)s - %(levelname)s: %(message)s", datefmt="%m/%d/%Y %I:%M:%S %p"
+)
+
+logger = logging.getLogger()
+logger.setLevel(logging.INFO)
+
+log_handler_console = logging.StreamHandler()
+log_handler_console.setFormatter(log_formatter)
+logger.addHandler(log_handler_console)
+
+if args.logfile:
+    log_handler_file = logging.handlers.WatchedFileHandler(args.logfile)
+    log_handler_file.setFormatter(log_formatter)
+    logger.addHandler(log_handler_file)
+
+
+# Setup snapshot directory
+if os.path.exists(args.snapshot_dir):
+    if not os.path.isdir(args.snapshot_dir):
+        raise ValueError(
+            "Snapshot directory[%s] is not a directory!" % args.snapshot_dir
+        )
+else:
+    os.mkdir(args.snapshot_dir)
+
+random_seed = int(time.time())
+logger.info(f"Seed RNG: {random_seed}")
+
+torch.manual_seed(random_seed)
+
+device = torch.device(args.device)
+model = load_model(args.model, args.weights, device)
+
+params = model.parameters()
+optimizer = model.init_optimizer()
+lr_scheduler = torch.optim.lr_scheduler.StepLR(
+    optimizer, args.decay_epoch, args.decay_rate
+)
+
+transforms_train = [transforms.RandomHorizontalFlip(), transforms.RandomVerticalFlip()]
+datasets = {}
+datasets["train"] = (
+    CSVDataset(
+        os.path.join(args.csv_dir, "training.csv"),
+        args.dataset_dir,
+        model,
+        augments=transforms_train,
+        illuminations=args.illuminations,
+        random_illumination=True,
+    )
+    if not args.illuminations_to_channel
+    else IlluminationChannelDataset(
+        os.path.join(args.csv_dir, "training.csv"),
+        args.dataset_dir,
+        model,
+        augments=transforms_train,
+        channel_mapping=args.channel_mapping,
+    )
+)
+datasets["val"] = (
+    CSVDataset(
+        os.path.join(args.csv_dir, "validation.csv"),
+        args.dataset_dir,
+        model,
+        illuminations=args.illuminations,
+        random_illumination=False,
+    )
+    if not args.illuminations_to_channel
+    else IlluminationChannelDataset(
+        os.path.join(args.csv_dir, "validation.csv"),
+        args.dataset_dir,
+        model,
+        channel_mapping=args.channel_mapping,
+    )
+)
+
+data_loaders = {}
+for key in datasets:
+    data_loaders[key] = torch.utils.data.DataLoader(
+        dataset=datasets[key],
+        batch_size=args.batch_size,
+        shuffle=True,
+        pin_memory=True,
+        num_workers=8,
+    )
+
+global_epoch = 0
+
+
+def run_epoch(model, optimizer, data_loader, phase="train"):
+    if phase == "train":
+        model.train()
+    else:
+        model.eval()
+
+    running_loss = 0
+    since = time.time()
+    count = 1
+    for inputs, labels in data_loader:
+        print("Iteration {}/{}".format(count, len(data_loader)) + "\r", end="")
+        count += 1
+
+        inputs = inputs.to(device)
+        if type(labels) == torch.Tensor:
+            labels = labels.to(device)
+        elif type(labels) == list:
+            labels = [label.to(device) for label in labels]
+
+        optimizer.zero_grad()
+
+        with torch.set_grad_enabled(phase == "train"):
+            output = model(inputs)
+            loss = model.compute_loss(output, labels)
+
+            if phase == "train":
+                loss.backward()
+                optimizer.step()
+
+        running_loss += loss.item() * inputs.size(0)
+
+    return running_loss / len(data_loader.dataset), time.time() - since
+
+
+try:
+    for epoch in range(1, args.epochs + 1):
+        if epoch == 1:
+            str_epoch = "Epoch {}/{}".format(0, args.epochs)
+
+            epoch_loss, epoch_time = run_epoch(
+                model, optimizer, data_loaders["val"], phase="val"
+            )
+
+            str_time = "Time {}s".format(round(epoch_time))
+            str_loss = "Loss {:.4f}".format(epoch_loss)
+
+            print("VAL: " + str_epoch + " " + str_loss + " " + str_time)
+            logger.info("VAL: " + str_epoch + " " + str_loss + " " + str_time)
+
+            epoch_loss, epoch_time = run_epoch(
+                model, optimizer, data_loaders["train"], phase="val"
+            )
+
+            str_time = "Time {}s".format(round(epoch_time))
+            str_loss = "Loss {:.4f}".format(epoch_loss)
+
+            print("TRAIN: " + str_epoch + " " + str_loss + " " + str_time)
+            logger.info("TRAIN: " + str_epoch + " " + str_loss + " " + str_time)
+
+        global_epoch = epoch
+        str_epoch = "Epoch {}/{}".format(epoch, args.epochs)
+
+        epoch_loss, epoch_time = run_epoch(model, optimizer, data_loaders["train"])
+
+        str_time = "Time {}s".format(round(epoch_time))
+        str_loss = "Loss {:.4f}".format(epoch_loss)
+
+        print("TRAIN: " + str_epoch + " " + str_loss + " " + str_time)
+        logger.info("TRAIN: " + str_epoch + " " + str_loss + " " + str_time)
+
+        # create snapshot
+        if epoch % args.snapshot_epoch == 0:
+            snapshot_file = f"{global_epoch:0{len(str(args.epochs))}d}.pth"
+            snapshot_file = os.path.join(args.snapshot_dir, snapshot_file)
+            print("Create Snapshot[{}]".format(snapshot_file))
+            logger.info("Create Snapshot[{}]".format(snapshot_file))
+            torch.save(model.state_dict(), snapshot_file)
+
+        # compute validation loss
+        if epoch % args.val_epoch == 0:
+            print("VAL: " + str_epoch, end="\r")
+
+            epoch_loss, epoch_time = run_epoch(
+                model, optimizer, data_loaders["val"], phase="val"
+            )
+
+            str_time = "Time {}s".format(round(epoch_time))
+            str_loss = "Loss {:.4f}".format(epoch_loss)
+
+            print("VAL: " + str_epoch + " " + str_loss + " " + str_time)
+            logger.info("VAL: " + str_epoch + " " + str_loss + " " + str_time)
+
+        lr_scheduler.step()
+except KeyboardInterrupt:
+    snapshot_file = f"{global_epoch:0{len(str(args.epochs))}d}.pth"
+    snapshot_file = os.path.join(args.snapshot_dir, snapshot_file)
+    print("Create Snapshot[{}]".format(snapshot_file))
+    torch.save(model.state_dict(), snapshot_file)
+
diff --git a/mu_map/training/default.py b/mu_map/training/default.py
new file mode 100644
index 0000000000000000000000000000000000000000..15330982e772dcb3298f0a99a6b8e77adebc0a97
--- /dev/null
+++ b/mu_map/training/default.py
@@ -0,0 +1,44 @@
+
+
+class Training():
+
+    def __init__(self, epochs):
+        self.epochs = epochs
+
+    def run(self):
+        for epoch in range(1, self.epochs + 1):
+            self.run_epoch(self.data_loader["train"], phase="train")
+            loss_training = self.run_epoch(self.data_loader["train"], phase="eval")
+            loss_validation = self.run_epoch(self.data_loader["validation"], phase="eval")
+
+            # ToDo: log outputs and time
+            self.lr_scheduler.step()
+
+            if epoch % self.snapshot_epoch:
+                self.store_snapshot(epoch)
+
+            
+
+    def run_epoch(self, data_loader, phase):
+        self.model.train() if phase == "train" else self.model.eval()
+
+        epoch_loss = 0
+        for inputs, labels in self.data_loader:
+            inputs = inputs.to(self.device)
+            labels = labels.to(self.device)
+
+            self.optimizer.zero_grad()
+            with torch.set_grad_enabled(phase == "train"):
+                outputs = self.model(inputs)
+                loss = self.loss(outputs, labels)
+
+                if phase == "train":
+                    loss.backward()
+                    optimizer.step()
+
+            epoch_loss += loss.item() / inputs.size[0]
+        return epoch_loss
+
+
+    def store_snapshot(self, epoch):
+        pass