Coverage for instanovo/utils/colorlogging.py: 93%
27 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-12-08 07:26 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2025-12-08 07:26 +0000
1import logging
2import os
3from typing import Literal
5from rich.console import Console
6from rich.logging import RichHandler
8from instanovo.__init__ import get_rank
9from instanovo.constants import LOGGING_SHOW_PATH, LOGGING_SHOW_TIME, USE_RICH_HANDLER
12class DynamicRankFormatter(logging.Formatter):
13 """A formatter that dynamically includes rank information in the logger name."""
15 def __init__(
16 self,
17 fmt: str | None = None,
18 datefmt: str | None = None,
19 style: Literal["%", "{", "$"] = "%",
20 ) -> None:
21 super().__init__(fmt, datefmt, style)
23 def format(self, record: logging.LogRecord) -> str:
24 """Format a log record.
26 Args:
27 record (logging.LogRecord): The log record to format.
29 Returns:
30 str: The formatted log record
31 """
32 rank = get_rank()
34 # If we have a rank, append it to the logger name
35 if rank is not None:
36 record.msg = f"[RANK {rank}] {record.msg}"
38 return super().format(record)
41class ColorLog:
42 """A logging utility class that integrates with Rich for enhanced console output.
44 (based on https://stackoverflow.com/a/79225597)
45 """
47 def __init__(self, console: Console, name: str) -> None:
48 message_format = "%(message)s" # Include logger name in format
50 aichor_enabled = "AICHOR_LOGS_PATH" in os.environ
52 rich_handler = RichHandler(
53 console=console,
54 show_time=LOGGING_SHOW_TIME and not aichor_enabled,
55 show_path=LOGGING_SHOW_PATH and not aichor_enabled,
56 )
57 rich_handler.setLevel(logging.INFO)
58 rich_handler.setFormatter(DynamicRankFormatter(message_format))
60 if USE_RICH_HANDLER:
61 logging.basicConfig(
62 level=logging.NOTSET,
63 format=message_format,
64 datefmt="[%X]",
65 handlers=[rich_handler],
66 )
67 else:
68 logging.basicConfig(
69 level=logging.INFO,
70 format=message_format,
71 )
73 # Suppress INFO logs from the datasets package
74 logging.getLogger("datasets").setLevel(logging.ERROR)
76 self.logger = logging.getLogger(name)