from datetime import datetime, timedelta
from typing import Tuple

import numpy as np
import pydicom


"""
Since DICOM images only allow images stored in short integer format,
the Siemens scanner software multiplies values by a factor before storing
so that no precision is lost.
The scale can be found in this private DICOM tag.
"""
DCM_TAG_PIXEL_SCALE_FACTOR = 0x00331038
"""
Maximum value that can be stored in an unsigned integer with 16 bist.
"""
UINT16_MAX = 2**16 - 1
"""
DICOM images always contain UIDs to indicate their uniqueness.
Thus, when a DICOM image is updated, UIDs have to be changed for
which the following prefix is used.
"""
UID_PREFIX = "1.2.826.0.1.3680043.2.521."


def load_dcm(filename: str) -> Tuple[pydicom.dataset.FileDataset, np.ndarray]:
    """
    Load a DICOM image, the data as a numpy array and apply normalization of the Siemens SPECT/CT
    Scanner.

    :param filename: filename of the DICOM image
    :return: the dicom header and the scaled image as a numpy array
    """
    dcm = pydicom.dcmread(filename)
    image = dcm.pixel_array / dcm[DCM_TAG_PIXEL_SCALE_FACTOR].value
    return dcm, image


def load_dcm_img(filename: str) -> np.ndarray:
    """
    Load a DICOM image as a numpy array and apply normalization of the Siemens SPECT/CT
    Scanner.

    :param filename: filename of the DICOM image
    :return: the image scaled and loaded into a numpy array
    """
    _, image = load_dcm(filename)
    return image


def scale_image(image: np.ndarray, initial_scale=10000000) -> Tuple[np.ndarray, float]:
    """
    For memory efficiency, the Siemens SPECT/CT does not store images as floating point
    numbers, but as unsigned integers with 16 bits. In order to somewhat keep precision,
    the floating points are scaled with a factor of 10^x where x is chosen so in a way
    that keeps the numbers in range of uint16. This function replicated this process.

    :param image: an image in floating points format
    :param initial_scale: the initial scale which is reduced until the maximum number
    is smaller than the maximum uint16 number
    :return: the image scaled and converted to uint16 as well as the used scaling factor
    """
    scale = initial_scale
    while (scale * image.max()) > UINT16_MAX:
        scale = scale / 10
    image = (image * scale).astype(np.uint16)
    return image, scale


def update_dcm(
    dcm: pydicom.dataset.FileDataset, image: np.ndarray
) -> pydicom.dataset.FileDataset:
    """
    Update the image data in a DICOM file. This function scales the image, converts
    it to unsigned integers with 16 bits and updates the pixel data in the DICOM file.
    Additionally, other related tags in the DICOM header, such as image dimensions and
    maximum pixel values, are updated accordingly.
    Note that this function modifies the given DICOM file. If you want to keep the old
    one, you should copy it first.

    :param dcm: the DICOM file to be udpated
    :param image: the image put into the DICOM file
    :return: the updated DICOM file
    """
    image, scale = scale_image(image)

    dcm.NumberOfFrames = image.shape[0]
    dcm.NumberOfSlices = image.shape[0]
    dcm.SliceVector = list(range(1, image.shape[0] + 1))
    dcm.Columns = image.shape[1]
    dcm.Rows = image.shape[2]
    dcm.PixelData = image.tobytes()
    dcm.WindowWidth = image.max()
    dcm.WindowCenter = image.max() / 2
    dcm.LargestImagePixelValue = image.max()
    dcm[DCM_TAG_PIXEL_SCALE_FACTOR].value = scale
    return dcm


def change_uid(dcm: pydicom.dataset.FileDataset) -> pydicom.dataset.FileDataset:
    """
    Change the UIDs (SeriesInstance and SOPInstance) in a DICOM header so that
    it becomes its own unique file. Note that this method does not guarantee
    that the UIDs are fully unique. Since the creation of UIDs is time dependent,
    this function should not be used to rapidly change many UIDs.

    :param dcm: the DICOM file to be udpated
    :return: the DICOM file with updated UIDs
    """

    now = datetime.now()
    soon = now + timedelta(seconds=1)

    dcm.SeriesInstanceUID = UID_PREFIX + now.strftime("%Y%m%d%H%M%S")
    dcm.SOPInstanceUID = UID_PREFIX + soon.strftime("%Y%m%d%H%M%S")
    return dcm