Skip to content
GitLab
Explore
Sign in
Register
Primary navigation
Search or go to…
Project
M
mu-map
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Package registry
Container Registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Service Desk
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Terms and privacy
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Tamino Huxohl
mu-map
Commits
03163d47
Commit
03163d47
authored
2 years ago
by
Tamino Huxohl
Browse files
Options
Downloads
Patches
Plain Diff
add capability to correclty pare cGAN random search params
parent
f9676c4d
No related branches found
Branches containing commit
No related tags found
No related merge requests found
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
mu_map/random_search/cgan.py
+109
-1
109 additions, 1 deletion
mu_map/random_search/cgan.py
with
109 additions
and
1 deletion
mu_map/random_search/cgan.py
+
109
−
1
View file @
03163d47
from
logging
import
Logger
from
logging
import
Logger
import
json
import
os
import
os
from
typing
import
Any
,
Dic
t
,
Optional
from
typing
import
Any
,
Callable
,
Dict
,
Lis
t
,
Optional
import
pandas
as
pd
import
pandas
as
pd
import
torch
import
torch
...
@@ -10,6 +11,7 @@ from mu_map.dataset.normalization import (
...
@@ -10,6 +11,7 @@ from mu_map.dataset.normalization import (
GaussianNormTransform
,
GaussianNormTransform
,
MaxNormTransform
,
MaxNormTransform
,
MeanNormTransform
,
MeanNormTransform
,
norm_by_str
,
)
)
from
mu_map.dataset.patches
import
MuMapPatchDataset
from
mu_map.dataset.patches
import
MuMapPatchDataset
from
mu_map.dataset.transform
import
PadCropTranform
,
SequenceTransform
from
mu_map.dataset.transform
import
PadCropTranform
,
SequenceTransform
...
@@ -240,6 +242,112 @@ class cGANRandomSearch(RandomSearch):
...
@@ -240,6 +242,112 @@ class cGANRandomSearch(RandomSearch):
return
params
return
params
class
ParamJSONDecoder
(
json
.
JSONDecoder
):
"""
A custom JSON decoder to the parameters of a cGAN random search run.
"""
def
__init__
(
self
):
super
().
__init__
()
self
.
int_fields
=
[
"
patch_size
"
,
"
patch_offset
"
,
"
patch_number
"
,
"
epochs
"
,
"
batch_size
"
,
"
lr_decay_epoch
"
,
]
self
.
float_fields
=
[
"
lr
"
,
"
lr_decay_factor
"
,
"
weight_crit_dist
"
,
"
weight_crit_adv
"
,
]
self
.
bool_fiels
=
[
"
scatter_correction
"
,
"
shuffle
"
,
"
lr_decay
"
]
self
.
int_list_fields
=
[
"
discriminator_conv_features
"
,
"
generator_features
"
]
def
decode
(
self
,
s
:
str
)
->
Dict
[
str
,
Any
]:
"""
Decode a JSON string into a dict of parameters.
"""
params
=
super
().
decode
(
s
)
self
.
parse_fields
(
params
,
int
,
*
self
.
int_fields
)
self
.
parse_fields
(
params
,
float
,
*
self
.
float_fields
)
self
.
parse_fields
(
params
,
lambda
v
:
v
==
"
True
"
,
*
self
.
bool_fiels
)
self
.
parse_fields
(
params
,
lambda
v
:
self
.
parse_as_list
(
v
,
int
),
*
self
.
int_list_fields
)
self
.
parse_fields
(
params
,
WeightedLoss
.
from_str
,
"
criterion_dist
"
)
self
.
parse_fields
(
params
,
norm_by_str
,
"
normalization
"
)
params
[
"
pad_crop
"
]
=
(
None
if
params
[
"
pad_crop
"
]
==
"
None
"
else
PadCropTranform
(
dim
=
3
,
size
=
32
)
)
return
params
def
parse_fields
(
self
,
params
:
Dict
[
str
,
Any
],
func
:
Callable
[
str
,
Any
],
*
fields
:
str
):
"""
Parse fields in a dict with a specified function.
This function makes sure that the fields exist and are currently string.
Parameters
----------
params: Dict[str, Any]
the dict whose values are parsed
func: Callable[str, Any]
the function used for parsing
*fields: str
the fields to parse
"""
fields
=
filter
(
lambda
field
:
field
in
params
.
keys
(),
fields
)
fields
=
filter
(
lambda
field
:
type
(
params
[
field
])
==
str
,
fields
)
for
field
in
fields
:
params
[
field
]
=
func
(
params
[
field
])
def
parse_as_list
(
self
,
s
:
str
,
func
:
Callable
[
str
,
Any
])
->
List
[
Any
]:
"""
Parse a field as a list.
Parameters
----------
s: str
the string to be parsed as a list
func:
the parsing function for list elements
Returns
-------
List[Any]
"""
s
=
s
[
1
:
-
1
]
# remove brackets
values
=
s
.
split
(
"
,
"
)
values
=
map
(
lambda
x
:
x
.
strip
(),
values
)
values
=
map
(
func
,
values
)
return
list
(
values
)
def
load_params
(
filename
:
str
)
->
Dict
[
str
,
Any
]:
"""
Load parameters of a cGAN random search from a file.
Parameters
----------
filename: str
the file to be read
Returns
-------
Dict[str, Any]
"""
with
open
(
filename
,
mode
=
"
r
"
)
as
f
:
return
json
.
load
(
f
,
cls
=
ParamJSONDecoder
)
if
__name__
==
"
__main__
"
:
if
__name__
==
"
__main__
"
:
import
argparse
import
argparse
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment