import argparse
from dataclasses import dataclass
from datetime import datetime
import logging
from logging import Formatter, getLogger, StreamHandler
from logging.handlers import WatchedFileHandler
import os
import shutil
from typing import Dict, Optional, List

date_format="%m/%d/%Y %I:%M:%S"
FORMATTER = Formatter(
    fmt="%(asctime)s - %(levelname)7s - %(name)s - %(message)s", datefmt=date_format
)


def add_logging_args(parser: argparse.ArgumentParser, defaults: Dict[str, str]):
    """
    Add logging arguments to an argument parser. This includes parameters for a
    filename logged to and the log level.
    """
    parser.add_argument(
        "--logfile",
        type=str,
        default=defaults.get("--logfile", None),
        help="the file information is logged to",
    )
    parser.add_argument(
        "--loglevel",
        type=str,
        default=defaults.get("--loglevel", "INFO"),
        choices=["DEBUG", "INFO", "WARNING", "ERROR"],
        help="the log level applied for printing logs",
    )
    return parser


def timestamp_filename(filename: str):
    timestamp = datetime.now().strftime("%Y-%m-%d-%H:%M:%S")
    basename, ext = os.path.splitext(filename)
    return f"{basename}_{timestamp}{ext}"


def rotate_log_file(filename: str):
    if not os.path.isfile(filename):
        return

    shutil.move(filename, timestamp_filename(filename))


def get_logger(logfile: Optional[str] = None, loglevel: Optional[str] = None, name: Optional[str] = None):
    logger = getLogger() if name is None else getLogger(name)

    if loglevel:
        _level = getattr(logging, loglevel)
        logger.setLevel(_level)

    _handler = StreamHandler()
    _handler.setFormatter(FORMATTER)
    logger.addHandler(_handler)

    if logfile:
        rotate_log_file(logfile)
        _handler = WatchedFileHandler(logfile)
        _handler.setFormatter(FORMATTER)
        logger.addHandler(_handler)

    return logger


def get_logger_by_args(args):
    return get_logger(args.logfile, args.loglevel)


@dataclass
class LogLine:
    time: datetime
    loglevel: str
    message: str

    def __repr__(self):
        return f"{self.time.strftime(date_format)} - {self.loglevel:>7} - {self.message}"


def parse_line(logline):
    _split = logline.strip().split("-")
    assert len(_split) >= 3, f"A logged line should consists of a least three elements with the format [TIME - LOGLEVEL - MESSAGE] but got [{logline.strip()}]"

    time_str = _split[0].strip()
    time = datetime.strptime(time_str, date_format)

    loglevel = _split[1].strip()

    message = "-".join(_split[2:]).strip()
    return LogLine(time=time, loglevel=loglevel, message=message)

def parse_file(logfile: str) -> List[LogLine]:
    with open(logfile, mode="r") as f:
        lines = f.readlines()
    lines = map(parse_line, lines)
    return list(lines)

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    add_logging_args(parser, defaults={"--loglevel": "DEBUG", "--logfile": "tmp.log"})
    args = parser.parse_args()
    print(args)

    logger = get_logger_by_args(args)

    logger.info("This is a test!")
    logger.warning("The training loss is over 9000!")
    logger.error("This is an error")
    logger.debug("The end")