diff --git a/src/python_template/logger/logger.py b/src/python_template/logger/logger.py index 2228cd2..7b6e465 100644 --- a/src/python_template/logger/logger.py +++ b/src/python_template/logger/logger.py @@ -154,7 +154,7 @@ class LoggerMixin: # pylint: disable=too-few-public-methods _lock_logger = threading.Lock() # Add a lock for thread safety. _loggers: dict = {} - def __init__( + def __init__( # pylint: disable=too-many-branches,too-many-statements self, logger_filename: str | Path | None = None, logger_format: str | None = None, @@ -183,6 +183,10 @@ def __init__( LoggerMixin._lock_logger ): # Acquire the lock before making changes. # Get call stack info: + if hasattr(self, "logger"): + if DEBUG_PRINTS: + print(f"{type(self).__name__} already has a logger.") + return stack = traceback.extract_stack() caller = stack[-2] # The frame that called this __init__. caller_info = f"{caller.filename}:{caller.lineno} in {caller.name}" diff --git a/tests/test_logger/test_logger.py b/tests/test_logger/test_logger.py index 75002ee..24f7030 100644 --- a/tests/test_logger/test_logger.py +++ b/tests/test_logger/test_logger.py @@ -1,10 +1,11 @@ """This file contains tests for the LoggerMixin class.""" import logging +import os from datetime import datetime from pathlib import Path -# from python_template.logger import LoggerMixin +from python_template.logger import LoggerMixin from python_template.logger.logger import DebugCategoryNameFilter TEST_TEXT = "Some test text." @@ -166,3 +167,24 @@ def test_debug_category_verbosity_help(create_test_debug_category): assert lines[3] == level_3_text assert lines[4] == level_4_text assert lines[5] == level_5_text + + +def test_logger_duplication_guard(): + """Tests the duplication guard in LoggerMixin.""" + + class TestClass(LoggerMixin): # pylint: disable=too-few-public-methods + """A test class that inherits LoggerMixin.""" + + def __init__(self): + super().__init__() + # self.old_logger = self.logger.copy() + self.old_logger_filename = self.logger_filename + LoggerMixin.__init__(self) + + test_class = TestClass() + # assert test_class.old_logger == test_class.logger + assert test_class.old_logger_filename == test_class.logger_filename + for handler in test_class.logger.handlers: + handler.close() + if os.path.exists(test_class.logger_filename): + os.remove(test_class.logger_filename)