diff --git a/src/app/main.py b/src/app/main.py index 9aace6d3a..6038398ef 100644 --- a/src/app/main.py +++ b/src/app/main.py @@ -74,8 +74,6 @@ async def rest_api_metrics( if isinstance(route, (Mount, Route, WebSocketRoute)) ] -setup_model_metrics() - @app.on_event("startup") async def startup_event() -> None: @@ -83,4 +81,6 @@ async def startup_event() -> None: logger.info("Registering MCP servers") await register_mcp_servers_async(logger, configuration.configuration) get_logger("app.endpoints.handlers") + logger.info("Setting up model metrics") + await setup_model_metrics() logger.info("App startup complete") diff --git a/src/metrics/utils.py b/src/metrics/utils.py index 29e2bcce4..0dcd7b508 100644 --- a/src/metrics/utils.py +++ b/src/metrics/utils.py @@ -1,19 +1,24 @@ """Utility functions for metrics handling.""" from configuration import configuration -from client import LlamaStackClientHolder +from client import LlamaStackClientHolder, AsyncLlamaStackClientHolder from log import get_logger import metrics logger = get_logger(__name__) -def setup_model_metrics() -> None: +async def setup_model_metrics() -> None: """Perform setup of all metrics related to LLM model and provider.""" - client = LlamaStackClientHolder().get_client() + model_list = [] + if configuration.llama_stack_configuration.use_as_library_client: + model_list = await AsyncLlamaStackClientHolder().get_client().models.list() + else: + model_list = LlamaStackClientHolder().get_client().models.list() + models = [ model - for model in client.models.list() + for model in model_list if model.model_type == "llm" # pyright: ignore[reportAttributeAccessIssue] ] diff --git a/tests/unit/metrics/test_utis.py b/tests/unit/metrics/test_utis.py index e3e2c6ab4..295241cd6 100644 --- a/tests/unit/metrics/test_utis.py +++ b/tests/unit/metrics/test_utis.py @@ -3,7 +3,7 @@ from metrics.utils import setup_model_metrics -def test_setup_model_metrics(mocker): +async def test_setup_model_metrics(mocker): """Test the setup_model_metrics function.""" # Mock the LlamaStackAsLibraryClient @@ -51,7 +51,7 @@ def test_setup_model_metrics(mocker): model_1, ] - setup_model_metrics() + await setup_model_metrics() # Check that the provider_model_configuration metric was set correctly # The default model should have a value of 1, others should be 0