diff --git a/model-engine/model_engine_server/inference/forwarding/celery_forwarder.py b/model-engine/model_engine_server/inference/forwarding/celery_forwarder.py index 6206f711e..9ed5e4ddf 100644 --- a/model-engine/model_engine_server/inference/forwarding/celery_forwarder.py +++ b/model-engine/model_engine_server/inference/forwarding/celery_forwarder.py @@ -5,6 +5,7 @@ from celery import Celery, Task, states from model_engine_server.common.constants import DEFAULT_CELERY_TASK_NAME, LIRA_CELERY_TASK_NAME from model_engine_server.common.dtos.model_endpoints import BrokerType +from model_engine_server.common.dtos.tasks import EndpointPredictV1Request from model_engine_server.core.celery import TaskVisibility, celery_app from model_engine_server.core.config import infra_config from model_engine_server.core.loggers import logger_name, make_logger @@ -25,45 +26,6 @@ class ErrorResponse(TypedDict): error_metadata: str -class ErrorHandlingTask(Task): - """Sets a 'custom' field with error in the Task response for FAILURE. - - Used when services are ran via the Celery backend. - """ - - def after_return( - self, status: str, retval: Union[dict, Exception], task_id: str, args, kwargs, einfo - ) -> None: - """Handler that ensures custom error response information is available whenever a Task fails. - - Specifically, whenever the task's :param:`status` is `"FAILURE"` and the return value - :param:`retval` is an `Exception`, this handler extracts information from the `Exception` - and constructs a custom error response JSON value (see :func:`error_response` for details). - - This handler then re-propagates the Celery-required exception information (`"exc_type"` and - `"exc_message"`) while adding this new error response information under the `"custom"` key. - """ - if status == states.FAILURE and isinstance(retval, Exception): - logger.warning(f"Setting custom error response for failed task {task_id}") - - info: dict = raw_celery_response(self.backend, task_id) - result: dict = info["result"] - err: Exception = retval - - error_payload = error_response("Internal failure", err) - - # Inspired by pattern from: - # https://www.distributedpython.com/2018/09/28/celery-task-states/ - self.update_state( - state=states.FAILURE, - meta={ - "exc_type": result["exc_type"], - "exc_message": result["exc_message"], - "custom": json.dumps(error_payload, indent=False), - }, - ) - - def raw_celery_response(backend, task_id: str) -> Dict[str, Any]: key_info: str = backend.get_key_for_task(task_id) info_as_str: str = backend.get(key_info) @@ -103,6 +65,47 @@ def create_celery_service( else None, ) + class ErrorHandlingTask(Task): + """Sets a 'custom' field with error in the Task response for FAILURE. + + Used when services are ran via the Celery backend. + """ + + def after_return( + self, status: str, retval: Union[dict, Exception], task_id: str, args, kwargs, einfo + ) -> None: + """Handler that ensures custom error response information is available whenever a Task fails. + + Specifically, whenever the task's :param:`status` is `"FAILURE"` and the return value + :param:`retval` is an `Exception`, this handler extracts information from the `Exception` + and constructs a custom error response JSON value (see :func:`error_response` for details). + + This handler then re-propagates the Celery-required exception information (`"exc_type"` and + `"exc_message"`) while adding this new error response information under the `"custom"` key. + """ + if status == states.FAILURE and isinstance(retval, Exception): + logger.warning(f"Setting custom error response for failed task {task_id}") + + info: dict = raw_celery_response(self.backend, task_id) + result: dict = info["result"] + err: Exception = retval + + error_payload = error_response("Internal failure", err) + + # Inspired by pattern from: + # https://www.distributedpython.com/2018/09/28/celery-task-states/ + self.update_state( + state=states.FAILURE, + meta={ + "exc_type": result["exc_type"], + "exc_message": result["exc_message"], + "custom": json.dumps(error_payload, indent=False), + }, + ) + request_params = args[0] + request_params_pydantic = EndpointPredictV1Request.parse_obj(request_params) + forwarder.post_inference_hooks_handler.handle(request_params_pydantic, retval, task_id) # type: ignore + # See documentation for options: # https://docs.celeryproject.org/en/stable/userguide/tasks.html#list-of-options @app.task(base=ErrorHandlingTask, name=LIRA_CELERY_TASK_NAME, track_started=True) diff --git a/model-engine/model_engine_server/inference/forwarding/forwarding.py b/model-engine/model_engine_server/inference/forwarding/forwarding.py index 099fe7d4c..4bbe885dc 100644 --- a/model-engine/model_engine_server/inference/forwarding/forwarding.py +++ b/model-engine/model_engine_server/inference/forwarding/forwarding.py @@ -9,7 +9,6 @@ import sseclient import yaml from fastapi.responses import JSONResponse -from model_engine_server.common.dtos.tasks import EndpointPredictV1Request from model_engine_server.core.loggers import logger_name, make_logger from model_engine_server.inference.common import get_endpoint_config from model_engine_server.inference.infra.gateways.datadog_inference_monitoring_metrics_gateway import ( @@ -126,7 +125,6 @@ class Forwarder(ModelEngineSerializationMixin): forward_http_status: bool def __call__(self, json_payload: Any) -> Any: - request_obj = EndpointPredictV1Request.parse_obj(json_payload) json_payload, using_serialize_results_as_string = self.unwrap_json_payload(json_payload) json_payload_repr = json_payload.keys() if hasattr(json_payload, "keys") else json_payload @@ -163,8 +161,6 @@ def __call__(self, json_payload: Any) -> Any: if self.wrap_response: response = self.get_response_payload(using_serialize_results_as_string, response) - # TODO: we actually want to do this after we've returned the response. - self.post_inference_hooks_handler.handle(request_obj, response) if self.forward_http_status: return JSONResponse(content=response, status_code=response_raw.status_code) else: diff --git a/model-engine/model_engine_server/inference/forwarding/http_forwarder.py b/model-engine/model_engine_server/inference/forwarding/http_forwarder.py index f121bec2b..1fdb030ba 100644 --- a/model-engine/model_engine_server/inference/forwarding/http_forwarder.py +++ b/model-engine/model_engine_server/inference/forwarding/http_forwarder.py @@ -4,7 +4,7 @@ import subprocess from functools import lru_cache -from fastapi import Depends, FastAPI +from fastapi import BackgroundTasks, Depends, FastAPI from model_engine_server.common.concurrency_limiter import MultiprocessingConcurrencyLimiter from model_engine_server.common.dtos.tasks import EndpointPredictV1Request from model_engine_server.core.loggers import logger_name, make_logger @@ -70,11 +70,20 @@ def load_streaming_forwarder(): @app.post("/predict") def predict( request: EndpointPredictV1Request, + background_tasks: BackgroundTasks, forwarder=Depends(load_forwarder), limiter=Depends(get_concurrency_limiter), ): with limiter: - return forwarder(request.dict()) + try: + response = forwarder(request.dict()) + background_tasks.add_task( + forwarder.post_inference_hooks_handler.handle, request, response + ) + return response + except Exception: + logger.error(f"Failed to decode payload from: {request}") + raise @app.post("/stream") diff --git a/model-engine/model_engine_server/inference/post_inference_hooks.py b/model-engine/model_engine_server/inference/post_inference_hooks.py index 05dba3065..142e998e8 100644 --- a/model-engine/model_engine_server/inference/post_inference_hooks.py +++ b/model-engine/model_engine_server/inference/post_inference_hooks.py @@ -1,7 +1,9 @@ +import json from abc import ABC, abstractmethod -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Union import requests +from fastapi.responses import JSONResponse from model_engine_server.common.constants import CALLBACK_POST_INFERENCE_HOOK from model_engine_server.common.dtos.tasks import EndpointPredictV1Request from model_engine_server.core.loggers import logger_name, make_logger @@ -108,13 +110,17 @@ def __init__( def handle( self, request_payload: EndpointPredictV1Request, - response: Dict[str, Any], + response: Union[Dict[str, Any], JSONResponse], task_id: Optional[str] = None, ): + if isinstance(response, JSONResponse): + loaded_response = json.loads(response.body) + else: + loaded_response = response for hook_name, hook in self._hooks.items(): self._monitoring_metrics_gateway.emit_attempted_post_inference_hook(hook_name) try: - hook.handle(request_payload, response, task_id) + hook.handle(request_payload, loaded_response, task_id) # pragma: no cover self._monitoring_metrics_gateway.emit_successful_post_inference_hook(hook_name) except Exception: logger.exception(f"Hook {hook_name} failed.") diff --git a/model-engine/tests/unit/inference/test_http_forwarder.py b/model-engine/tests/unit/inference/test_http_forwarder.py index 43fbdfbd3..1ded1624d 100644 --- a/model-engine/tests/unit/inference/test_http_forwarder.py +++ b/model-engine/tests/unit/inference/test_http_forwarder.py @@ -1,12 +1,23 @@ import threading -import time +from dataclasses import dataclass +from typing import Mapping +from unittest import mock import pytest +from fastapi import BackgroundTasks +from fastapi.responses import JSONResponse from model_engine_server.common.dtos.tasks import EndpointPredictV1Request +from model_engine_server.inference.forwarding.forwarding import Forwarder from model_engine_server.inference.forwarding.http_forwarder import ( MultiprocessingConcurrencyLimiter, predict, ) +from model_engine_server.inference.infra.gateways.datadog_inference_monitoring_metrics_gateway import ( + DatadogInferenceMonitoringMetricsGateway, +) +from model_engine_server.inference.post_inference_hooks import PostInferenceHooksHandler + +PAYLOAD: Mapping[str, str] = {"hello": "world"} class ExceptionCapturedThread(threading.Thread): @@ -26,21 +37,90 @@ def join(self): raise self.ex -def mock_forwarder(dict): - time.sleep(1) - return dict +def mocked_get(*args, **kwargs): # noqa + @dataclass + class mocked_static_status_code: + status_code: int = 200 + + return mocked_static_status_code() + + +def mocked_post(*args, **kwargs): # noqa + @dataclass + class mocked_static_json: + status_code: int = 200 + + def json(self) -> dict: + return PAYLOAD # type: ignore + + return mocked_static_json() + + +@pytest.fixture +def post_inference_hooks_handler(): + handler = PostInferenceHooksHandler( + endpoint_name="test_endpoint_name", + bundle_name="test_bundle_name", + post_inference_hooks=[], + user_id="test_user_id", + billing_queue="billing_queue", + billing_tags=[], + default_callback_url=None, + default_callback_auth=None, + monitoring_metrics_gateway=DatadogInferenceMonitoringMetricsGateway(), + ) + return handler + +@pytest.fixture +def mock_request(): + return EndpointPredictV1Request( + url="test_url", + return_pickled=False, + args={"x": 1}, + ) -def test_http_service_429(): + +@mock.patch("requests.post", mocked_post) +@mock.patch("requests.get", mocked_get) +def test_http_service_429(mock_request, post_inference_hooks_handler): + mock_forwarder = Forwarder( + "ignored", + model_engine_unwrap=True, + serialize_results_as_string=False, + post_inference_hooks_handler=post_inference_hooks_handler, + wrap_response=True, + forward_http_status=True, + ) limiter = MultiprocessingConcurrencyLimiter(1, True) t1 = ExceptionCapturedThread( - target=predict, args=(EndpointPredictV1Request(), mock_forwarder, limiter) + target=predict, args=(mock_request, BackgroundTasks(), mock_forwarder, limiter) ) t2 = ExceptionCapturedThread( - target=predict, args=(EndpointPredictV1Request(), mock_forwarder, limiter) + target=predict, args=(mock_request, BackgroundTasks(), mock_forwarder, limiter) ) t1.start() t2.start() t1.join() with pytest.raises(Exception): # 429 thrown t2.join() + + +def test_handler_response(post_inference_hooks_handler): + try: + post_inference_hooks_handler.handle( + request_payload=mock_request, response=PAYLOAD, task_id="test_task_id" + ) + except Exception as e: + pytest.fail(f"Unexpected exception: {e}") + + +def test_handler_json_response(post_inference_hooks_handler): + try: + post_inference_hooks_handler.handle( + request_payload=mock_request, + response=JSONResponse(content=PAYLOAD), + task_id="test_task_id", + ) + except Exception as e: + pytest.fail(f"Unexpected exception: {e}")