Skip to content
Snippets Groups Projects
Commit 738b753c authored by Tamino Huxohl's avatar Tamino Huxohl
Browse files

update label outliers script

parent 1db7856b
No related branches found
No related tags found
No related merge requests found
......@@ -27,7 +27,7 @@ Controls:
"""
default_rs_dir = "cgan_random_search/"
default_outfile = os.path.join(default_rs_dir, "outliers.csv")
default_outfile = "outliers.csv"
parser = argparse.ArgumentParser(
description="label random search runs as outliers (not-converged) by having a look at the model's prediction on the validation split of the dataset",
......@@ -53,6 +53,12 @@ parser.add_argument(
)
args = parser.parse_args()
args.out = os.path.join(args.random_search_dir, args.out)
wname = "Label Outlier"
cv.namedWindow(wname, cv.WINDOW_NORMAL)
cv.resizeWindow(wname, 1600, 900)
device = torch.device(args.device)
runs = load_data(args.random_search_dir)
......@@ -66,7 +72,8 @@ print(controls)
total = str(len(runs))
for i, run in enumerate(runs):
print(f"Run {str(i+1):>{len(total)}}/{total}", end="\r")
nmae = runs[run]["measures"]["NMAE"].mean()
print(f"Run {str(i+1):>{len(total)}}/{total} - with NMAE {nmae:.5f}", end="\r")
if run in data["run"]:
continue
......@@ -94,13 +101,11 @@ for i, run in enumerate(runs):
model.load_state_dict(weights)
model = model.to(device).eval()
wname = "Label Outlier"
cv.namedWindow(wname, cv.WINDOW_NORMAL)
cv.resizeWindow(wname, 1600, 900)
def action(key):
if key == ord("o"):
print(f"Run {str(i+1):>{len(total)}}/{total} - Outlier!")
print(
f"Run {str(i+1):>{len(total)}}/{total} - with NMAE {nmae:.5f} - Outlier!"
)
data["outlier"][i] = True
return True
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment