Skip to content
Snippets Groups Projects
Commit 9774f924 authored by Tamino Huxohl's avatar Tamino Huxohl
Browse files

label outliers script now uses load data method defined in util

parent eecbb754
No related branches found
No related tags found
No related merge requests found
...@@ -10,6 +10,7 @@ from mu_map.dataset.transform import SequenceTransform, PadCropTranform ...@@ -10,6 +10,7 @@ from mu_map.dataset.transform import SequenceTransform, PadCropTranform
from mu_map.models.unet import UNet from mu_map.models.unet import UNet
from mu_map.random_search.cgan import load_params from mu_map.random_search.cgan import load_params
from mu_map.random_search.show_predictions import main from mu_map.random_search.show_predictions import main
from mu_map.random_search.eval.util import load_data
controls = """ controls = """
Controls: Controls:
...@@ -50,13 +51,7 @@ parser.add_argument( ...@@ -50,13 +51,7 @@ parser.add_argument(
args = parser.parse_args() args = parser.parse_args()
device = torch.device(args.device) device = torch.device(args.device)
runs = load_data(args.random_search_dir)
runs = sorted(os.listdir(args.random_search_dir))
runs = map(lambda f: os.path.join(args.random_search_dir, f), runs)
runs = filter(lambda f: os.path.isdir(f), runs)
runs = filter(lambda f: not os.path.islink(f), runs)
runs = map(lambda f: os.path.basename(f), runs)
runs = list(runs)
data = {"run": [], "outlier": []} data = {"run": [], "outlier": []}
if os.path.isfile(args.out): if os.path.isfile(args.out):
...@@ -68,16 +63,17 @@ print(controls) ...@@ -68,16 +63,17 @@ print(controls)
total = str(len(runs)) total = str(len(runs))
for i, run in enumerate(runs): for i, run in enumerate(runs):
if int(run) in data["run"]: print(f"Run {str(i+1):>{len(total)}}/{total}", end="\r")
if run in data["run"]:
continue continue
print(f"Run {str(i+1):>{len(total)}}/{total}", end="\r")
data["run"].append(int(run)) data["run"].append(int(run))
data["outlier"].append(False) data["outlier"].append(False)
_dir = os.path.join(args.random_search_dir, run)
params = load_params(os.path.join(_dir, "params.json")) dir_run = os.path.join(args.random_search_dir, runs[run]["dir"])
params = runs[run]["params"]
dataset = MuMapDataset( dataset = MuMapDataset(
"data/second/", "data/second/",
...@@ -90,7 +86,7 @@ for i, run in enumerate(runs): ...@@ -90,7 +86,7 @@ for i, run in enumerate(runs):
model = UNet(features=params["generator_features"]) model = UNet(features=params["generator_features"])
weights = torch.load( weights = torch.load(
os.path.join(_dir, "snapshots", "val_min_generator.pth"), os.path.join(dir_run, "snapshots", "val_min_generator.pth"),
map_location=device, map_location=device,
) )
model.load_state_dict(weights) model.load_state_dict(weights)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment