From 32eabd9bc28ffe624573a882b404478356886797 Mon Sep 17 00:00:00 2001 From: Tamino Huxohl <thuxohl@techfak.uni-bielefeld.de> Date: Tue, 13 Dec 2022 16:51:52 +0100 Subject: [PATCH] add bland altman plot to perfusion evaluation --- mu_map/polar_map/eval_perfusion.py | 65 ++++++++++++++++++++++++++++++ 1 file changed, 65 insertions(+) diff --git a/mu_map/polar_map/eval_perfusion.py b/mu_map/polar_map/eval_perfusion.py index fedeece..59a8a21 100644 --- a/mu_map/polar_map/eval_perfusion.py +++ b/mu_map/polar_map/eval_perfusion.py @@ -1,9 +1,46 @@ +from typing import List, Optional + +import matplotlib as mlp import matplotlib.pyplot as plt import numpy as np import pandas as pd from mu_map.polar_map.prepare import headers +SIZE_DEFAULT = 12 +plt.rc("font", family="Arial") # controls default font +plt.rc("font", weight="normal") # controls default font +plt.rc("font", size=SIZE_DEFAULT) # controls default text sizes +plt.rc("axes", titlesize=16) # fontsize of the axes title + + +# COLORS=["#a6cee3", "#1f78b4", "#b2df8a"] +COLORS=["#66c2a5", "#fc8d62", "#8da0cb"] +COLORS=COLORS[::-1] + +def bland_altman( + data1: np.ndarray, data2: np.ndarray, ax: Optional[mlp.axes.Axes] = None +): + ax = plt.subplot() if ax is None else ax + + mean = np.mean([data1, data2], axis=0) + diff = data1 - data2 # Difference between data1 and data2 + md = np.mean(diff) # Mean of the difference + sd = np.std(diff, axis=0) # Standard deviation of the difference + + ax.axhline(md, color="#fc8d59", linestyle="-") + for x in [-1.96, 1.96]: + ax.axhline(md + x * sd, color="black", linestyle="--", alpha=0.8) + ax.scatter(mean, diff, color="#91bfdb", s=25, alpha=0.7, edgecolors="black", linewidths=0.3) + + ax.spines["left"].set_visible(False) + ax.spines["right"].set_visible(False) + ax.spines["top"].set_visible(False) + + ax.grid(axis="both", alpha=0.5, linestyle="dotted") + + + data = pd.read_csv("data/polar_maps/perfusion.csv") baseline = data[data[headers.ac] & (data[headers.type] == "symbia")] @@ -19,6 +56,34 @@ _correction_none = correction_none[keys_segments].values _correction_syn = correction_syn[keys_segments].values _correction_ct = correction_ct[keys_segments].values +fig, axs = plt.subplots(1, 2, figsize=(12, 6)) +bland_altman(_correction_none.flatten(), _correction_ct.flatten(), axs[0]) +bland_altman(_correction_syn.flatten(), _correction_ct.flatten(), axs[1]) +axs[0].set_title("No Attenuation Correction") +axs[1].set_title("Synthetic Attenuation Correction") + +# fig, axs = plt.subplots(1, 3, figsize=(18, 6)) +# bland_altman(_baseline.flatten(), _correction_none.flatten(), ax=axs[0]) +# bland_altman(_baseline.flatten(), _correction_syn.flatten(), ax=axs[1]) +# bland_altman(_baseline.flatten(), _correction_ct.flatten(), ax=axs[2]) + + +def normalize_axes(axes: List[mlp.axes.Axes]): + x_min = min(map(lambda ax: ax.get_xlim()[0], axes)) + x_max = max(map(lambda ax: ax.get_xlim()[1], axes)) + y_min = min(map(lambda ax: ax.get_ylim()[0], axes)) + y_max = max(map(lambda ax: ax.get_ylim()[1], axes)) + for ax in axs: + ax.set_xlim((x_min, x_max)) + ax.set_ylim((y_min, y_max)) + + +normalize_axes(axs) +# plt.hist(_correction_syn.flatten() - _correction_ct.flatten()) +# plt.hist(_correction_none.flatten() - _correction_ct.flatten()) +plt.show() +exit(0) + def absolute_percent_error(prediction: np.ndarray, target: np.ndarray) -> float: mean_p = prediction.mean(axis=0) -- GitLab