Skip to content
Snippets Groups Projects
Commit eecbb754 authored by Tamino Huxohl's avatar Tamino Huxohl
Browse files

add tool to label random search runs that die not converge

parent 828569e1
No related branches found
No related tags found
No related merge requests found
import argparse
import os
import cv2 as cv
import pandas as pd
import torch
from mu_map.dataset.default import MuMapDataset
from mu_map.dataset.transform import SequenceTransform, PadCropTranform
from mu_map.models.unet import UNet
from mu_map.random_search.cgan import load_params
from mu_map.random_search.show_predictions import main
controls = """
Controls:
q: quit/exit the application
n: show the next mu map
p: pause on the current slice or resume if paused
o: mark run as an outlier
s: mark run as valid
"""
default_rs_dir = "cgan_random_search/"
default_outfile = os.path.join(default_rs_dir, "outliers.csv")
parser = argparse.ArgumentParser(
description="label random search runs as outliers (not-converged) by having a look at the model's prediction on the validation split of the dataset",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--random_search_dir",
type=str,
default=default_rs_dir,
help="directory where the runs of a random search are stored",
)
parser.add_argument(
"--out",
type=str,
default=default_outfile,
help="the file where outliers are stored - if it already exists, missing runs can be added",
)
parser.add_argument(
"--device",
choices=["cpu", "gpu"],
default="gpu" if torch.cuda.is_available() else "cpu",
help="the device on which the model of the random search run is evaluated",
)
args = parser.parse_args()
device = torch.device(args.device)
runs = sorted(os.listdir(args.random_search_dir))
runs = map(lambda f: os.path.join(args.random_search_dir, f), runs)
runs = filter(lambda f: os.path.isdir(f), runs)
runs = filter(lambda f: not os.path.islink(f), runs)
runs = map(lambda f: os.path.basename(f), runs)
runs = list(runs)
data = {"run": [], "outlier": []}
if os.path.isfile(args.out):
available_data = pd.read_csv(args.out)
data["run"] = list(available_data["run"])
data["outlier"] = list(available_data["outlier"])
print(controls)
total = str(len(runs))
for i, run in enumerate(runs):
if int(run) in data["run"]:
continue
print(f"Run {str(i+1):>{len(total)}}/{total}", end="\r")
data["run"].append(int(run))
data["outlier"].append(False)
_dir = os.path.join(args.random_search_dir, run)
params = load_params(os.path.join(_dir, "params.json"))
dataset = MuMapDataset(
"data/second/",
transform_normalization=SequenceTransform(
[params["normalization"], PadCropTranform(dim=3, size=32)]
),
split_name="validation",
scatter_correction=False,
)
model = UNet(features=params["generator_features"])
weights = torch.load(
os.path.join(_dir, "snapshots", "val_min_generator.pth"),
map_location=device,
)
model.load_state_dict(weights)
model = model.to(device).eval()
wname = "Label Outlier"
cv.namedWindow(wname, cv.WINDOW_NORMAL)
cv.resizeWindow(wname, 1600, 900)
def action(key):
if key == ord("o"):
print(f"Run {str(i+1):>{len(total)}}/{total} - Outlier!")
data["outlier"][i] = True
return True
if key == ord("s"):
return True
return False
main(model, dataset, wname, action, _print=False)
pd.DataFrame(data).to_csv(args.out, index=False)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment