Newer
Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
from mu_map.eval.measures import nmae, mse
if __name__ == "__main__":
import argparse
import json
import os
import numpy as np
import pandas as pd
import torch
from mu_map.data.prepare import headers
from mu_map.data.remove_bed import add_bed
from mu_map.dataset.default import MuMapDataset
from mu_map.dataset.util import load_dcm_img, align_images
from mu_map.dataset.transform import SequenceTransform, PadCropTranform
from mu_map.models.unet import UNet
from mu_map.training.random_search import normalization_by_params, scatter_correction_by_params
from mu_map.util import reconstruct
parser = argparse.ArgumentParser(
description="Compute, print and store measures for a given model based on the resulting reconstructions",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--device",
type=str,
default="cuda",
choices=["cpu", "cuda"],
help="the device on which the model is evaluated (cpu or cuda)",
)
parser.add_argument(
"--dir_train",
type=str,
required=True,
help="directory where training results (snapshots, params) are stored",
)
parser.add_argument("--out", type=str, help="write results as a csv file")
parser.add_argument(
"--dataset_dir",
type=str,
default="data/second/",
help="directory where the dataset is found",
)
parser.add_argument(
"--split",
type=str,
default="validation",
choices=["train", "test", "validation", "all"],
help="the split of the dataset to be processed",
)
args = parser.parse_args()
if args.split == "all":
args.split = None
torch.set_grad_enabled(False)
device = torch.device(args.device)
with open(os.path.join(args.dir_train, "params.json"), mode="r") as f:
params = json.load(f)
weights = os.path.join(args.dir_train, "snapshots", "val_min_generator.pth")
model = UNet()
model.load_state_dict(torch.load(weights, map_location=device))
model = model.to(device).eval()
transform_pad_crop = PadCropTranform(dim=3, size=32)
transform_normalization = SequenceTransform(
transforms=[
normalization_by_params(params),
transform_pad_crop,
]
)
dataset = MuMapDataset(
args.dataset_dir,
transform_normalization=transform_normalization,
split_name=args.split,
scatter_correction=scatter_correction_by_params(params),
)
dataset_with_bed = MuMapDataset(args.dataset_dir, transform_normalization=transform_pad_crop, split_name=args.split, bed_contours_file=None)
values = pd.DataFrame({
"NMAE_NAC_TO_AC": [],
"NMAE_SYN_TO_AC": [],
"NMAE_CT_TO_AC": [],
"NMAE_NAC_TO_CT": [],
"NMAE_SYN_TO_CT": [],
})
for i, ((recon, _), (recon_nac, mu_map_ct)) in enumerate(zip(dataset, dataset_with_bed)):
print(
f"Process input {str(i):>{len(str(len(dataset)))}}/{len(dataset)}", end="\r"
)
_row = dataset.table.iloc[i]
mu_map_syn = model(recon.unsqueeze(dim=0).to(device))
mu_map_syn = mu_map_syn.squeeze().cpu().numpy()
mu_map_ct = mu_map_ct.squeeze().cpu().numpy()
mu_map_syn = add_bed(mu_map_syn, mu_map_ct, bed_contour=dataset.bed_contours[_row["id"]])
recon_nac = recon_nac.squeeze().cpu().numpy()
recon_ac = load_dcm_img(os.path.join(dataset.dir_images, _row[headers.file_recon_ac_nsc]))
recon_ac = torch.from_numpy(recon_ac)
recon_ac, _ = transform_pad_crop(recon_ac, recon_ac)
recon_ac = recon_ac.cpu().numpy()
recon_ac_syn = reconstruct(recon_nac.copy(), mu_map=mu_map_syn.copy(), use_gpu=args.device=="cuda")
recon_ac_ct = reconstruct(recon_nac.copy(), mu_map=mu_map_ct.copy(), use_gpu=args.device=="cuda")
row = pd.DataFrame({
"NMAE_NAC_TO_AC": [nmae(recon_nac, recon_ac)],
"NMAE_SYN_TO_AC": [nmae(recon_ac_syn, recon_ac)],
"NMAE_CT_TO_AC": [nmae(recon_ac_ct, recon_ac)],
"NMAE_NAC_TO_CT": [nmae(recon_nac, recon_ac_ct)],
"NMAE_SYN_TO_CT": [nmae(recon_ac_syn, recon_ac_ct)],
})
values = pd.concat((values, row), ignore_index=True)
print(f" " * 100, end="\r")
if args.out:
values.to_csv(args.out, index=False)
print("Scores:")
for measure_name, measure_values in values.items():
mean = measure_values.mean()
std = np.std(measure_values)
print(f" - {measure_name:>20}: {mean:.6f}±{std:.6f}")