Skip to content
Snippets Groups Projects
train.py 6.58 KiB
Newer Older
  • Learn to ignore specific revisions
  • Tamino Huxohl's avatar
    Tamino Huxohl committed
    import logging
    import logging.handlers
    
    log_formatter = logging.Formatter(
        fmt="%(asctime)s - %(levelname)s: %(message)s", datefmt="%m/%d/%Y %I:%M:%S %p"
    )
    
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    
    log_handler_console = logging.StreamHandler()
    log_handler_console.setFormatter(log_formatter)
    logger.addHandler(log_handler_console)
    
    logfile = "train.log"
    log_handler_file = logging.handlers.WatchedFileHandler(logfile)
    log_handler_file.setFormatter(log_formatter)
    logger.addHandler(log_handler_file)
    
    logger.info("This is a test!")
    logger.warning("The training loss is over 9000!")
    logger.error("This is an error")
    logger.info("The end")
    
    args = parser.parse_args()
    args.channel_mapping = dict(zip(args.channel_mapping, [0, 1, 2]))
    
    log_formatter = logging.Formatter(
        fmt="%(asctime)s - %(levelname)s: %(message)s", datefmt="%m/%d/%Y %I:%M:%S %p"
    )
    
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    
    log_handler_console = logging.StreamHandler()
    log_handler_console.setFormatter(log_formatter)
    logger.addHandler(log_handler_console)
    
    if args.logfile:
        log_handler_file = logging.handlers.WatchedFileHandler(args.logfile)
        log_handler_file.setFormatter(log_formatter)
        logger.addHandler(log_handler_file)
    
    
    # Setup snapshot directory
    if os.path.exists(args.snapshot_dir):
        if not os.path.isdir(args.snapshot_dir):
            raise ValueError(
                "Snapshot directory[%s] is not a directory!" % args.snapshot_dir
            )
    else:
        os.mkdir(args.snapshot_dir)
    
    random_seed = int(time.time())
    logger.info(f"Seed RNG: {random_seed}")
    
    torch.manual_seed(random_seed)
    
    device = torch.device(args.device)
    model = load_model(args.model, args.weights, device)
    
    params = model.parameters()
    optimizer = model.init_optimizer()
    lr_scheduler = torch.optim.lr_scheduler.StepLR(
        optimizer, args.decay_epoch, args.decay_rate
    )
    
    transforms_train = [transforms.RandomHorizontalFlip(), transforms.RandomVerticalFlip()]
    datasets = {}
    datasets["train"] = (
        CSVDataset(
            os.path.join(args.csv_dir, "training.csv"),
            args.dataset_dir,
            model,
            augments=transforms_train,
            illuminations=args.illuminations,
            random_illumination=True,
        )
        if not args.illuminations_to_channel
        else IlluminationChannelDataset(
            os.path.join(args.csv_dir, "training.csv"),
            args.dataset_dir,
            model,
            augments=transforms_train,
            channel_mapping=args.channel_mapping,
        )
    )
    datasets["val"] = (
        CSVDataset(
            os.path.join(args.csv_dir, "validation.csv"),
            args.dataset_dir,
            model,
            illuminations=args.illuminations,
            random_illumination=False,
        )
        if not args.illuminations_to_channel
        else IlluminationChannelDataset(
            os.path.join(args.csv_dir, "validation.csv"),
            args.dataset_dir,
            model,
            channel_mapping=args.channel_mapping,
        )
    )
    
    data_loaders = {}
    for key in datasets:
        data_loaders[key] = torch.utils.data.DataLoader(
            dataset=datasets[key],
            batch_size=args.batch_size,
            shuffle=True,
            pin_memory=True,
            num_workers=8,
        )
    
    global_epoch = 0
    
    
    def run_epoch(model, optimizer, data_loader, phase="train"):
        if phase == "train":
            model.train()
        else:
            model.eval()
    
        running_loss = 0
        since = time.time()
        count = 1
        for inputs, labels in data_loader:
            print("Iteration {}/{}".format(count, len(data_loader)) + "\r", end="")
            count += 1
    
            inputs = inputs.to(device)
            if type(labels) == torch.Tensor:
                labels = labels.to(device)
            elif type(labels) == list:
                labels = [label.to(device) for label in labels]
    
            optimizer.zero_grad()
    
            with torch.set_grad_enabled(phase == "train"):
                output = model(inputs)
                loss = model.compute_loss(output, labels)
    
                if phase == "train":
                    loss.backward()
                    optimizer.step()
    
            running_loss += loss.item() * inputs.size(0)
    
        return running_loss / len(data_loader.dataset), time.time() - since
    
    
    try:
        for epoch in range(1, args.epochs + 1):
            if epoch == 1:
                str_epoch = "Epoch {}/{}".format(0, args.epochs)
    
                epoch_loss, epoch_time = run_epoch(
                    model, optimizer, data_loaders["val"], phase="val"
                )
    
                str_time = "Time {}s".format(round(epoch_time))
                str_loss = "Loss {:.4f}".format(epoch_loss)
    
                print("VAL: " + str_epoch + " " + str_loss + " " + str_time)
                logger.info("VAL: " + str_epoch + " " + str_loss + " " + str_time)
    
                epoch_loss, epoch_time = run_epoch(
                    model, optimizer, data_loaders["train"], phase="val"
                )
    
                str_time = "Time {}s".format(round(epoch_time))
                str_loss = "Loss {:.4f}".format(epoch_loss)
    
                print("TRAIN: " + str_epoch + " " + str_loss + " " + str_time)
                logger.info("TRAIN: " + str_epoch + " " + str_loss + " " + str_time)
    
            global_epoch = epoch
            str_epoch = "Epoch {}/{}".format(epoch, args.epochs)
    
            epoch_loss, epoch_time = run_epoch(model, optimizer, data_loaders["train"])
    
            str_time = "Time {}s".format(round(epoch_time))
            str_loss = "Loss {:.4f}".format(epoch_loss)
    
            print("TRAIN: " + str_epoch + " " + str_loss + " " + str_time)
            logger.info("TRAIN: " + str_epoch + " " + str_loss + " " + str_time)
    
            # create snapshot
            if epoch % args.snapshot_epoch == 0:
                snapshot_file = f"{global_epoch:0{len(str(args.epochs))}d}.pth"
                snapshot_file = os.path.join(args.snapshot_dir, snapshot_file)
                print("Create Snapshot[{}]".format(snapshot_file))
                logger.info("Create Snapshot[{}]".format(snapshot_file))
                torch.save(model.state_dict(), snapshot_file)
    
            # compute validation loss
            if epoch % args.val_epoch == 0:
                print("VAL: " + str_epoch, end="\r")
    
                epoch_loss, epoch_time = run_epoch(
                    model, optimizer, data_loaders["val"], phase="val"
                )
    
                str_time = "Time {}s".format(round(epoch_time))
                str_loss = "Loss {:.4f}".format(epoch_loss)
    
                print("VAL: " + str_epoch + " " + str_loss + " " + str_time)
                logger.info("VAL: " + str_epoch + " " + str_loss + " " + str_time)
    
            lr_scheduler.step()
    except KeyboardInterrupt:
        snapshot_file = f"{global_epoch:0{len(str(args.epochs))}d}.pth"
        snapshot_file = os.path.join(args.snapshot_dir, snapshot_file)
        print("Create Snapshot[{}]".format(snapshot_file))
        torch.save(model.state_dict(), snapshot_file)