from dataclasses import dataclass
from typing import Tuple

import cv2 as cv
import numpy as np
import pytesseract


@dataclass
class PolarMapSegment:
    id: int
    rect: Tuple[int, int, int, int]
    name: str
    location: str


# see https://www.ahajournals.org/doi/full/10.1161/hc0402.102975
SEGMENTS = [
    PolarMapSegment( id=1,  location="basal",      name="anterior", rect=( 14, 190,  41, 217)),
    PolarMapSegment( id=2,  location="basal",  name="anteroseptal", rect=(101,  36, 128,  63)),
    PolarMapSegment( id=3,  location="basal",  name="inferoseptal", rect=(277,  36, 304,  63)),
    PolarMapSegment( id=4,  location="basal",      name="inferior", rect=(365, 190, 392, 217)),
    PolarMapSegment( id=5,  location="basal", name="inferolateral", rect=(277, 341, 304, 370)),
    PolarMapSegment( id=6,  location="basal", name="anterolateral", rect=(101, 340, 128, 367)),
    PolarMapSegment( id=7,    location="mid",      name="anterior", rect=( 64, 190,  91, 217)),
    PolarMapSegment( id=8,    location="mid",  name="anteroseptal", rect=(128,  80, 155, 107)),
    PolarMapSegment( id=9,    location="mid",  name="inferoseptal", rect=(251,  80, 278, 107)),
    PolarMapSegment(id=10,    location="mid",      name="inferior", rect=(313, 188, 340, 215)),
    PolarMapSegment(id=11,    location="mid", name="inferolateral", rect=(251, 298, 278, 325)),
    PolarMapSegment(id=12,    location="mid", name="anterolateral", rect=(128, 298, 155, 325)),
    PolarMapSegment(id=13, location="apical",      name="anterior", rect=(115, 190, 142, 217)),
    PolarMapSegment(id=14, location="apical",        name="septal", rect=(190, 113, 217, 140)),
    PolarMapSegment(id=15, location="apical",      name="inferior", rect=(265, 188, 292, 215)),
    PolarMapSegment(id=16, location="apical",       name="lateral", rect=(190, 265, 217, 292)),
    PolarMapSegment(id=17,   location="apex",          name="apex", rect=(190, 190, 217, 217)),
]


hsv_green_lower = np.array([40, 100, 100])
hsv_green_upper = np.array([80, 255, 255])


if __name__ == "__main__":
    import argparse
    import os

    import pandas as pd

    from mu_map.data.prepare_polar_maps import headers

    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )
    parser.add_argument(
        "--polar_map_dir", type=str, required=True, help="directory where files are output to"
    )
    parser.add_argument(
        "--images_dir",
        type=str,
        default="images",
        help="directory under <out_dir> where images of polar maps are stored",
    )
    parser.add_argument(
        "--csv",
        type=str,
        default="polar_maps.csv",
        help="file unter <out_dir> where meta information is stored",
    )
    parser.add_argument(
        "--perfusion_csv",
        type=str,
        default="perfusion.csv",
        help="csv file output by this script under polar_map_dir",
    )
    parser.add_argument(
        "--number_res",
        type=int,
        default=128,
        help="numbers cutouts are rescaled to this resolution for better number recognition",
    )
    args = parser.parse_args()

    args.images_dir = os.path.join(args.polar_map_dir, args.images_dir)
    args.csv = os.path.join(args.polar_map_dir, args.csv)
    args.perfusion_csv = os.path.join(args.polar_map_dir, args.perfusion_csv)

    data = pd.read_csv(args.csv)
    rows = []
    for i in range(len(data)):
        row = data.iloc[i]
        _id = row[headers.id]
        _file = os.path.join(args.images_dir, row[headers.file])

        polar_map = cv.imread(_file)
        for segment in SEGMENTS:
            # extract number segment
            top, left, bottom, right = segment.rect
            img_number = polar_map[top:bottom, left:right]

            # process segment for improved automatic number detection
            img_number = cv.resize(img_number, (args.number_res, args.number_res))
            img_number = cv.cvtColor(img_number, cv.COLOR_BGR2HSV)
            img_number = cv.inRange(img_number, hsv_green_lower, hsv_green_upper)
            img_number = cv.morphologyEx(img_number, cv.MORPH_CLOSE, np.ones((4, 4), np.uint8))
            img_number = cv.morphologyEx(img_number, cv.MORPH_OPEN, np.ones((2, 2), np.uint8))
            _, img_number = cv.threshold(img_number, 0, 255, cv.THRESH_BINARY_INV)

            # try to recognize number
            str_number = pytesseract.image_to_string(img_number, config="-c tessedit_char_whitelist=[1,2,3,4,5,6,7,8,9,0] --psm 7")
            str_number = str_number.strip()

            # prepare image for visualization
            _polar_map = polar_map.copy()
            _polar_map = cv.rectangle(_polar_map, (left, top), (right, bottom), (255, 255, 255), 1)
            _polar_map = cv.resize(_polar_map, (512, 512))
            img_number = img_number.repeat(3).reshape((*img_number.shape, 3))
            img_number = cv.resize(img_number, (512, 512))
            space_h = np.full((512, 10, 3), 239, np.uint8)

            cv.imshow("Polar Map - Segment", np.hstack((_polar_map, space_h, img_number)))
            cv.waitKey(50)

            while True:
                _input = input(f"Number is {str_number} (y/other): ")
                try:
                    if _input == "y":
                        number = int(str_number)
                        break
                    elif _input == "q":
                        exit(0)
                    else:
                        number = int(_input)
                        break
                except ValueError:
                    print(f"Cannot parse {_input} as a number. Please enter a valid number.")

            _value = pd.Series({f"segment_{segment.id}": number})
            row = pd.concat([row, _value])
        rows.append(row)
    pd.DataFrame(rows).to_csv(args.perfusion_csv)