Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 51 additions & 8 deletions src/app/endpoints/rlsapi_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,13 @@
from models.rlsapi.requests import RlsapiV1InferRequest, RlsapiV1SystemInfo
from models.rlsapi.responses import RlsapiV1InferData, RlsapiV1InferResponse
from observability import InferenceEventData, build_inference_event, send_splunk_event
from utils.query import handle_known_apistatus_errors
from utils.query import (
extract_provider_and_model_from_model_id,
handle_known_apistatus_errors,
)
from utils.responses import (
extract_text_from_response_items,
extract_token_usage,
get_mcp_tools,
)
from utils.suid import get_suid
Expand Down Expand Up @@ -191,6 +195,7 @@ async def retrieve_simple_response(
store=False,
)
response = cast(OpenAIResponseObject, response)
extract_token_usage(response.usage, model_id)

return extract_text_from_response_items(response.output)

Expand Down Expand Up @@ -242,6 +247,8 @@ def _record_inference_failure( # pylint: disable=too-many-arguments,too-many-po
request_id: str,
error: Exception,
start_time: float,
model: str,
provider: str,
) -> float:
"""Record metrics and queue Splunk event for an inference failure.

Expand All @@ -257,7 +264,7 @@ def _record_inference_failure( # pylint: disable=too-many-arguments,too-many-po
The total inference time in seconds.
"""
inference_time = time.monotonic() - start_time
metrics.llm_calls_failures_total.inc()
metrics.llm_calls_failures_total.labels(provider, model).inc()
_queue_splunk_event(
background_tasks,
infer_request,
Expand All @@ -272,7 +279,7 @@ def _record_inference_failure( # pylint: disable=too-many-arguments,too-many-po

@router.post("/infer", responses=infer_responses)
@authorize(Action.RLSAPI_V1_INFER)
async def infer_endpoint(
async def infer_endpoint( # pylint: disable=R0914
infer_request: RlsapiV1InferRequest,
request: Request,
background_tasks: BackgroundTasks,
Expand Down Expand Up @@ -307,6 +314,7 @@ async def infer_endpoint(
input_source = infer_request.get_input_source()
instructions = _build_instructions(infer_request.context.systeminfo)
model_id = _get_default_model_id()
provider, model = extract_provider_and_model_from_model_id(model_id)
mcp_tools = await get_mcp_tools(request_headers=request.headers)
logger.debug(
"Request %s: Combined input source length: %d", request_id, len(input_source)
Expand All @@ -321,19 +329,40 @@ async def infer_endpoint(
except RuntimeError as e:
if "context_length" in str(e).lower():
_record_inference_failure(
background_tasks, infer_request, request, request_id, e, start_time
background_tasks,
infer_request,
request,
request_id,
e,
start_time,
model,
provider,
)
logger.error("Prompt too long for request %s: %s", request_id, e)
error_response = PromptTooLongResponse(model=model_id)
raise HTTPException(**error_response.model_dump()) from e
_record_inference_failure(
background_tasks, infer_request, request, request_id, e, start_time
background_tasks,
infer_request,
request,
request_id,
e,
start_time,
model,
provider,
)
logger.error("Unexpected RuntimeError for request %s: %s", request_id, e)
raise
except APIConnectionError as e:
_record_inference_failure(
background_tasks, infer_request, request, request_id, e, start_time
background_tasks,
infer_request,
request,
request_id,
e,
start_time,
model,
provider,
)
logger.error(
"Unable to connect to Llama Stack for request %s: %s", request_id, e
Expand All @@ -345,7 +374,14 @@ async def infer_endpoint(
raise HTTPException(**error_response.model_dump()) from e
except RateLimitError as e:
_record_inference_failure(
background_tasks, infer_request, request, request_id, e, start_time
background_tasks,
infer_request,
request,
request_id,
e,
start_time,
model,
provider,
)
logger.error("Rate limit exceeded for request %s: %s", request_id, e)
error_response = QuotaExceededResponse(
Expand All @@ -355,7 +391,14 @@ async def infer_endpoint(
raise HTTPException(**error_response.model_dump()) from e
except (APIStatusError, OpenAIAPIStatusError) as e:
_record_inference_failure(
background_tasks, infer_request, request, request_id, e, start_time
background_tasks,
infer_request,
request,
request_id,
e,
start_time,
model,
provider,
)
logger.exception("API error for request %s: %s", request_id, e)
error_response = handle_known_apistatus_errors(e, model_id)
Expand Down
4 changes: 3 additions & 1 deletion src/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@
)

# Metric that counts how many LLM calls failed
llm_calls_failures_total = Counter("ls_llm_calls_failures_total", "LLM calls failures")
llm_calls_failures_total = Counter(
"ls_llm_calls_failures_total", "LLM calls failures", ["provider", "model"]
)

# Metric that counts how many LLM calls had validation errors
llm_calls_validation_errors_total = Counter(
Expand Down
2 changes: 1 addition & 1 deletion src/utils/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,7 +484,7 @@ def extract_provider_and_model_from_model_id(model_id: str) -> tuple[str, str]:
model_id: The model ID to extract from.

Returns:
tuple[str, str]: The model and provider.
tuple[str, str]: The provider and model.
"""
split = model_id.split("/", 1)
if len(split) == 2:
Expand Down
1 change: 0 additions & 1 deletion tests/unit/app/endpoints/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ async def test_metrics_endpoint(mocker: MockerFixture) -> None:
assert "# TYPE ls_provider_model_configuration gauge" in response_body
assert "# TYPE ls_llm_calls_total counter" in response_body
assert "# TYPE ls_llm_calls_failures_total counter" in response_body
assert "# TYPE ls_llm_calls_failures_created gauge" in response_body
assert "# TYPE ls_llm_validation_errors_total counter" in response_body
assert "# TYPE ls_llm_validation_errors_created gauge" in response_body
assert "# TYPE ls_llm_token_sent_total counter" in response_body
Expand Down
Loading