From eecbb75450f33e7ff3cc285894065180fcffb02f Mon Sep 17 00:00:00 2001 From: Tamino Huxohl <thuxohl@techfak.uni-bielefeld.de> Date: Thu, 26 Jan 2023 15:50:59 +0100 Subject: [PATCH] add tool to label random search runs that die not converge --- mu_map/random_search/label_outliers.py | 116 +++++++++++++++++++++++++ 1 file changed, 116 insertions(+) create mode 100644 mu_map/random_search/label_outliers.py diff --git a/mu_map/random_search/label_outliers.py b/mu_map/random_search/label_outliers.py new file mode 100644 index 0000000..4803bf8 --- /dev/null +++ b/mu_map/random_search/label_outliers.py @@ -0,0 +1,116 @@ +import argparse +import os + +import cv2 as cv +import pandas as pd +import torch + +from mu_map.dataset.default import MuMapDataset +from mu_map.dataset.transform import SequenceTransform, PadCropTranform +from mu_map.models.unet import UNet +from mu_map.random_search.cgan import load_params +from mu_map.random_search.show_predictions import main + +controls = """ +Controls: + + q: quit/exit the application + n: show the next mu map + p: pause on the current slice or resume if paused + + o: mark run as an outlier + s: mark run as valid +""" + +default_rs_dir = "cgan_random_search/" +default_outfile = os.path.join(default_rs_dir, "outliers.csv") + +parser = argparse.ArgumentParser( + description="label random search runs as outliers (not-converged) by having a look at the model's prediction on the validation split of the dataset", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, +) +parser.add_argument( + "--random_search_dir", + type=str, + default=default_rs_dir, + help="directory where the runs of a random search are stored", +) +parser.add_argument( + "--out", + type=str, + default=default_outfile, + help="the file where outliers are stored - if it already exists, missing runs can be added", +) +parser.add_argument( + "--device", + choices=["cpu", "gpu"], + default="gpu" if torch.cuda.is_available() else "cpu", + help="the device on which the model of the random search run is evaluated", +) +args = parser.parse_args() + +device = torch.device(args.device) + +runs = sorted(os.listdir(args.random_search_dir)) +runs = map(lambda f: os.path.join(args.random_search_dir, f), runs) +runs = filter(lambda f: os.path.isdir(f), runs) +runs = filter(lambda f: not os.path.islink(f), runs) +runs = map(lambda f: os.path.basename(f), runs) +runs = list(runs) + +data = {"run": [], "outlier": []} +if os.path.isfile(args.out): + available_data = pd.read_csv(args.out) + data["run"] = list(available_data["run"]) + data["outlier"] = list(available_data["outlier"]) + +print(controls) + +total = str(len(runs)) +for i, run in enumerate(runs): + if int(run) in data["run"]: + continue + + print(f"Run {str(i+1):>{len(total)}}/{total}", end="\r") + data["run"].append(int(run)) + data["outlier"].append(False) + + _dir = os.path.join(args.random_search_dir, run) + + params = load_params(os.path.join(_dir, "params.json")) + + dataset = MuMapDataset( + "data/second/", + transform_normalization=SequenceTransform( + [params["normalization"], PadCropTranform(dim=3, size=32)] + ), + split_name="validation", + scatter_correction=False, + ) + + model = UNet(features=params["generator_features"]) + weights = torch.load( + os.path.join(_dir, "snapshots", "val_min_generator.pth"), + map_location=device, + ) + model.load_state_dict(weights) + model = model.to(device).eval() + + wname = "Label Outlier" + cv.namedWindow(wname, cv.WINDOW_NORMAL) + cv.resizeWindow(wname, 1600, 900) + + def action(key): + if key == ord("o"): + print(f"Run {str(i+1):>{len(total)}}/{total} - Outlier!") + data["outlier"][i] = True + return True + + if key == ord("s"): + return True + + return False + + main(model, dataset, wname, action, _print=False) + + pd.DataFrame(data).to_csv(args.out, index=False) -- GitLab