From e44371f231fda698ddb93e8f3dd3f82043129bd5 Mon Sep 17 00:00:00 2001 From: Tamino Huxohl <thuxohl@techfak.uni-bielefeld.de> Date: Tue, 27 Sep 2022 13:30:37 +0200 Subject: [PATCH] add capability to parse logfiles and generate loss curves from that --- mu_map/logging.py | 39 +++++++++++++++++++++--- mu_map/vis/loss_curve.py | 66 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 100 insertions(+), 5 deletions(-) create mode 100644 mu_map/vis/loss_curve.py diff --git a/mu_map/logging.py b/mu_map/logging.py index 62f447d..fb43420 100644 --- a/mu_map/logging.py +++ b/mu_map/logging.py @@ -1,15 +1,16 @@ import argparse -import datetime +from dataclasses import dataclass +from datetime import datetime import logging from logging import Formatter, getLogger, StreamHandler from logging.handlers import WatchedFileHandler import os import shutil -from typing import Dict, Optional - +from typing import Dict, Optional, List +date_format="%m/%d/%Y %I:%M:%S" FORMATTER = Formatter( - fmt="%(asctime)s - %(levelname)7s - %(message)s", datefmt="%m/%d/%Y %I:%M:%S" + fmt="%(asctime)s - %(levelname)7s - %(message)s", datefmt=date_format ) @@ -35,7 +36,7 @@ def add_logging_args(parser: argparse.ArgumentParser, defaults: Dict[str, str]): def timestamp_filename(filename: str): - timestamp = datetime.datetime.now().strftime("%Y-%m-%d-%H:%M:%S") + timestamp = datetime.now().strftime("%Y-%m-%d-%H:%M:%S") basename, ext = os.path.splitext(filename) return f"{basename}_{timestamp}{ext}" @@ -71,6 +72,34 @@ def get_logger_by_args(args): return get_logger(args.logfile, args.loglevel) +@dataclass +class LogLine: + time: datetime + loglevel: str + message: str + + def __repr__(self): + return f"{self.time.strftime(date_format)} - {self.loglevel:>7} - {self.message}" + + +def parse_line(logline): + _split = logline.strip().split("-") + assert len(_split) >= 3, f"A logged line should consists of a least three elements with the format [TIME - LOGLEVEL - MESSAGE] but got [{logline.strip()}]" + + time_str = _split[0].strip() + time = datetime.strptime(time_str, date_format) + + loglevel = _split[1].strip() + + message = "-".join(_split[2:]).strip() + return LogLine(time=time, loglevel=loglevel, message=message) + +def parse_file(logfile: str) -> List[LogLine]: + with open(logfile, mode="r") as f: + lines = f.readlines() + lines = map(parse_line, lines) + return list(lines) + if __name__ == "__main__": parser = argparse.ArgumentParser() add_logging_args(parser, defaults={"--loglevel": "DEBUG", "--logfile": "tmp.log"}) diff --git a/mu_map/vis/loss_curve.py b/mu_map/vis/loss_curve.py new file mode 100644 index 0000000..336b1ca --- /dev/null +++ b/mu_map/vis/loss_curve.py @@ -0,0 +1,66 @@ +import argparse + +import matplotlib.pyplot as plt +import numpy as np + +from mu_map.logging import parse_file + +SIZE_DEFAULT = 12 +plt.rc("font", family="Roboto") # controls default font +plt.rc("font", weight="normal") # controls default font +plt.rc("font", size=SIZE_DEFAULT) # controls default text sizes +plt.rc("axes", titlesize=18) # fontsize of the axes title + +# https://colorbrewer2.org/#type=diverging&scheme=RdBu&n=3lk +COLORS = ["#ef8a62", "#67a9cf"] + +parser = argparse.ArgumentParser(description="TODO", formatter_class=argparse.ArgumentDefaultsHelpFormatter) +parser.add_argument("logfile", type=str, help="TODO") +parser.add_argument("--normalize", action="store_true", help="TODO") +args = parser.parse_args() + +logs = parse_file(args.logfile) +logs = list(filter(lambda logline: logline.loglevel == "INFO", logs)) + +def parse_loss(logs, phase): + _logs = map(lambda logline: logline.message, logs) + _logs = filter(lambda log: phase in log, _logs) + _logs = filter(lambda log: "Loss" in log, _logs) + _logs = list(_logs) + + losses = map(lambda log: log.split("-")[-1].strip(), _logs) + losses = map(lambda log: log.split(":")[-1].strip(), losses) + losses = map(float, losses) + + epochs = map(lambda log: log.split("-")[0].strip(), _logs) + epochs = list(epochs) + epochs = map(lambda log: log.split(" ")[-1], epochs) + epochs = map(lambda log: log.split("/")[0], epochs) + epochs = map(int, epochs) + + return np.array(list(epochs)), np.array(list(losses)) + +phases = ["TRAIN", "VAL"] +labels = ["Training", "Validation"] + +fig, ax = plt.subplots() +for phase, label, color in zip(phases, labels, COLORS): + epochs, loss = parse_loss(logs, phase) + + if args.normalize: + loss = loss / loss.max() + + ax.plot(epochs, loss, label=label, color=color) + ax.scatter(epochs, loss, s=15, color=color) + +ax.spines["left"].set_visible(False) +ax.spines["right"].set_visible(False) +ax.spines["top"].set_visible(False) + +ax.grid(axis="y", alpha=0.7) +ax.legend() +ax.set_xlabel("Epoch") +ax.set_ylabel("Loss") +plt.tight_layout() +plt.show() + -- GitLab