"""
Dataset utility method.
Currently, the only available method aligns images at
their central slice.
"""
from typing import Tuple

import numpy as np


def align_images(
    image_1: np.ndarray, image_2: np.ndarray
) -> Tuple[np.ndarray, np.ndarray]:
    """
    Center align the image with more slices to the one with fewer slices on the first axis (z-axis).

    Parameters
    ----------
    image_1: np.ndarray
        the image to be aligned
    image_2: np.ndarray
        the image to which image_1 is aligned

    Returns
    -------
    Tuple[np.ndarray, np.ndarray]
        both images aligned in the order they were put in
    """
    # reverse function if image_2 has more slices
    if image_2.shape[0] > image_1.shape[0]:
        return align_images(image_2, image_1)[::-1]

    # central slice of image 2
    c_2 = image_2.shape[0] // 2
    # image to the left and right of the center
    left = c_2
    right = image_2.shape[0] - c_2

    # central slice of image 1
    c_1 = image_1.shape[0] // 2
    # select center and same amount to the left/right as image_2
    return image_1[(c_1 - left) : (c_1 + right)], image_2