Newer
Older
from mu_map.dataset.default import MuMapDataset
from mu_map.models.unet import UNet
def mse(prediction: np.array, target: np.array) -> float:
"""
Compute the mean squared error (MSE) between a prediction and
a target array.
Parameters
----------
prediction: np.ndarray
target: np.ndarray
"""
se = (prediction - target) ** 2
mse = se.sum() / se.size
return mse
def nmae(
prediction: np.array, target: np.array, vmax: float = None, vmin: float = None
):
"""
Compute the normalized mean absolute error (NMAE) between a prediction
and a target array.
Parameters
----------
prediction: np.ndarray
target: np.ndarray
vmax: float, optional
maximum value for normalization, defaults to the maximal value in the target
vmin: float, optional
minimum value for normalization, defaults to the minimal value in the target
"""
vmax = target.max() if vmax is None else vmax
vmin = target.min() if vmin is None else vmin
ae = np.absolute(prediction - target)
mae = ae.sum() / ae.size
nmae = mae / (vmax - vmin)
def compute_measures(dataset: MuMapDataset, model: UNet) -> pd.DataFrame:
"""
Compute measures (MSE, NMAE) for all images in a dataset.
Parameters
----------
dataset: MuMapDataset
the dataset containing the reconstructions and mu maps for which the scores are computed
model: UNet
the UNet model which is used to predict mu maps from reconstructions
Returns
-------
pd.DataFrame
a dataframe containing containing the measures for each image in the dataset
"""
device = next(model.parameters()).device
measures = {"NMAE": nmae, "MSE": mse}
values = pd.DataFrame(dict(map(lambda x: (x, []), measures.keys())))
for i, (recon, mu_map) in enumerate(dataset):
_id = dataset.table.iloc[i]["id"]
print(
f"Process input {str(i):>{len(str(len(dataset)))}}/{len(dataset)}", end="\r"
)
prediction = model(recon.unsqueeze(dim=0).to(device))
prediction = prediction.squeeze().cpu().numpy()
mu_map = mu_map.squeeze().cpu().numpy()
row = dict(
map(lambda item: (item[0], [item[1](prediction, mu_map)]), measures.items())
)
row["id"] = _id
row = pd.DataFrame(row)
values = pd.concat((values, row), ignore_index=True)
print(f" " * 100, end="\r")
return values
if __name__ == "__main__":
import argparse
import torch
from mu_map.dataset.normalization import norm_by_str, norm_choices
from mu_map.dataset.transform import SequenceTransform, PadCropTranform
parser = argparse.ArgumentParser(
description="Compute, print and store measures for a given model",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--device",
type=str,
default="cuda" if torch.cuda.is_available() else "cpu",
choices=["cpu", "cuda"],
help="the device on which the model is evaluated (cpu or cuda)",
)
parser.add_argument(
"--weights",
type=str,
required=True,
help="the model weights which should be scored",
)
parser.add_argument("--out", type=str, help="write results as a csv file")
parser.add_argument("--scatter_corrected", action="store_true")
parser.add_argument(
"--dataset_dir",
type=str,
help="directory where the dataset is found",
)
parser.add_argument(
"--split",
type=str,
default="validation",
choices=["train", "test", "validation", "all"],
help="the split of the dataset to be processed",
)
parser.add_argument(
"--norm",
type=str,
choices=["none", *norm_choices],
default="mean",
help="type of normalization applied to the reconstructions",
)
parser.add_argument(
"--size",
type=int,
default=32,
help="pad/crop the third tensor dimension to this value",
)
args = parser.parse_args()
if args.split == "all":
torch.set_grad_enabled(False)
device = torch.device(args.device)
# model = UNet(features=[32, 64, 128, 256, 512])
model = UNet(features=[64, 128, 256, 512])
model.load_state_dict(torch.load(args.weights, map_location=device))
model = model.to(device).eval()
transform_normalization = SequenceTransform(
transforms=[
PadCropTranform(dim=3, size=args.size),
]
)
dataset = MuMapDataset(
args.dataset_dir,
transform_normalization=transform_normalization,
split_name=args.split,
scatter_correction=args.scatter_corrected,
values = compute_measures(dataset, model)
if args.out:
values.to_csv(args.out, index=False)
print("Scores:")
for measure_name, measure_values in values.items():
if measure_name == "id":
continue
mean = measure_values.mean()
std = np.std(measure_values)