Newer
Older
from mu_map.dataset.default import MuMapDataset
from mu_map.models.unet import UNet
def mse(prediction: np.array, target: np.array) -> float:
"""
Compute the mean squared error (MSE) between a prediction and
a target array.
Parameters
----------
prediction: np.ndarray
target: np.ndarray
"""
se = (prediction - target) ** 2
mse = se.sum() / se.size
return mse
def nmae(
prediction: np.array, target: np.array, vmax: float = None, vmin: float = None
):
"""
Compute the normalized mean absolute error (NMAE) between a prediction
and a target array.
Parameters
----------
prediction: np.ndarray
target: np.ndarray
vmax: float, optional
maximum value for normalization, defaults to the maximal value in the target
vmin: float, optional
minimum value for normalization, defaults to the minimal value in the target
"""
vmax = target.max() if vmax is None else vmax
vmin = target.min() if vmin is None else vmin
ae = np.absolute(prediction - target)
mae = ae.sum() / ae.size
nmae = mae / (vmax - vmin)
def compute_measures(dataset: MuMapDataset, model: UNet) -> pd.DataFrame:
"""
Compute measures (MSE, NMAE) for all images in a dataset.
Parameters
----------
dataset: MuMapDataset
the dataset containing the reconstructions and mu maps for which the scores are computed
model: UNet
the UNet model which is used to predict mu maps from reconstructions
Returns
-------
pd.DataFrame
a dataframe containing containing the measures for each image in the dataset
"""
device = next(model.parameters()).device
measures = {"NMAE": nmae, "MSE": mse}
values = pd.DataFrame(dict(map(lambda x: (x, []), measures.keys())))
for i, (recon, mu_map) in enumerate(dataset):
_id = dataset.table.iloc[i]["id"]
print(
f"Process input {str(i):>{len(str(len(dataset)))}}/{len(dataset)}", end="\r"
)
prediction = model(recon.unsqueeze(dim=0).to(device))
prediction = prediction.squeeze().cpu().numpy()
mu_map = mu_map.squeeze().cpu().numpy()
row = dict(
map(lambda item: (item[0], [item[1](prediction, mu_map)]), measures.items())
)
row["id"] = _id
row = pd.DataFrame(row)
values = pd.concat((values, row), ignore_index=True)
print(f" " * 100, end="\r")
return values
if __name__ == "__main__":
import argparse
import torch
from mu_map.dataset.normalization import norm_by_str, norm_choices
from mu_map.dataset.transform import SequenceTransform, PadCropTranform
parser = argparse.ArgumentParser(
description="Compute, print and store measures for a given model",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--device",
type=str,
default="cuda" if torch.cuda.is_available() else "cpu",
choices=["cpu", "cuda"],
help="the device on which the model is evaluated (cpu or cuda)",
)
parser.add_argument(
"--weights",
type=str,
required=True,
help="the model weights which should be scored",
)
parser.add_argument("--out", type=str, help="write results as a csv file")
parser.add_argument("--scatter_corrected", action="store_true")
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
parser.add_argument(
"--dataset_dir",
type=str,
default="data/initial/",
help="directory where the dataset is found",
)
parser.add_argument(
"--split",
type=str,
default="validation",
choices=["train", "test", "validation", "all"],
help="the split of the dataset to be processed",
)
parser.add_argument(
"--norm",
type=str,
choices=["none", *norm_choices],
default="mean",
help="type of normalization applied to the reconstructions",
)
parser.add_argument(
"--size",
type=int,
default=32,
help="pad/crop the third tensor dimension to this value",
)
args = parser.parse_args()
if args.split == "all":
torch.set_grad_enabled(False)
device = torch.device(args.device)
model = UNet(features=[32, 64, 128, 256, 512])
model.load_state_dict(torch.load(args.weights, map_location=device))
model = model.to(device).eval()
transform_normalization = SequenceTransform(
transforms=[
PadCropTranform(dim=3, size=args.size),
]
)
dataset = MuMapDataset(
args.dataset_dir,
transform_normalization=transform_normalization,
split_name=args.split,
scatter_correction=args.scatter_corrected,
values = compute_measures(dataset, model)
if args.out:
values.to_csv(args.out, index=False)
print("Scores:")
for measure_name, measure_values in values.items():
if measure_name == "id":
continue
mean = measure_values.mean()
std = np.std(measure_values)