diff --git a/model-engine/model_engine_server/api/tasks_v1.py b/model-engine/model_engine_server/api/tasks_v1.py index 524f2f462..663b3e0c2 100644 --- a/model-engine/model_engine_server/api/tasks_v1.py +++ b/model-engine/model_engine_server/api/tasks_v1.py @@ -18,6 +18,7 @@ from model_engine_server.core.loggers import logger_name, make_logger from model_engine_server.domain.exceptions import ( EndpointUnsupportedInferenceTypeException, + InvalidRequestException, ObjectNotAuthorizedException, ObjectNotFoundException, UpstreamServiceError, @@ -66,6 +67,11 @@ async def create_async_inference_task( status_code=400, detail=f"Unsupported inference type: {str(exc)}", ) from exc + except InvalidRequestException as exc: + raise HTTPException( + status_code=400, + detail=f"Invalid request: {str(exc)}", + ) from exc @inference_task_router_v1.get("/async-tasks/{task_id}", response_model=GetAsyncTaskV1Response) diff --git a/model-engine/model_engine_server/infra/gateways/celery_task_queue_gateway.py b/model-engine/model_engine_server/infra/gateways/celery_task_queue_gateway.py index 7a8f6911a..676a12743 100644 --- a/model-engine/model_engine_server/infra/gateways/celery_task_queue_gateway.py +++ b/model-engine/model_engine_server/infra/gateways/celery_task_queue_gateway.py @@ -1,5 +1,6 @@ from typing import Any, Dict, List, Optional +import botocore from model_engine_server.common.dtos.model_endpoints import BrokerType from model_engine_server.common.dtos.tasks import ( CreateAsyncTaskV1Response, @@ -9,6 +10,7 @@ from model_engine_server.core.celery import TaskVisibility, celery_app from model_engine_server.core.config import infra_config from model_engine_server.core.loggers import logger_name, make_logger +from model_engine_server.domain.exceptions import InvalidRequestException from model_engine_server.domain.gateways.task_queue_gateway import TaskQueueGateway logger = make_logger(logger_name()) @@ -68,12 +70,16 @@ def send_task( ) -> CreateAsyncTaskV1Response: celery_dest = self._get_celery_dest() - res = celery_dest.send_task( - name=task_name, - args=args, - kwargs=kwargs, - queue=queue_name, - ) + try: + res = celery_dest.send_task( + name=task_name, + args=args, + kwargs=kwargs, + queue=queue_name, + ) + except botocore.exceptions.ClientError as e: + logger.exception(f"Error sending task to queue {queue_name}: {e}") + raise InvalidRequestException(f"Error sending celery task: {e}") logger.info(f"Task {res.id} sent to queue {queue_name} from gateway") # pragma: no cover return CreateAsyncTaskV1Response(task_id=res.id) diff --git a/model-engine/tests/integration/inference/test_async_inference.py b/model-engine/tests/integration/inference/test_async_inference.py index e96164d7a..db9bc9a77 100644 --- a/model-engine/tests/integration/inference/test_async_inference.py +++ b/model-engine/tests/integration/inference/test_async_inference.py @@ -4,7 +4,9 @@ import subprocess from functools import lru_cache from typing import Any, List, Optional, Tuple +from unittest.mock import MagicMock +import botocore import pytest import redis import requests @@ -17,6 +19,7 @@ TaskStatus, ) from model_engine_server.common.env_vars import CIRCLECI +from model_engine_server.domain.exceptions import InvalidRequestException from model_engine_server.infra.gateways import ( CeleryTaskQueueGateway, LiveAsyncModelEndpointInferenceGateway, @@ -157,3 +160,24 @@ def test_async_callbacks( assert actual_payload == expected_callback_payload assert callback_stats["last_auth"][callback_version] == expected_credentials + + +def test_async_callbacks_botocore_exception( + queue: str, +): + gateway = CeleryTaskQueueGateway(broker_type=BrokerType.SQS) + + mock_dest = MagicMock() + mock_dest.send_task = MagicMock( + side_effect=botocore.exceptions.ClientError(error_response={}, operation_name="") + ) + mock_get = MagicMock() + mock_get.return_value = mock_dest + gateway._get_celery_dest = mock_get + + with pytest.raises(InvalidRequestException): + gateway.send_task( + task_name="test_task", + queue_name=queue, + args=[1, 2], + ) diff --git a/model-engine/tests/unit/api/test_tasks.py b/model-engine/tests/unit/api/test_tasks.py index f9a0f0620..3d019016c 100644 --- a/model-engine/tests/unit/api/test_tasks.py +++ b/model-engine/tests/unit/api/test_tasks.py @@ -4,6 +4,7 @@ from model_engine_server.common.dtos.tasks import EndpointPredictV1Request from model_engine_server.domain.entities import ModelBundle, ModelEndpoint from model_engine_server.domain.exceptions import ( + InvalidRequestException, ObjectNotAuthorizedException, ObjectNotFoundException, UpstreamServiceError, @@ -104,6 +105,43 @@ def test_create_async_task_raises_404_not_found( assert response.status_code == 404 +def test_create_async_task_raises_400_invalid_requests( + model_bundle_1_v1: Tuple[ModelBundle, Any], + model_endpoint_1: Tuple[ModelEndpoint, Any], + endpoint_predict_request_1: Tuple[EndpointPredictV1Request, Dict[str, Any]], + test_api_key: str, + get_test_client_wrapper, +): + assert model_endpoint_1[0].infra_state is not None + client = get_test_client_wrapper( + fake_docker_repository_image_always_exists=True, + fake_model_bundle_repository_contents={ + model_bundle_1_v1[0].id: model_bundle_1_v1[0], + }, + fake_model_endpoint_record_repository_contents={ + model_endpoint_1[0].record.id: model_endpoint_1[0].record, + }, + fake_model_endpoint_infra_gateway_contents={ + model_endpoint_1[0].infra_state.deployment_name: model_endpoint_1[0].infra_state, + }, + fake_batch_job_record_repository_contents={}, + fake_batch_job_progress_gateway_contents={}, + fake_docker_image_batch_job_bundle_repository_contents={}, + ) + mock_use_case = MagicMock() + mock_use_case.return_value.execute = MagicMock(side_effect=InvalidRequestException) + with patch( + "model_engine_server.api.tasks_v1.CreateAsyncInferenceTaskV1UseCase", + mock_use_case, + ): + response = client.post( + "/v1/async-tasks?model_endpoint_id=invalid_model_endpoint_id", + auth=(test_api_key, ""), + json=endpoint_predict_request_1[1], + ) + assert response.status_code == 400 + + def test_get_async_task_success( model_bundle_1_v1: Tuple[ModelBundle, Any], model_endpoint_1: Tuple[ModelEndpoint, Any],