Skip to content
Snippets Groups Projects
Commit 0cc09b15 authored by Tamino Huxohl's avatar Tamino Huxohl
Browse files

add a code for training

parent d80575c9
No related branches found
No related tags found
No related merge requests found
...@@ -14,5 +14,13 @@ y = GaussianNorm()(x) ...@@ -14,5 +14,13 @@ y = GaussianNorm()(x)
print(f" After: mean={y.mean():.3f} std={y.std():.3f}") print(f" After: mean={y.mean():.3f} std={y.std():.3f}")
import cv2 as cv
import numpy as np
x = np.zeros((512, 512), np.uint8)
cv.imshow("X", x)
key = cv.waitKey(0)
while key != ord("q"):
print(key)
key = cv.waitKey(0)
import logging
import logging.handlers
log_formatter = logging.Formatter(
fmt="%(asctime)s - %(levelname)s: %(message)s", datefmt="%m/%d/%Y %I:%M:%S %p"
)
logger = logging.getLogger()
logger.setLevel(logging.INFO)
log_handler_console = logging.StreamHandler()
log_handler_console.setFormatter(log_formatter)
logger.addHandler(log_handler_console)
logfile = "train.log"
log_handler_file = logging.handlers.WatchedFileHandler(logfile)
log_handler_file.setFormatter(log_formatter)
logger.addHandler(log_handler_file)
logger.info("This is a test!")
logger.warning("The training loss is over 9000!")
logger.error("This is an error")
logger.info("The end")
args = parser.parse_args()
args.channel_mapping = dict(zip(args.channel_mapping, [0, 1, 2]))
log_formatter = logging.Formatter(
fmt="%(asctime)s - %(levelname)s: %(message)s", datefmt="%m/%d/%Y %I:%M:%S %p"
)
logger = logging.getLogger()
logger.setLevel(logging.INFO)
log_handler_console = logging.StreamHandler()
log_handler_console.setFormatter(log_formatter)
logger.addHandler(log_handler_console)
if args.logfile:
log_handler_file = logging.handlers.WatchedFileHandler(args.logfile)
log_handler_file.setFormatter(log_formatter)
logger.addHandler(log_handler_file)
# Setup snapshot directory
if os.path.exists(args.snapshot_dir):
if not os.path.isdir(args.snapshot_dir):
raise ValueError(
"Snapshot directory[%s] is not a directory!" % args.snapshot_dir
)
else:
os.mkdir(args.snapshot_dir)
random_seed = int(time.time())
logger.info(f"Seed RNG: {random_seed}")
torch.manual_seed(random_seed)
device = torch.device(args.device)
model = load_model(args.model, args.weights, device)
params = model.parameters()
optimizer = model.init_optimizer()
lr_scheduler = torch.optim.lr_scheduler.StepLR(
optimizer, args.decay_epoch, args.decay_rate
)
transforms_train = [transforms.RandomHorizontalFlip(), transforms.RandomVerticalFlip()]
datasets = {}
datasets["train"] = (
CSVDataset(
os.path.join(args.csv_dir, "training.csv"),
args.dataset_dir,
model,
augments=transforms_train,
illuminations=args.illuminations,
random_illumination=True,
)
if not args.illuminations_to_channel
else IlluminationChannelDataset(
os.path.join(args.csv_dir, "training.csv"),
args.dataset_dir,
model,
augments=transforms_train,
channel_mapping=args.channel_mapping,
)
)
datasets["val"] = (
CSVDataset(
os.path.join(args.csv_dir, "validation.csv"),
args.dataset_dir,
model,
illuminations=args.illuminations,
random_illumination=False,
)
if not args.illuminations_to_channel
else IlluminationChannelDataset(
os.path.join(args.csv_dir, "validation.csv"),
args.dataset_dir,
model,
channel_mapping=args.channel_mapping,
)
)
data_loaders = {}
for key in datasets:
data_loaders[key] = torch.utils.data.DataLoader(
dataset=datasets[key],
batch_size=args.batch_size,
shuffle=True,
pin_memory=True,
num_workers=8,
)
global_epoch = 0
def run_epoch(model, optimizer, data_loader, phase="train"):
if phase == "train":
model.train()
else:
model.eval()
running_loss = 0
since = time.time()
count = 1
for inputs, labels in data_loader:
print("Iteration {}/{}".format(count, len(data_loader)) + "\r", end="")
count += 1
inputs = inputs.to(device)
if type(labels) == torch.Tensor:
labels = labels.to(device)
elif type(labels) == list:
labels = [label.to(device) for label in labels]
optimizer.zero_grad()
with torch.set_grad_enabled(phase == "train"):
output = model(inputs)
loss = model.compute_loss(output, labels)
if phase == "train":
loss.backward()
optimizer.step()
running_loss += loss.item() * inputs.size(0)
return running_loss / len(data_loader.dataset), time.time() - since
try:
for epoch in range(1, args.epochs + 1):
if epoch == 1:
str_epoch = "Epoch {}/{}".format(0, args.epochs)
epoch_loss, epoch_time = run_epoch(
model, optimizer, data_loaders["val"], phase="val"
)
str_time = "Time {}s".format(round(epoch_time))
str_loss = "Loss {:.4f}".format(epoch_loss)
print("VAL: " + str_epoch + " " + str_loss + " " + str_time)
logger.info("VAL: " + str_epoch + " " + str_loss + " " + str_time)
epoch_loss, epoch_time = run_epoch(
model, optimizer, data_loaders["train"], phase="val"
)
str_time = "Time {}s".format(round(epoch_time))
str_loss = "Loss {:.4f}".format(epoch_loss)
print("TRAIN: " + str_epoch + " " + str_loss + " " + str_time)
logger.info("TRAIN: " + str_epoch + " " + str_loss + " " + str_time)
global_epoch = epoch
str_epoch = "Epoch {}/{}".format(epoch, args.epochs)
epoch_loss, epoch_time = run_epoch(model, optimizer, data_loaders["train"])
str_time = "Time {}s".format(round(epoch_time))
str_loss = "Loss {:.4f}".format(epoch_loss)
print("TRAIN: " + str_epoch + " " + str_loss + " " + str_time)
logger.info("TRAIN: " + str_epoch + " " + str_loss + " " + str_time)
# create snapshot
if epoch % args.snapshot_epoch == 0:
snapshot_file = f"{global_epoch:0{len(str(args.epochs))}d}.pth"
snapshot_file = os.path.join(args.snapshot_dir, snapshot_file)
print("Create Snapshot[{}]".format(snapshot_file))
logger.info("Create Snapshot[{}]".format(snapshot_file))
torch.save(model.state_dict(), snapshot_file)
# compute validation loss
if epoch % args.val_epoch == 0:
print("VAL: " + str_epoch, end="\r")
epoch_loss, epoch_time = run_epoch(
model, optimizer, data_loaders["val"], phase="val"
)
str_time = "Time {}s".format(round(epoch_time))
str_loss = "Loss {:.4f}".format(epoch_loss)
print("VAL: " + str_epoch + " " + str_loss + " " + str_time)
logger.info("VAL: " + str_epoch + " " + str_loss + " " + str_time)
lr_scheduler.step()
except KeyboardInterrupt:
snapshot_file = f"{global_epoch:0{len(str(args.epochs))}d}.pth"
snapshot_file = os.path.join(args.snapshot_dir, snapshot_file)
print("Create Snapshot[{}]".format(snapshot_file))
torch.save(model.state_dict(), snapshot_file)
class Training():
def __init__(self, epochs):
self.epochs = epochs
def run(self):
for epoch in range(1, self.epochs + 1):
self.run_epoch(self.data_loader["train"], phase="train")
loss_training = self.run_epoch(self.data_loader["train"], phase="eval")
loss_validation = self.run_epoch(self.data_loader["validation"], phase="eval")
# ToDo: log outputs and time
self.lr_scheduler.step()
if epoch % self.snapshot_epoch:
self.store_snapshot(epoch)
def run_epoch(self, data_loader, phase):
self.model.train() if phase == "train" else self.model.eval()
epoch_loss = 0
for inputs, labels in self.data_loader:
inputs = inputs.to(self.device)
labels = labels.to(self.device)
self.optimizer.zero_grad()
with torch.set_grad_enabled(phase == "train"):
outputs = self.model(inputs)
loss = self.loss(outputs, labels)
if phase == "train":
loss.backward()
optimizer.step()
epoch_loss += loss.item() / inputs.size[0]
return epoch_loss
def store_snapshot(self, epoch):
pass
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment