Skip to content
Snippets Groups Projects
loss_curve.py 3.79 KiB
Newer Older
  • Learn to ignore specific revisions
  • import argparse
    
    import matplotlib.pyplot as plt
    import numpy as np
    
    from mu_map.logging import parse_file
    
    
    # https://colorbrewer2.org/#type=diverging&scheme=RdBu&n=3lk
    COLORS = ["#ef8a62", "#67a9cf"]
    
    
    def parse_loss(logs, loss_type):
    
        _logs = map(lambda logline: logline.message, logs)
    
        _logs = filter(lambda log: loss_type in log, _logs)
    
        _logs = filter(lambda log: "Loss" in log, _logs)
        _logs = list(_logs)
    
        losses = map(lambda log: log.split("-")[-1].strip(), _logs)
        losses = map(lambda log: log.split(":")[-1].strip(), losses)
        losses = map(float, losses)
    
        epochs = map(lambda log: log.split("-")[0].strip(), _logs)
        epochs = list(epochs)
        epochs = map(lambda log: log.split(" ")[-1], epochs)
        epochs = map(lambda log: log.split("/")[0], epochs)
        epochs = map(int, epochs)
    
        return np.array(list(epochs)), np.array(list(losses))
    
    
    def plot_loss(logfile, loss_types, ax, normalize=False, from_epoch=0, to_epoch=None, start_idx_message=3):
        logs = parse_file(logfile, start_idx_message=start_idx_message)
    
        logs = list(filter(lambda logline: logline.loglevel == "INFO", logs))
    
        for i, loss_type in enumerate(loss_types):
            epochs, loss = parse_loss(logs, loss_type)
    
    
            if to_epoch == None:
                to_epoch = max(epochs)
    
            _filter = np.logical_and(from_epoch <= epochs, epochs <= to_epoch)
            epochs = epochs[_filter]
            loss = loss[_filter]
    
    
            if normalize:
                loss = loss / loss.max()
    
            label = loss_type[0].upper() + loss_type[1:]
            color = COLORS[i % len(COLORS)]
            ax.plot(epochs, loss, label=label, color=color)
            ax.scatter(epochs, loss, s=15, color=color)
    
        ax.spines["left"].set_visible(False)
        ax.spines["right"].set_visible(False)
        ax.spines["top"].set_visible(False)
    
        ax.grid(axis="y", alpha=0.7)
        ax.legend()
        ax.set_xlabel("Epoch")
        ax.set_ylabel("Loss")
    
    
    if __name__ == "__main__":
        SIZE_DEFAULT = 12
        plt.rc("font", family="Roboto")  # controls default font
        plt.rc("font", weight="normal")  # controls default font
        plt.rc("font", size=SIZE_DEFAULT)  # controls default text sizes
        plt.rc("axes", titlesize=18)  # fontsize of the axes title
    
        parser = argparse.ArgumentParser(
            description="plot the losses written to a logfile",
            formatter_class=argparse.ArgumentDefaultsHelpFormatter,
        )
        parser.add_argument(
            "logfile",
            type=str,
            help="the logfile from which the losses (training and validation) are parsed",
        )
        parser.add_argument(
            "--loss_types",
            type=str,
            nargs="+",
            default=["train", "validation"],
            help="the types of losses the log is searched for",
        )
        parser.add_argument(
            "--normalize",
            action="store_true",
            help="normalize the loss values (both training and validation losses are normalized separately)",
        )
        parser.add_argument(
            "--out",
            "-o",
            type=str,
            default="loss.png",
            help="the file into which the resulting plot is saved",
        )
        parser.add_argument(
            "--verbose",
            "-v",
            action="store_true",
            help="do not only save the figure but also attempt to visualize it (opens a window)",
        )
    
        parser.add_argument(
            "--from_epoch",
            type=int,
            default=0,
            help="start plotting from this epoch",
        )
        parser.add_argument(
            "--to_epoch",
            type=int,
            help="only plot to this epoch",
        )
    
        args = parser.parse_args()
    
        fig, ax = plt.subplots()
    
        plot_loss(
            args.logfile,
            args.loss_types,
            ax,
            normalize=args.normalize,
            from_epoch=args.from_epoch,
            to_epoch=args.to_epoch,
        )
    
    
        plt.tight_layout()
        plt.savefig(args.out, dpi=300)
    
        if args.verbose:
            plt.show()