Skip to content
Snippets Groups Projects
Commit 93f4da3e authored by Tamino Huxohl's avatar Tamino Huxohl
Browse files

remove outdated train script

parent 934fcc7c
No related branches found
No related tags found
No related merge requests found
import logging
import logging.handlers
log_formatter = logging.Formatter(
fmt="%(asctime)s - %(levelname)s: %(message)s", datefmt="%m/%d/%Y %I:%M:%S %p"
)
logger = logging.getLogger()
logger.setLevel(logging.INFO)
log_handler_console = logging.StreamHandler()
log_handler_console.setFormatter(log_formatter)
logger.addHandler(log_handler_console)
logfile = "train.log"
log_handler_file = logging.handlers.WatchedFileHandler(logfile)
log_handler_file.setFormatter(log_formatter)
logger.addHandler(log_handler_file)
logger.info("This is a test!")
logger.warning("The training loss is over 9000!")
logger.error("This is an error")
logger.info("The end")
args = parser.parse_args()
args.channel_mapping = dict(zip(args.channel_mapping, [0, 1, 2]))
log_formatter = logging.Formatter(
fmt="%(asctime)s - %(levelname)s: %(message)s", datefmt="%m/%d/%Y %I:%M:%S %p"
)
logger = logging.getLogger()
logger.setLevel(logging.INFO)
log_handler_console = logging.StreamHandler()
log_handler_console.setFormatter(log_formatter)
logger.addHandler(log_handler_console)
if args.logfile:
log_handler_file = logging.handlers.WatchedFileHandler(args.logfile)
log_handler_file.setFormatter(log_formatter)
logger.addHandler(log_handler_file)
# Setup snapshot directory
if os.path.exists(args.snapshot_dir):
if not os.path.isdir(args.snapshot_dir):
raise ValueError(
"Snapshot directory[%s] is not a directory!" % args.snapshot_dir
)
else:
os.mkdir(args.snapshot_dir)
random_seed = int(time.time())
logger.info(f"Seed RNG: {random_seed}")
torch.manual_seed(random_seed)
device = torch.device(args.device)
model = load_model(args.model, args.weights, device)
params = model.parameters()
optimizer = model.init_optimizer()
lr_scheduler = torch.optim.lr_scheduler.StepLR(
optimizer, args.decay_epoch, args.decay_rate
)
transforms_train = [transforms.RandomHorizontalFlip(), transforms.RandomVerticalFlip()]
datasets = {}
datasets["train"] = (
CSVDataset(
os.path.join(args.csv_dir, "training.csv"),
args.dataset_dir,
model,
augments=transforms_train,
illuminations=args.illuminations,
random_illumination=True,
)
if not args.illuminations_to_channel
else IlluminationChannelDataset(
os.path.join(args.csv_dir, "training.csv"),
args.dataset_dir,
model,
augments=transforms_train,
channel_mapping=args.channel_mapping,
)
)
datasets["val"] = (
CSVDataset(
os.path.join(args.csv_dir, "validation.csv"),
args.dataset_dir,
model,
illuminations=args.illuminations,
random_illumination=False,
)
if not args.illuminations_to_channel
else IlluminationChannelDataset(
os.path.join(args.csv_dir, "validation.csv"),
args.dataset_dir,
model,
channel_mapping=args.channel_mapping,
)
)
data_loaders = {}
for key in datasets:
data_loaders[key] = torch.utils.data.DataLoader(
dataset=datasets[key],
batch_size=args.batch_size,
shuffle=True,
pin_memory=True,
num_workers=8,
)
global_epoch = 0
def run_epoch(model, optimizer, data_loader, phase="train"):
if phase == "train":
model.train()
else:
model.eval()
running_loss = 0
since = time.time()
count = 1
for inputs, labels in data_loader:
print("Iteration {}/{}".format(count, len(data_loader)) + "\r", end="")
count += 1
inputs = inputs.to(device)
if type(labels) == torch.Tensor:
labels = labels.to(device)
elif type(labels) == list:
labels = [label.to(device) for label in labels]
optimizer.zero_grad()
with torch.set_grad_enabled(phase == "train"):
output = model(inputs)
loss = model.compute_loss(output, labels)
if phase == "train":
loss.backward()
optimizer.step()
running_loss += loss.item() * inputs.size(0)
return running_loss / len(data_loader.dataset), time.time() - since
try:
for epoch in range(1, args.epochs + 1):
if epoch == 1:
str_epoch = "Epoch {}/{}".format(0, args.epochs)
epoch_loss, epoch_time = run_epoch(
model, optimizer, data_loaders["val"], phase="val"
)
str_time = "Time {}s".format(round(epoch_time))
str_loss = "Loss {:.4f}".format(epoch_loss)
print("VAL: " + str_epoch + " " + str_loss + " " + str_time)
logger.info("VAL: " + str_epoch + " " + str_loss + " " + str_time)
epoch_loss, epoch_time = run_epoch(
model, optimizer, data_loaders["train"], phase="val"
)
str_time = "Time {}s".format(round(epoch_time))
str_loss = "Loss {:.4f}".format(epoch_loss)
print("TRAIN: " + str_epoch + " " + str_loss + " " + str_time)
logger.info("TRAIN: " + str_epoch + " " + str_loss + " " + str_time)
global_epoch = epoch
str_epoch = "Epoch {}/{}".format(epoch, args.epochs)
epoch_loss, epoch_time = run_epoch(model, optimizer, data_loaders["train"])
str_time = "Time {}s".format(round(epoch_time))
str_loss = "Loss {:.4f}".format(epoch_loss)
print("TRAIN: " + str_epoch + " " + str_loss + " " + str_time)
logger.info("TRAIN: " + str_epoch + " " + str_loss + " " + str_time)
# create snapshot
if epoch % args.snapshot_epoch == 0:
snapshot_file = f"{global_epoch:0{len(str(args.epochs))}d}.pth"
snapshot_file = os.path.join(args.snapshot_dir, snapshot_file)
print("Create Snapshot[{}]".format(snapshot_file))
logger.info("Create Snapshot[{}]".format(snapshot_file))
torch.save(model.state_dict(), snapshot_file)
# compute validation loss
if epoch % args.val_epoch == 0:
print("VAL: " + str_epoch, end="\r")
epoch_loss, epoch_time = run_epoch(
model, optimizer, data_loaders["val"], phase="val"
)
str_time = "Time {}s".format(round(epoch_time))
str_loss = "Loss {:.4f}".format(epoch_loss)
print("VAL: " + str_epoch + " " + str_loss + " " + str_time)
logger.info("VAL: " + str_epoch + " " + str_loss + " " + str_time)
lr_scheduler.step()
except KeyboardInterrupt:
snapshot_file = f"{global_epoch:0{len(str(args.epochs))}d}.pth"
snapshot_file = os.path.join(args.snapshot_dir, snapshot_file)
print("Create Snapshot[{}]".format(snapshot_file))
torch.save(model.state_dict(), snapshot_file)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment