diff --git a/mu_map/vis/loss_curve.py b/mu_map/vis/loss_curve.py index c17bb846fb4089652292b29152b0dd49147090cb..b0ae6c71e326e05d24ff09f7104941686e9d3cc4 100644 --- a/mu_map/vis/loss_curve.py +++ b/mu_map/vis/loss_curve.py @@ -29,8 +29,8 @@ def parse_loss(logs, loss_type): return np.array(list(epochs)), np.array(list(losses)) -def plot_loss(logfile, loss_types, ax, normalize=False, from_epoch=0, to_epoch=None): - logs = parse_file(logfile) +def plot_loss(logfile, loss_types, ax, normalize=False, from_epoch=0, to_epoch=None, start_idx_message=3): + logs = parse_file(logfile, start_idx_message=start_idx_message) logs = list(filter(lambda logline: logline.loglevel == "INFO", logs)) for i, loss_type in enumerate(loss_types):