"""
Module containing utilities for logging.
"""
import argparse
from dataclasses import dataclass
from datetime import datetime
import logging
from logging import Formatter, getLogger, StreamHandler
from logging.handlers import WatchedFileHandler
import os
import shutil
from typing import Dict, Optional, List

# default date format string
date_format = "%m/%d/%Y %I:%M:%S %p"
# default formatter for log messages
FORMATTER = Formatter(
    fmt="%(asctime)s - %(levelname)7s - %(name)s - %(message)s", datefmt=date_format
)


def add_logging_args(parser: argparse.ArgumentParser, defaults: Dict[str, str]):
    """
    Add logging arguments to an argument parser. This includes parameters for a
    filename and the log level.

    Parameters
    ----------
    parser: argparse.ArgumentParser
        the parser to which arguments are added
    defaults: Dict[str, str]
        default values for the arguments
        use the keys `--logfile` or `--logleve` to specify the defaults

    Returns
    -------
    argparse.ArgumentParser
        the modified parser
    """
    parser.add_argument(
        "--logfile",
        type=str,
        default=defaults.get("--logfile", None),
        help="the file information is logged to",
    )
    parser.add_argument(
        "--loglevel",
        type=str,
        default=defaults.get("--loglevel", "INFO"),
        choices=["DEBUG", "INFO", "WARNING", "ERROR"],
        help="the log level applied for printing logs",
    )
    return parser


def timestamp_filename(filename: str) -> str:
    """
    Attach a timestamp to a filename.

    The timestamp is attached as a postfix to the filename.
    E.g., `img.png` becomes `img_2023-02-20-12:59:53.png`.

    Parameters
    ----------
    filename: str

    Returns
    -------
    str
    """
    timestamp = datetime.now().strftime("%Y-%m-%d-%H:%M:%S")
    basename, ext = os.path.splitext(filename)
    return f"{basename}_{timestamp}{ext}"


def rotate_log_file(filename: str):
    """
    Rotate a log file.

    This means that the given file is moved to a filename
    with a timestamp (see `timestamp_filename`).

    Parameters
    ----------
    filename: str
        the log file to be rotated
    """
    if not os.path.isfile(filename):
        return

    shutil.move(filename, timestamp_filename(filename))


def get_logger(
    logfile: Optional[str] = None,
    loglevel: Optional[str] = None,
    name: Optional[str] = None,
) -> logging.Logger:
    """
    Create a new logger.

    If there is already a logger with the given name, it will
    be returned instead of creating a new one.

    Parameters
    ----------
    logfile: str, optional
    loglevel: str, otional
    name: str

    Returns
    -------
    logging.Logger
    """
    logger = getLogger() if name is None else getLogger(name)

    if logger.hasHandlers():
        # logger already exists, so refrain from adding more handlers
        return logger

    if loglevel:
        _level = getattr(logging, loglevel)
        logger.setLevel(_level)

    _handler = StreamHandler()
    _handler.setFormatter(FORMATTER)
    logger.addHandler(_handler)

    if logfile:
        rotate_log_file(logfile)
        _handler = WatchedFileHandler(logfile)
        _handler.setFormatter(FORMATTER)
        logger.addHandler(_handler)

    return logger


def get_logger_by_args(args: argparse.Namespace) -> logging.Logger:
    """
    Utility function to create a logger from arguments added
    with `add_logging_args`.

    Parameters
    ----------
    args: argparse.Namespace

    Returns
    -------
    logging.Logger
    """
    return get_logger(args.logfile, args.loglevel)


@dataclass
class LogLine:
    """
    Data class which represents the different parts of a single
    line in a log.
    """

    time: datetime
    loglevel: str
    message: str

    def __repr__(self):
        return (
            f"{self.time.strftime(date_format)} - {self.loglevel:>7} - {self.message}"
        )


def parse_line(logline: str, start_idx_message: int = 3) -> LogLine:
    """
    Parse a single line of a log into a structured LogLine object.

    Elements in a line are expected to be separated by the `-` character.
    The first element is the date time, the second the log level, the third
    is the logger name and the last is the message.
    Note that the third element is optional and that the message may
    contain additional `-` characters. This can be handled using the
    `start_idx_message` parameter.

    Parameters
    ----------
    logline: str
        line of a log
    start_idx_message: int, optional
        index at which split of the ling the message starts

    Returns
    -------
    LogLine
    """
    _split = logline.strip().split("-")
    assert (
        len(_split) >= start_idx_message
    ), f"A logged line should consists of a least {start_idx_message} elements with the format [TIME - LOGLEVEL - ... - MESSAGE] but got [{logline.strip()}]"

    time_str = _split[0].strip()
    time = datetime.strptime(time_str, date_format)

    loglevel = _split[1].strip()

    message = "-".join(_split[start_idx_message:]).strip()
    return LogLine(time=time, loglevel=loglevel, message=message)


def parse_file(logfile: str, start_idx_message: int = 3) -> List[LogLine]:
    """
    Parse a logfile into a list of `LogLine`.

    Parameters
    ----------
    logfile: str
    start_idx_message: int, optional
        see `parse_line`

    Returns
    -------
    list of LogLine
    """
    with open(logfile, mode="r") as f:
        lines = f.readlines()
    lines = map(lambda line: parse_line(line, start_idx_message), lines)
    return list(lines)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    add_logging_args(parser, defaults={"--loglevel": "DEBUG", "--logfile": "tmp.log"})
    args = parser.parse_args()
    print(args)

    logger = get_logger_by_args(args)

    logger.info("This is a test!")
    logger.warning("The training loss is over 9000!")
    logger.error("This is an error")
    logger.debug("The end")