diff --git a/mu_map/test.py b/mu_map/test.py index fb197707dafb8ef048a3cad6b2d7de1cdc9b0526..1da35dfb4f733be69e8ff23f51d1fbc744d4b9bc 100644 --- a/mu_map/test.py +++ b/mu_map/test.py @@ -14,5 +14,13 @@ y = GaussianNorm()(x) print(f" After: mean={y.mean():.3f} std={y.std():.3f}") - +import cv2 as cv +import numpy as np + +x = np.zeros((512, 512), np.uint8) +cv.imshow("X", x) +key = cv.waitKey(0) +while key != ord("q"): + print(key) + key = cv.waitKey(0) diff --git a/mu_map/train.py b/mu_map/train.py index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..a1143a864214e3f6ba1e2055e8c40058ba42be11 100644 --- a/mu_map/train.py +++ b/mu_map/train.py @@ -0,0 +1,216 @@ +import logging +import logging.handlers + +log_formatter = logging.Formatter( + fmt="%(asctime)s - %(levelname)s: %(message)s", datefmt="%m/%d/%Y %I:%M:%S %p" +) + +logger = logging.getLogger() +logger.setLevel(logging.INFO) + +log_handler_console = logging.StreamHandler() +log_handler_console.setFormatter(log_formatter) +logger.addHandler(log_handler_console) + +logfile = "train.log" +log_handler_file = logging.handlers.WatchedFileHandler(logfile) +log_handler_file.setFormatter(log_formatter) +logger.addHandler(log_handler_file) + +logger.info("This is a test!") +logger.warning("The training loss is over 9000!") +logger.error("This is an error") +logger.info("The end") + +args = parser.parse_args() +args.channel_mapping = dict(zip(args.channel_mapping, [0, 1, 2])) + +log_formatter = logging.Formatter( + fmt="%(asctime)s - %(levelname)s: %(message)s", datefmt="%m/%d/%Y %I:%M:%S %p" +) + +logger = logging.getLogger() +logger.setLevel(logging.INFO) + +log_handler_console = logging.StreamHandler() +log_handler_console.setFormatter(log_formatter) +logger.addHandler(log_handler_console) + +if args.logfile: + log_handler_file = logging.handlers.WatchedFileHandler(args.logfile) + log_handler_file.setFormatter(log_formatter) + logger.addHandler(log_handler_file) + + +# Setup snapshot directory +if os.path.exists(args.snapshot_dir): + if not os.path.isdir(args.snapshot_dir): + raise ValueError( + "Snapshot directory[%s] is not a directory!" % args.snapshot_dir + ) +else: + os.mkdir(args.snapshot_dir) + +random_seed = int(time.time()) +logger.info(f"Seed RNG: {random_seed}") + +torch.manual_seed(random_seed) + +device = torch.device(args.device) +model = load_model(args.model, args.weights, device) + +params = model.parameters() +optimizer = model.init_optimizer() +lr_scheduler = torch.optim.lr_scheduler.StepLR( + optimizer, args.decay_epoch, args.decay_rate +) + +transforms_train = [transforms.RandomHorizontalFlip(), transforms.RandomVerticalFlip()] +datasets = {} +datasets["train"] = ( + CSVDataset( + os.path.join(args.csv_dir, "training.csv"), + args.dataset_dir, + model, + augments=transforms_train, + illuminations=args.illuminations, + random_illumination=True, + ) + if not args.illuminations_to_channel + else IlluminationChannelDataset( + os.path.join(args.csv_dir, "training.csv"), + args.dataset_dir, + model, + augments=transforms_train, + channel_mapping=args.channel_mapping, + ) +) +datasets["val"] = ( + CSVDataset( + os.path.join(args.csv_dir, "validation.csv"), + args.dataset_dir, + model, + illuminations=args.illuminations, + random_illumination=False, + ) + if not args.illuminations_to_channel + else IlluminationChannelDataset( + os.path.join(args.csv_dir, "validation.csv"), + args.dataset_dir, + model, + channel_mapping=args.channel_mapping, + ) +) + +data_loaders = {} +for key in datasets: + data_loaders[key] = torch.utils.data.DataLoader( + dataset=datasets[key], + batch_size=args.batch_size, + shuffle=True, + pin_memory=True, + num_workers=8, + ) + +global_epoch = 0 + + +def run_epoch(model, optimizer, data_loader, phase="train"): + if phase == "train": + model.train() + else: + model.eval() + + running_loss = 0 + since = time.time() + count = 1 + for inputs, labels in data_loader: + print("Iteration {}/{}".format(count, len(data_loader)) + "\r", end="") + count += 1 + + inputs = inputs.to(device) + if type(labels) == torch.Tensor: + labels = labels.to(device) + elif type(labels) == list: + labels = [label.to(device) for label in labels] + + optimizer.zero_grad() + + with torch.set_grad_enabled(phase == "train"): + output = model(inputs) + loss = model.compute_loss(output, labels) + + if phase == "train": + loss.backward() + optimizer.step() + + running_loss += loss.item() * inputs.size(0) + + return running_loss / len(data_loader.dataset), time.time() - since + + +try: + for epoch in range(1, args.epochs + 1): + if epoch == 1: + str_epoch = "Epoch {}/{}".format(0, args.epochs) + + epoch_loss, epoch_time = run_epoch( + model, optimizer, data_loaders["val"], phase="val" + ) + + str_time = "Time {}s".format(round(epoch_time)) + str_loss = "Loss {:.4f}".format(epoch_loss) + + print("VAL: " + str_epoch + " " + str_loss + " " + str_time) + logger.info("VAL: " + str_epoch + " " + str_loss + " " + str_time) + + epoch_loss, epoch_time = run_epoch( + model, optimizer, data_loaders["train"], phase="val" + ) + + str_time = "Time {}s".format(round(epoch_time)) + str_loss = "Loss {:.4f}".format(epoch_loss) + + print("TRAIN: " + str_epoch + " " + str_loss + " " + str_time) + logger.info("TRAIN: " + str_epoch + " " + str_loss + " " + str_time) + + global_epoch = epoch + str_epoch = "Epoch {}/{}".format(epoch, args.epochs) + + epoch_loss, epoch_time = run_epoch(model, optimizer, data_loaders["train"]) + + str_time = "Time {}s".format(round(epoch_time)) + str_loss = "Loss {:.4f}".format(epoch_loss) + + print("TRAIN: " + str_epoch + " " + str_loss + " " + str_time) + logger.info("TRAIN: " + str_epoch + " " + str_loss + " " + str_time) + + # create snapshot + if epoch % args.snapshot_epoch == 0: + snapshot_file = f"{global_epoch:0{len(str(args.epochs))}d}.pth" + snapshot_file = os.path.join(args.snapshot_dir, snapshot_file) + print("Create Snapshot[{}]".format(snapshot_file)) + logger.info("Create Snapshot[{}]".format(snapshot_file)) + torch.save(model.state_dict(), snapshot_file) + + # compute validation loss + if epoch % args.val_epoch == 0: + print("VAL: " + str_epoch, end="\r") + + epoch_loss, epoch_time = run_epoch( + model, optimizer, data_loaders["val"], phase="val" + ) + + str_time = "Time {}s".format(round(epoch_time)) + str_loss = "Loss {:.4f}".format(epoch_loss) + + print("VAL: " + str_epoch + " " + str_loss + " " + str_time) + logger.info("VAL: " + str_epoch + " " + str_loss + " " + str_time) + + lr_scheduler.step() +except KeyboardInterrupt: + snapshot_file = f"{global_epoch:0{len(str(args.epochs))}d}.pth" + snapshot_file = os.path.join(args.snapshot_dir, snapshot_file) + print("Create Snapshot[{}]".format(snapshot_file)) + torch.save(model.state_dict(), snapshot_file) + diff --git a/mu_map/training/default.py b/mu_map/training/default.py new file mode 100644 index 0000000000000000000000000000000000000000..15330982e772dcb3298f0a99a6b8e77adebc0a97 --- /dev/null +++ b/mu_map/training/default.py @@ -0,0 +1,44 @@ + + +class Training(): + + def __init__(self, epochs): + self.epochs = epochs + + def run(self): + for epoch in range(1, self.epochs + 1): + self.run_epoch(self.data_loader["train"], phase="train") + loss_training = self.run_epoch(self.data_loader["train"], phase="eval") + loss_validation = self.run_epoch(self.data_loader["validation"], phase="eval") + + # ToDo: log outputs and time + self.lr_scheduler.step() + + if epoch % self.snapshot_epoch: + self.store_snapshot(epoch) + + + + def run_epoch(self, data_loader, phase): + self.model.train() if phase == "train" else self.model.eval() + + epoch_loss = 0 + for inputs, labels in self.data_loader: + inputs = inputs.to(self.device) + labels = labels.to(self.device) + + self.optimizer.zero_grad() + with torch.set_grad_enabled(phase == "train"): + outputs = self.model(inputs) + loss = self.loss(outputs, labels) + + if phase == "train": + loss.backward() + optimizer.step() + + epoch_loss += loss.item() / inputs.size[0] + return epoch_loss + + + def store_snapshot(self, epoch): + pass