Skip to content

Commit befedd4

Browse files
committed
Add labels to the /v1/infer failure metrics
Get the provider and model in order to pass that to _record_inference_failure. Add model and provider labels to the Counter.
1 parent 8212b0a commit befedd4

3 files changed

Lines changed: 51 additions & 9 deletions

File tree

src/app/endpoints/rlsapi_v1.py

Lines changed: 48 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,10 @@
3333
from models.rlsapi.requests import RlsapiV1InferRequest, RlsapiV1SystemInfo
3434
from models.rlsapi.responses import RlsapiV1InferData, RlsapiV1InferResponse
3535
from observability import InferenceEventData, build_inference_event, send_splunk_event
36-
from utils.query import handle_known_apistatus_errors
36+
from utils.query import (
37+
extract_provider_and_model_from_model_id,
38+
handle_known_apistatus_errors,
39+
)
3740
from utils.responses import (
3841
extract_text_from_response_items,
3942
extract_token_usage,
@@ -244,6 +247,8 @@ def _record_inference_failure( # pylint: disable=too-many-arguments,too-many-po
244247
request_id: str,
245248
error: Exception,
246249
start_time: float,
250+
model: str,
251+
provider: str,
247252
) -> float:
248253
"""Record metrics and queue Splunk event for an inference failure.
249254
@@ -259,7 +264,7 @@ def _record_inference_failure( # pylint: disable=too-many-arguments,too-many-po
259264
The total inference time in seconds.
260265
"""
261266
inference_time = time.monotonic() - start_time
262-
metrics.llm_calls_failures_total.inc()
267+
metrics.llm_calls_failures_total.labels(provider, model).inc()
263268
_queue_splunk_event(
264269
background_tasks,
265270
infer_request,
@@ -309,6 +314,7 @@ async def infer_endpoint(
309314
input_source = infer_request.get_input_source()
310315
instructions = _build_instructions(infer_request.context.systeminfo)
311316
model_id = _get_default_model_id()
317+
model, provider = extract_provider_and_model_from_model_id(model_id)
312318
mcp_tools = await get_mcp_tools(request_headers=request.headers)
313319
logger.debug(
314320
"Request %s: Combined input source length: %d", request_id, len(input_source)
@@ -323,19 +329,40 @@ async def infer_endpoint(
323329
except RuntimeError as e:
324330
if "context_length" in str(e).lower():
325331
_record_inference_failure(
326-
background_tasks, infer_request, request, request_id, e, start_time
332+
background_tasks,
333+
infer_request,
334+
request,
335+
request_id,
336+
e,
337+
start_time,
338+
model,
339+
provider,
327340
)
328341
logger.error("Prompt too long for request %s: %s", request_id, e)
329342
error_response = PromptTooLongResponse(model=model_id)
330343
raise HTTPException(**error_response.model_dump()) from e
331344
_record_inference_failure(
332-
background_tasks, infer_request, request, request_id, e, start_time
345+
background_tasks,
346+
infer_request,
347+
request,
348+
request_id,
349+
e,
350+
start_time,
351+
model,
352+
provider,
333353
)
334354
logger.error("Unexpected RuntimeError for request %s: %s", request_id, e)
335355
raise
336356
except APIConnectionError as e:
337357
_record_inference_failure(
338-
background_tasks, infer_request, request, request_id, e, start_time
358+
background_tasks,
359+
infer_request,
360+
request,
361+
request_id,
362+
e,
363+
start_time,
364+
model,
365+
provider,
339366
)
340367
logger.error(
341368
"Unable to connect to Llama Stack for request %s: %s", request_id, e
@@ -347,7 +374,14 @@ async def infer_endpoint(
347374
raise HTTPException(**error_response.model_dump()) from e
348375
except RateLimitError as e:
349376
_record_inference_failure(
350-
background_tasks, infer_request, request, request_id, e, start_time
377+
background_tasks,
378+
infer_request,
379+
request,
380+
request_id,
381+
e,
382+
start_time,
383+
model,
384+
provider,
351385
)
352386
logger.error("Rate limit exceeded for request %s: %s", request_id, e)
353387
error_response = QuotaExceededResponse(
@@ -357,7 +391,14 @@ async def infer_endpoint(
357391
raise HTTPException(**error_response.model_dump()) from e
358392
except (APIStatusError, OpenAIAPIStatusError) as e:
359393
_record_inference_failure(
360-
background_tasks, infer_request, request, request_id, e, start_time
394+
background_tasks,
395+
infer_request,
396+
request,
397+
request_id,
398+
e,
399+
start_time,
400+
model,
401+
provider,
361402
)
362403
logger.exception("API error for request %s: %s", request_id, e)
363404
error_response = handle_known_apistatus_errors(e, model_id)

src/metrics/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,9 @@
3333
)
3434

3535
# Metric that counts how many LLM calls failed
36-
llm_calls_failures_total = Counter("ls_llm_calls_failures_total", "LLM calls failures")
36+
llm_calls_failures_total = Counter(
37+
"ls_llm_calls_failures_total", "LLM calls failures", ["provider", "model"]
38+
)
3739

3840
# Metric that counts how many LLM calls had validation errors
3941
llm_calls_validation_errors_total = Counter(

tests/unit/app/endpoints/test_metrics.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@ async def test_metrics_endpoint(mocker: MockerFixture) -> None:
4242
assert "# TYPE ls_provider_model_configuration gauge" in response_body
4343
assert "# TYPE ls_llm_calls_total counter" in response_body
4444
assert "# TYPE ls_llm_calls_failures_total counter" in response_body
45-
assert "# TYPE ls_llm_calls_failures_created gauge" in response_body
4645
assert "# TYPE ls_llm_validation_errors_total counter" in response_body
4746
assert "# TYPE ls_llm_validation_errors_created gauge" in response_body
4847
assert "# TYPE ls_llm_token_sent_total counter" in response_body

0 commit comments

Comments
 (0)