Skip to content
Snippets Groups Projects
logging.py 2.19 KiB
import argparse
import datetime
import logging
import logging.handlers
import os
import shutil
from typing import Dict, Optional


FORMATTER = logging.Formatter(
    fmt="%(asctime)s - %(levelname)s: %(message)s", datefmt="%m/%d/%Y %I:%M:%S"
)


def add_logging_args(parser: argparse.ArgumentParser, defaults: Dict[str, str]):
    """
    Add logging arguments to an argument parser. This includes parameters for a
    filename logged to and the log level.
    """
    parser.add_argument(
        "--logfile",
        type=str,
        default=defaults.get("--logfile", None),
        help="the file information is logged to",
    )
    parser.add_argument(
        "--loglevel",
        type=str,
        default=defaults.get("--loglevel", "INFO"),
        choices=["DEBUG", "INFO", "WARNING", "ERROR"],
        help="the log level applied for printing logs",
    )
    return parser


def timestamp_filename(filename: str):
    timestamp = datetime.datetime.now().strftime("%Y-%m-%d-%H:%M:%S")
    basename, ext = os.path.splitext(filename)
    return f"{basename}_{timestamp}{ext}"


def rotate_log_file(filename: str):
    if not os.path.isfile(filename):
        return

    shutil.move(filename, timestamp_filename(filename))


def get_logger(logfile: Optional[str] = None, loglevel: Optional[str] = None):
    logger = logging.getLogger()

    if loglevel:
        _level = getattr(logging, loglevel)
        logger.setLevel(_level)

    _handler = logging.StreamHandler()
    _handler.setFormatter(FORMATTER)
    logger.addHandler(_handler)

    if logfile:
        rotate_log_file(logfile)
        _handler = logging.handlers.WatchedFileHandler(logfile)
        _handler.setFormatter(FORMATTER)
        logger.addHandler(_handler)

    return logger


def get_logger_by_args(args):
    return get_logger(args.logfile, args.loglevel)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    add_logging_args(parser, defaults={"--loglevel": "DEBUG", "--logfile": "tmp.log"})
    args = parser.parse_args()
    print(args)

    logger = get_logger_by_args(args)

    logger.info("This is a test!")
    logger.warning("The training loss is over 9000!")
    logger.error("This is an error")
    logger.info("The end")