Skip to content
Snippets Groups Projects
Commit 03163d47 authored by Tamino Huxohl's avatar Tamino Huxohl
Browse files

add capability to correclty pare cGAN random search params

parent f9676c4d
No related branches found
No related tags found
No related merge requests found
from logging import Logger from logging import Logger
import json
import os import os
from typing import Any, Dict, Optional from typing import Any, Callable, Dict, List, Optional
import pandas as pd import pandas as pd
import torch import torch
...@@ -10,6 +11,7 @@ from mu_map.dataset.normalization import ( ...@@ -10,6 +11,7 @@ from mu_map.dataset.normalization import (
GaussianNormTransform, GaussianNormTransform,
MaxNormTransform, MaxNormTransform,
MeanNormTransform, MeanNormTransform,
norm_by_str,
) )
from mu_map.dataset.patches import MuMapPatchDataset from mu_map.dataset.patches import MuMapPatchDataset
from mu_map.dataset.transform import PadCropTranform, SequenceTransform from mu_map.dataset.transform import PadCropTranform, SequenceTransform
...@@ -240,6 +242,112 @@ class cGANRandomSearch(RandomSearch): ...@@ -240,6 +242,112 @@ class cGANRandomSearch(RandomSearch):
return params return params
class ParamJSONDecoder(json.JSONDecoder):
"""
A custom JSON decoder to the parameters of a cGAN random search run.
"""
def __init__(self):
super().__init__()
self.int_fields = [
"patch_size",
"patch_offset",
"patch_number",
"epochs",
"batch_size",
"lr_decay_epoch",
]
self.float_fields = [
"lr",
"lr_decay_factor",
"weight_crit_dist",
"weight_crit_adv",
]
self.bool_fiels = ["scatter_correction", "shuffle", "lr_decay"]
self.int_list_fields = ["discriminator_conv_features", "generator_features"]
def decode(self, s: str) -> Dict[str, Any]:
"""
Decode a JSON string into a dict of parameters.
"""
params = super().decode(s)
self.parse_fields(params, int, *self.int_fields)
self.parse_fields(params, float, *self.float_fields)
self.parse_fields(params, lambda v: v == "True", *self.bool_fiels)
self.parse_fields(
params, lambda v: self.parse_as_list(v, int), *self.int_list_fields
)
self.parse_fields(params, WeightedLoss.from_str, "criterion_dist")
self.parse_fields(params, norm_by_str, "normalization")
params["pad_crop"] = (
None if params["pad_crop"] == "None" else PadCropTranform(dim=3, size=32)
)
return params
def parse_fields(
self, params: Dict[str, Any], func: Callable[str, Any], *fields: str
):
"""
Parse fields in a dict with a specified function.
This function makes sure that the fields exist and are currently string.
Parameters
----------
params: Dict[str, Any]
the dict whose values are parsed
func: Callable[str, Any]
the function used for parsing
*fields: str
the fields to parse
"""
fields = filter(lambda field: field in params.keys(), fields)
fields = filter(lambda field: type(params[field]) == str, fields)
for field in fields:
params[field] = func(params[field])
def parse_as_list(self, s: str, func: Callable[str, Any]) -> List[Any]:
"""
Parse a field as a list.
Parameters
----------
s: str
the string to be parsed as a list
func:
the parsing function for list elements
Returns
-------
List[Any]
"""
s = s[1:-1] # remove brackets
values = s.split(",")
values = map(lambda x: x.strip(), values)
values = map(func, values)
return list(values)
def load_params(filename: str) -> Dict[str, Any]:
"""
Load parameters of a cGAN random search from a file.
Parameters
----------
filename: str
the file to be read
Returns
-------
Dict[str, Any]
"""
with open(filename, mode="r") as f:
return json.load(f, cls=ParamJSONDecoder)
if __name__ == "__main__": if __name__ == "__main__":
import argparse import argparse
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment