from typing import Tuple

import numpy as np
import pydicom

"""
Since DICOM images only allow images stored in short integer format,
the Siemens scanner software multiplies values by a factor before storing
so that no precision is lost.
The scale can be found in this private DICOM tag.
"""
DCM_TAG_PIXEL_SCALE_FACTOR = 0x00331038


def load_dcm_img(filename: str) -> np.ndarray:
    """
    Load a DICOM image as a numpy array and apply normalization of the Siemens SPECT/CT
    Scanner.

    :param filename: filename of the DICOM image
    :return: the image scaled and loaded into a numpy array
    """
    image = pydicom.dcmread(filename)
    image = image.pixel_array / image[DCM_TAG_PIXEL_SCALE_FACTOR].value
    return image


def align_images(
    image_1: np.ndarray, image_2: np.ndarray
) -> Tuple[np.ndarray, np.ndarray]:
    """
    Center align the image with more slices to the one with fewer slices on the first axis (z-axis).

    :param image_1: the image to be aligned
    :param image_2: the image to which image_1 is aligned
    :return: both images aligned in the order they were put in
    """
    # reverse function if image_2 has more slices
    if image_2.shape[0] > image_1.shape[0]:
        return align_images(image_2, image_1)[::-1]

    # central slice of image 2
    c_2 = image_2.shape[0] // 2
    # image to the left and right of the center
    left = c_2
    right = image_2.shape[0] - c_2

    # central slice of image 1
    c_1 = image_1.shape[0] // 2
    # select center and same amount to the left/right as image_2
    return image_1[(c_1 - left) : (c_1 + right)], image_2