Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions src/app/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
logger = get_logger(__name__)

engine: Engine | None = None
SessionLocal: sessionmaker | None = None
session_local: sessionmaker | None = None


def get_engine() -> Engine:
Expand All @@ -33,11 +33,11 @@ def create_tables() -> None:

def get_session() -> Session:
"""Get a database session. Raises an error if not initialized."""
if SessionLocal is None:
if session_local is None:
raise RuntimeError(
"Database session not initialized. Call initialize_database() first."
)
return SessionLocal()
return session_local()


def _create_sqlite_engine(config: SQLiteDatabaseConfiguration, **kwargs: Any) -> Engine:
Expand Down Expand Up @@ -102,7 +102,7 @@ def initialize_database() -> None:
"""Initialize the database engine."""
db_config = configuration.database_configuration

global engine, SessionLocal # pylint: disable=global-statement
global engine, session_local # pylint: disable=global-statement

# Debug print all SQL statements if our logger is at-least DEBUG level
echo = bool(logger.isEnabledFor(logging.DEBUG))
Expand All @@ -126,4 +126,4 @@ def initialize_database() -> None:
assert isinstance(postgres_config, PostgreSQLDatabaseConfiguration)
engine = _create_postgres_engine(postgres_config, **create_engine_kwargs)

SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine)
12 changes: 6 additions & 6 deletions tests/unit/app/test_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,17 @@
def reset_database_state_fixture():
"""Reset global database state before and after tests."""
original_engine = database.engine
original_session_local = database.SessionLocal
original_session_local = database.session_local

# Reset state before test
database.engine = None
database.SessionLocal = None
database.session_local = None

yield

# Restore original state after test
database.engine = original_engine
database.SessionLocal = original_session_local
database.session_local = original_session_local


@pytest.fixture(name="base_postgres_config")
Expand Down Expand Up @@ -72,7 +72,7 @@ def test_get_session_when_initialized(self, mocker):
mock_session_local = mocker.MagicMock()
mock_session = mocker.MagicMock(spec=Session)
mock_session_local.return_value = mock_session
database.SessionLocal = mock_session_local
database.session_local = mock_session_local

result = database.get_session()

Expand All @@ -81,7 +81,7 @@ def test_get_session_when_initialized(self, mocker):

def test_get_session_when_not_initialized(self):
"""Test get_session raises RuntimeError when not initialized."""
database.SessionLocal = None
database.session_local = None

with pytest.raises(RuntimeError, match="Database session not initialized"):
database.get_session()
Expand Down Expand Up @@ -257,7 +257,7 @@ def _verify_common_assertions(
autocommit=False, autoflush=False, bind=mock_engine
)
assert database.engine is mock_engine
assert database.SessionLocal is mock_session_local
assert database.session_local is mock_session_local

def test_initialize_database_sqlite(
self,
Expand Down
Loading