diff --git a/mu_map/vis/loss_curve.py b/mu_map/vis/loss_curve.py
index 8262b5e4d99018f1e92988765435eda08c2e173a..71585f315c3d731e3d0ea521da938a358b401880 100644
--- a/mu_map/vis/loss_curve.py
+++ b/mu_map/vis/loss_curve.py
@@ -15,13 +15,21 @@ plt.rc("axes", titlesize=18)  # fontsize of the axes title
 COLORS = ["#ef8a62", "#67a9cf"]
 
 parser = argparse.ArgumentParser(
-    description="plot the losses written to a logfile", formatter_class=argparse.ArgumentDefaultsHelpFormatter
+    description="plot the losses written to a logfile",
+    formatter_class=argparse.ArgumentDefaultsHelpFormatter,
 )
 parser.add_argument(
     "logfile",
     type=str,
     help="the logfile from which the losses (training and validation) are parsed",
 )
+parser.add_argument(
+    "--loss_types",
+    type=str,
+    nargs="+",
+    default=["train", "validation"],
+    help="the types of losses the log is searched for",
+)
 parser.add_argument(
     "--normalize",
     action="store_true",
@@ -46,9 +54,9 @@ logs = parse_file(args.logfile)
 logs = list(filter(lambda logline: logline.loglevel == "INFO", logs))
 
 
-def parse_loss(logs, phase):
+def parse_loss(logs, loss_type):
     _logs = map(lambda logline: logline.message, logs)
-    _logs = filter(lambda log: phase in log, _logs)
+    _logs = filter(lambda log: loss_type in log, _logs)
     _logs = filter(lambda log: "Loss" in log, _logs)
     _logs = list(_logs)
 
@@ -65,16 +73,15 @@ def parse_loss(logs, phase):
     return np.array(list(epochs)), np.array(list(losses))
 
 
-phases = ["TRAIN", "VAL"]
-labels = ["Training", "Validation"]
-
 fig, ax = plt.subplots()
-for phase, label, color in zip(phases, labels, COLORS):
-    epochs, loss = parse_loss(logs, phase)
+for i, loss_type in enumerate(args.loss_types):
+    epochs, loss = parse_loss(logs, loss_type)
 
     if args.normalize:
         loss = loss / loss.max()
 
+    label = loss_type[0].upper() + loss_type[1:]
+    color = COLORS[i % len(COLORS)]
     ax.plot(epochs, loss, label=label, color=color)
     ax.scatter(epochs, loss, s=15, color=color)