From eecbb75450f33e7ff3cc285894065180fcffb02f Mon Sep 17 00:00:00 2001
From: Tamino Huxohl <thuxohl@techfak.uni-bielefeld.de>
Date: Thu, 26 Jan 2023 15:50:59 +0100
Subject: [PATCH] add tool to label random search runs that die not converge

---
 mu_map/random_search/label_outliers.py | 116 +++++++++++++++++++++++++
 1 file changed, 116 insertions(+)
 create mode 100644 mu_map/random_search/label_outliers.py

diff --git a/mu_map/random_search/label_outliers.py b/mu_map/random_search/label_outliers.py
new file mode 100644
index 0000000..4803bf8
--- /dev/null
+++ b/mu_map/random_search/label_outliers.py
@@ -0,0 +1,116 @@
+import argparse
+import os
+
+import cv2 as cv
+import pandas as pd
+import torch
+
+from mu_map.dataset.default import MuMapDataset
+from mu_map.dataset.transform import SequenceTransform, PadCropTranform
+from mu_map.models.unet import UNet
+from mu_map.random_search.cgan import load_params
+from mu_map.random_search.show_predictions import main
+
+controls = """
+Controls:
+
+    q: quit/exit the application
+    n: show the next mu map
+    p: pause on the current slice or resume if paused
+
+    o: mark run as an outlier
+    s: mark run as valid
+"""
+
+default_rs_dir = "cgan_random_search/"
+default_outfile = os.path.join(default_rs_dir, "outliers.csv")
+
+parser = argparse.ArgumentParser(
+    description="label random search runs as outliers (not-converged) by having a look at the model's prediction on the validation split of the dataset",
+    formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+)
+parser.add_argument(
+    "--random_search_dir",
+    type=str,
+    default=default_rs_dir,
+    help="directory where the runs of a random search are stored",
+)
+parser.add_argument(
+    "--out",
+    type=str,
+    default=default_outfile,
+    help="the file where outliers are stored - if it already exists, missing runs can be added",
+)
+parser.add_argument(
+    "--device",
+    choices=["cpu", "gpu"],
+    default="gpu" if torch.cuda.is_available() else "cpu",
+    help="the device on which the model of the random search run is evaluated",
+)
+args = parser.parse_args()
+
+device = torch.device(args.device)
+
+runs = sorted(os.listdir(args.random_search_dir))
+runs = map(lambda f: os.path.join(args.random_search_dir, f), runs)
+runs = filter(lambda f: os.path.isdir(f), runs)
+runs = filter(lambda f: not os.path.islink(f), runs)
+runs = map(lambda f: os.path.basename(f), runs)
+runs = list(runs)
+
+data = {"run": [], "outlier": []}
+if os.path.isfile(args.out):
+    available_data = pd.read_csv(args.out)
+    data["run"] = list(available_data["run"])
+    data["outlier"] = list(available_data["outlier"])
+
+print(controls)
+
+total = str(len(runs))
+for i, run in enumerate(runs):
+    if int(run) in data["run"]:
+        continue
+
+    print(f"Run {str(i+1):>{len(total)}}/{total}", end="\r")
+    data["run"].append(int(run))
+    data["outlier"].append(False)
+
+    _dir = os.path.join(args.random_search_dir, run)
+
+    params = load_params(os.path.join(_dir, "params.json"))
+
+    dataset = MuMapDataset(
+        "data/second/",
+        transform_normalization=SequenceTransform(
+            [params["normalization"], PadCropTranform(dim=3, size=32)]
+        ),
+        split_name="validation",
+        scatter_correction=False,
+    )
+
+    model = UNet(features=params["generator_features"])
+    weights = torch.load(
+        os.path.join(_dir, "snapshots", "val_min_generator.pth"),
+        map_location=device,
+    )
+    model.load_state_dict(weights)
+    model = model.to(device).eval()
+
+    wname = "Label Outlier"
+    cv.namedWindow(wname, cv.WINDOW_NORMAL)
+    cv.resizeWindow(wname, 1600, 900)
+
+    def action(key):
+        if key == ord("o"):
+            print(f"Run {str(i+1):>{len(total)}}/{total} - Outlier!")
+            data["outlier"][i] = True
+            return True
+
+        if key == ord("s"):
+            return True
+
+        return False
+
+    main(model, dataset, wname, action, _print=False)
+
+    pd.DataFrame(data).to_csv(args.out, index=False)
-- 
GitLab