import logging import logging.handlers log_formatter = logging.Formatter( fmt="%(asctime)s - %(levelname)s: %(message)s", datefmt="%m/%d/%Y %I:%M:%S %p" ) logger = logging.getLogger() logger.setLevel(logging.INFO) log_handler_console = logging.StreamHandler() log_handler_console.setFormatter(log_formatter) logger.addHandler(log_handler_console) logfile = "train.log" log_handler_file = logging.handlers.WatchedFileHandler(logfile) log_handler_file.setFormatter(log_formatter) logger.addHandler(log_handler_file) logger.info("This is a test!") logger.warning("The training loss is over 9000!") logger.error("This is an error") logger.info("The end") args = parser.parse_args() args.channel_mapping = dict(zip(args.channel_mapping, [0, 1, 2])) log_formatter = logging.Formatter( fmt="%(asctime)s - %(levelname)s: %(message)s", datefmt="%m/%d/%Y %I:%M:%S %p" ) logger = logging.getLogger() logger.setLevel(logging.INFO) log_handler_console = logging.StreamHandler() log_handler_console.setFormatter(log_formatter) logger.addHandler(log_handler_console) if args.logfile: log_handler_file = logging.handlers.WatchedFileHandler(args.logfile) log_handler_file.setFormatter(log_formatter) logger.addHandler(log_handler_file) # Setup snapshot directory if os.path.exists(args.snapshot_dir): if not os.path.isdir(args.snapshot_dir): raise ValueError( "Snapshot directory[%s] is not a directory!" % args.snapshot_dir ) else: os.mkdir(args.snapshot_dir) random_seed = int(time.time()) logger.info(f"Seed RNG: {random_seed}") torch.manual_seed(random_seed) device = torch.device(args.device) model = load_model(args.model, args.weights, device) params = model.parameters() optimizer = model.init_optimizer() lr_scheduler = torch.optim.lr_scheduler.StepLR( optimizer, args.decay_epoch, args.decay_rate ) transforms_train = [transforms.RandomHorizontalFlip(), transforms.RandomVerticalFlip()] datasets = {} datasets["train"] = ( CSVDataset( os.path.join(args.csv_dir, "training.csv"), args.dataset_dir, model, augments=transforms_train, illuminations=args.illuminations, random_illumination=True, ) if not args.illuminations_to_channel else IlluminationChannelDataset( os.path.join(args.csv_dir, "training.csv"), args.dataset_dir, model, augments=transforms_train, channel_mapping=args.channel_mapping, ) ) datasets["val"] = ( CSVDataset( os.path.join(args.csv_dir, "validation.csv"), args.dataset_dir, model, illuminations=args.illuminations, random_illumination=False, ) if not args.illuminations_to_channel else IlluminationChannelDataset( os.path.join(args.csv_dir, "validation.csv"), args.dataset_dir, model, channel_mapping=args.channel_mapping, ) ) data_loaders = {} for key in datasets: data_loaders[key] = torch.utils.data.DataLoader( dataset=datasets[key], batch_size=args.batch_size, shuffle=True, pin_memory=True, num_workers=8, ) global_epoch = 0 def run_epoch(model, optimizer, data_loader, phase="train"): if phase == "train": model.train() else: model.eval() running_loss = 0 since = time.time() count = 1 for inputs, labels in data_loader: print("Iteration {}/{}".format(count, len(data_loader)) + "\r", end="") count += 1 inputs = inputs.to(device) if type(labels) == torch.Tensor: labels = labels.to(device) elif type(labels) == list: labels = [label.to(device) for label in labels] optimizer.zero_grad() with torch.set_grad_enabled(phase == "train"): output = model(inputs) loss = model.compute_loss(output, labels) if phase == "train": loss.backward() optimizer.step() running_loss += loss.item() * inputs.size(0) return running_loss / len(data_loader.dataset), time.time() - since try: for epoch in range(1, args.epochs + 1): if epoch == 1: str_epoch = "Epoch {}/{}".format(0, args.epochs) epoch_loss, epoch_time = run_epoch( model, optimizer, data_loaders["val"], phase="val" ) str_time = "Time {}s".format(round(epoch_time)) str_loss = "Loss {:.4f}".format(epoch_loss) print("VAL: " + str_epoch + " " + str_loss + " " + str_time) logger.info("VAL: " + str_epoch + " " + str_loss + " " + str_time) epoch_loss, epoch_time = run_epoch( model, optimizer, data_loaders["train"], phase="val" ) str_time = "Time {}s".format(round(epoch_time)) str_loss = "Loss {:.4f}".format(epoch_loss) print("TRAIN: " + str_epoch + " " + str_loss + " " + str_time) logger.info("TRAIN: " + str_epoch + " " + str_loss + " " + str_time) global_epoch = epoch str_epoch = "Epoch {}/{}".format(epoch, args.epochs) epoch_loss, epoch_time = run_epoch(model, optimizer, data_loaders["train"]) str_time = "Time {}s".format(round(epoch_time)) str_loss = "Loss {:.4f}".format(epoch_loss) print("TRAIN: " + str_epoch + " " + str_loss + " " + str_time) logger.info("TRAIN: " + str_epoch + " " + str_loss + " " + str_time) # create snapshot if epoch % args.snapshot_epoch == 0: snapshot_file = f"{global_epoch:0{len(str(args.epochs))}d}.pth" snapshot_file = os.path.join(args.snapshot_dir, snapshot_file) print("Create Snapshot[{}]".format(snapshot_file)) logger.info("Create Snapshot[{}]".format(snapshot_file)) torch.save(model.state_dict(), snapshot_file) # compute validation loss if epoch % args.val_epoch == 0: print("VAL: " + str_epoch, end="\r") epoch_loss, epoch_time = run_epoch( model, optimizer, data_loaders["val"], phase="val" ) str_time = "Time {}s".format(round(epoch_time)) str_loss = "Loss {:.4f}".format(epoch_loss) print("VAL: " + str_epoch + " " + str_loss + " " + str_time) logger.info("VAL: " + str_epoch + " " + str_loss + " " + str_time) lr_scheduler.step() except KeyboardInterrupt: snapshot_file = f"{global_epoch:0{len(str(args.epochs))}d}.pth" snapshot_file = os.path.join(args.snapshot_dir, snapshot_file) print("Create Snapshot[{}]".format(snapshot_file)) torch.save(model.state_dict(), snapshot_file)