From 32eabd9bc28ffe624573a882b404478356886797 Mon Sep 17 00:00:00 2001
From: Tamino Huxohl <thuxohl@techfak.uni-bielefeld.de>
Date: Tue, 13 Dec 2022 16:51:52 +0100
Subject: [PATCH] add bland altman plot to perfusion evaluation

---
 mu_map/polar_map/eval_perfusion.py | 65 ++++++++++++++++++++++++++++++
 1 file changed, 65 insertions(+)

diff --git a/mu_map/polar_map/eval_perfusion.py b/mu_map/polar_map/eval_perfusion.py
index fedeece..59a8a21 100644
--- a/mu_map/polar_map/eval_perfusion.py
+++ b/mu_map/polar_map/eval_perfusion.py
@@ -1,9 +1,46 @@
+from typing import List, Optional
+
+import matplotlib as mlp
 import matplotlib.pyplot as plt
 import numpy as np
 import pandas as pd
 
 from mu_map.polar_map.prepare import headers
 
+SIZE_DEFAULT = 12
+plt.rc("font", family="Arial")  # controls default font
+plt.rc("font", weight="normal")  # controls default font
+plt.rc("font", size=SIZE_DEFAULT)  # controls default text sizes
+plt.rc("axes", titlesize=16)  # fontsize of the axes title
+
+
+# COLORS=["#a6cee3", "#1f78b4", "#b2df8a"]
+COLORS=["#66c2a5", "#fc8d62", "#8da0cb"]
+COLORS=COLORS[::-1]
+
+def bland_altman(
+    data1: np.ndarray, data2: np.ndarray, ax: Optional[mlp.axes.Axes] = None
+):
+    ax = plt.subplot() if ax is None else ax
+
+    mean = np.mean([data1, data2], axis=0)
+    diff = data1 - data2  # Difference between data1 and data2
+    md = np.mean(diff)  # Mean of the difference
+    sd = np.std(diff, axis=0)  # Standard deviation of the difference
+
+    ax.axhline(md, color="#fc8d59", linestyle="-")
+    for x in [-1.96, 1.96]:
+        ax.axhline(md + x * sd, color="black", linestyle="--", alpha=0.8)
+    ax.scatter(mean, diff, color="#91bfdb", s=25, alpha=0.7, edgecolors="black", linewidths=0.3)
+
+    ax.spines["left"].set_visible(False)
+    ax.spines["right"].set_visible(False)
+    ax.spines["top"].set_visible(False)
+
+    ax.grid(axis="both", alpha=0.5, linestyle="dotted")
+
+
+
 data = pd.read_csv("data/polar_maps/perfusion.csv")
 
 baseline = data[data[headers.ac] & (data[headers.type] == "symbia")]
@@ -19,6 +56,34 @@ _correction_none = correction_none[keys_segments].values
 _correction_syn = correction_syn[keys_segments].values
 _correction_ct = correction_ct[keys_segments].values
 
+fig, axs = plt.subplots(1, 2, figsize=(12, 6))
+bland_altman(_correction_none.flatten(), _correction_ct.flatten(), axs[0])
+bland_altman(_correction_syn.flatten(), _correction_ct.flatten(), axs[1])
+axs[0].set_title("No Attenuation Correction")
+axs[1].set_title("Synthetic Attenuation Correction")
+
+# fig, axs = plt.subplots(1, 3, figsize=(18, 6))
+# bland_altman(_baseline.flatten(), _correction_none.flatten(), ax=axs[0])
+# bland_altman(_baseline.flatten(), _correction_syn.flatten(), ax=axs[1])
+# bland_altman(_baseline.flatten(), _correction_ct.flatten(), ax=axs[2])
+
+
+def normalize_axes(axes: List[mlp.axes.Axes]):
+    x_min = min(map(lambda ax: ax.get_xlim()[0], axes))
+    x_max = max(map(lambda ax: ax.get_xlim()[1], axes))
+    y_min = min(map(lambda ax: ax.get_ylim()[0], axes))
+    y_max = max(map(lambda ax: ax.get_ylim()[1], axes))
+    for ax in axs:
+        ax.set_xlim((x_min, x_max))
+        ax.set_ylim((y_min, y_max))
+
+
+normalize_axes(axs)
+# plt.hist(_correction_syn.flatten() - _correction_ct.flatten())
+# plt.hist(_correction_none.flatten() - _correction_ct.flatten())
+plt.show()
+exit(0)
+
 
 def absolute_percent_error(prediction: np.ndarray, target: np.ndarray) -> float:
     mean_p = prediction.mean(axis=0)
-- 
GitLab