Commit f5191931 authored by Tamino Huxohl's avatar Tamino Huxohl
Browse files

further cleanup

parent b1dc1a16
......@@ -67,7 +67,6 @@ Der Unterschied zwischen den beiden Methoden für die Erstellung von Bildausschn
### Vis
* `loss.py`: Visualisieren den Wert der Lossfunktion die während des Trainings optimiert wird. Dafür wird der Log eines Trainings benötigt.
* `pr_curve.py`: Visualisiere eine PR-Kurve (Precision-Recall-Kurve).
* `rdr_curve.py`: Visualisiere die RDR über verschiedene Schwellwerte.
* `side_by_side.py`: Visualisiere Bilder indem sie nebeneinander dargestellt werden.
......
......@@ -16,6 +16,7 @@ PyWavelets==1.2.0
scikit-image==0.19.1
scipy==1.7.3
six==1.16.0
termcolor==1.1.0
tifffile==2021.11.2
torch==1.10.1
torchvision==0.11.2
......
......@@ -12,6 +12,7 @@ install_requires =
Pillow==9.0.0
scikit-image==0.19.1
scipy==1.7.3
termcolor==1.1.0
torch==1.10.1
torchvision==0.11.2
tqdm==4.62.3
import argparse
import csv
import os
import numpy as np
import matplotlib.pyplot as plt
import self_training.eval.measures as measures
def read_csv_file(csv_file):
keys = ["threshold", "precision", "recall", "f_beta", "iou"]
data = dict([(key, []) for key in keys])
with open(csv_file, "r") as csvfile:
reader = csv.DictReader(csvfile, skipinitialspace=True)
for row in reader:
for key in keys:
data[key].append(float(row[key]))
data = dict([(key, np.array(data[key])) for key in data])
return data
def main(args):
if args.labels is not None:
assert len(args.labels) == len(args.csv_files)
else:
args.labels = args.csv_files
fig = plt.figure()
axes = fig.add_axes([0.1, 0.1, 0.8, 0.8])
for csv_file, label in zip(args.csv_files, args.labels):
print(f"{label}:")
data = read_csv_file(csv_file)
data["precision"] = [0.0, *data["precision"], 1.0]
data["recall"] = [1.0, *data["recall"], 0.0]
max_i = np.argmax(data["f_beta"])
print(f" Max FBeta {data['f_beta'][max_i]:.3f} at threshold {data['threshold'][max_i]}")
max_iou_idx = np.argmax(data["iou"])
print(f" Max IoU {data['iou'][max_i]:.3f} at threshold {data['threshold'][max_i]}")
auc = measures.auc(data["precision"], data["recall"])
print(f" AUC {auc:.3f}")
axes.plot(data["precision"], data["recall"], label=label)
plt.xlabel("Precision")
plt.ylabel("Recall")
axes.set_xlim([0, 1.0])
axes.set_ylim([0, 1.0])
axes.legend()
plt.grid(linestyle="--", alpha=0.5)
if args.title:
axes.set_title(args.title)
plt.show()
def add_subparser(subparsers):
parser = subparsers.add_parser("pr_curve",
description="plot the precision-recall-curve and extract the best FBeta and IoU score from a csv file",
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("csv_files", nargs="+", type=str,
help="the csv files from which the prevision and recall values are extracted")
parser.add_argument("--labels", nargs="+", type=str,
help="labels for each pr curve in the csv files")
parser.add_argument("--title", type=str,
help="title of the plot")
parser.set_defaults(func=main)
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment