Skip to content
Snippets Groups Projects
Commit bb21d73a authored by Tamino Huxohl's avatar Tamino Huxohl
Browse files

remove outdated top-level scripts

parent a20551d0
No related branches found
No related tags found
No related merge requests found
import os
import time
import numpy as np
import pydicom
import torch
from mu_map.data.prepare import headers
from mu_map.data.remove_bed import add_bed
from mu_map.dataset.default import MuMapDataset
from mu_map.dataset.normalization import GaussianNormTransform
from mu_map.dataset.transform import SequenceTransform, PadCropTranform
from mu_map.file.dicom import load_dcm, update_dcm, change_uid
from mu_map.models.unet import UNet
torch.set_grad_enabled(False)
device = torch.device("cuda")
split_name = "test"
dir_train = "results/cgan_random_search_02/03"
dir_data = "data/second/"
dir_out = "results/mu_map_syn/"
file_weights = os.path.join(dir_train, "snapshots", "val_min_generator.pth")
model = UNet()
model.load_state_dict(torch.load(file_weights, map_location=device))
model = model.to(device)
model = model.eval()
transform_normalization = SequenceTransform([GaussianNormTransform(), PadCropTranform(dim=3, size=32)])
dataset = MuMapDataset(dir_data, transform_normalization=transform_normalization, split_name=split_name)
dataset_with_bed = MuMapDataset(dir_data, transform_normalization=PadCropTranform(dim=3, size=32), split_name=split_name, bed_contours_file=None)
for i, ((recon_nac, _), (_, mu_map)) in enumerate(zip(dataset, dataset_with_bed)):
row = dataset.table.iloc[i]
_id = row[headers.id]
print(f"Process {_id} ...")
recon_nac = recon_nac.to(device)
mu_map_syn = model(recon_nac.unsqueeze(dim=0))
mu_map = mu_map.squeeze().cpu().numpy()
mu_map_syn = mu_map_syn.squeeze().cpu().numpy()
mu_map_syn = np.where(mu_map_syn < 0.0, 0.0, mu_map_syn)
mu_map_syn = add_bed(mu_map_syn, mu_map, dataset.bed_contours[_id])
file_mu_map = os.path.join(dir_data, "images", row[headers.file_mu_map])
dcm_mu_map = pydicom.dcmread(file_mu_map)
dcm_mu_map_syn = update_dcm(dcm_mu_map, mu_map_syn)
dcm_mu_map_syn = change_uid(dcm_mu_map_syn)
base, ext = os.path.splitext(os.path.basename(file_mu_map))
file_mu_map_syn = os.path.join(dir_out, f"{base}_syn{ext}")
pydicom.dcmwrite(file_mu_map_syn, dcm_mu_map_syn)
import cv2 as cv
import matplotlib.pyplot as plt
import numpy as np
import torch
from mu_map.dataset.default import MuMapDataset
from mu_map.dataset.normalization import MeanNormTransform
from mu_map.dataset.transform import PadCropTranform, SequenceTransform
from mu_map.models.unet import UNet
from mu_map.util import to_grayscale, COLOR_WHITE
torch.set_grad_enabled(False)
dataset = MuMapDataset(
"data/second/",
transform_normalization=SequenceTransform([
MeanNormTransform(),
PadCropTranform(dim=3, size=32)
]),
)
model = UNet(in_channels=1)
device = torch.device("cpu")
# weights = torch.load("train_data/snapshots/val_min_Model.pth", map_location=device)
# model.load_state_dict(weights)
model = model.eval()
recon, mu_map = dataset[0]
recon = recon.unsqueeze(dim=0)
output = model(recon)
diff = ((mu_map - output) ** 2).mean()
print(f"Diff: {diff:.5f}")
output = output.squeeze().numpy()
mu_map = mu_map.squeeze().numpy()
assert output.shape[0] == mu_map.shape[0]
wname = "Dataset"
cv.namedWindow(wname, cv.WINDOW_NORMAL)
cv.resizeWindow(wname, 1600, 900)
space = np.full((1024, 10), 239, np.uint8)
def to_display_image(image, _slice, _min=None, _max=None):
_max = _max if _max is not None else image.max()
_min = _min if _min is not None else image.min()
_image = to_grayscale(image[_slice], min_val=_min, max_val=_max)
_image = cv.resize(_image, (1024, 1024), cv.INTER_AREA)
_text = f"{str(_slice):>{len(str(image.shape[0]))}}/{str(image.shape[0])}"
_image = cv.putText(
_image, _text, (0, 30), cv.FONT_HERSHEY_SIMPLEX, 1, COLOR_WHITE, 3
)
return _image
def com(image1, image2, _slice):
image1 = to_display_image(image1, _slice)
image2 = to_display_image(image2, _slice)
space = np.full((image1.shape[0], 10), 239, np.uint8)
return np.hstack((image1, space, image2))
output = np.clip(output, 0, mu_map.max())
i = 0
while True:
x = com(output, mu_map, i)
cv.imshow(wname, x)
key = cv.waitKey(100)
if key == ord("q"):
break
i = (i + 1) % output.shape[0]
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment