diff --git a/charts/model-engine/templates/gateway_deployment.yaml b/charts/model-engine/templates/gateway_deployment.yaml index a58717a33..e52833193 100644 --- a/charts/model-engine/templates/gateway_deployment.yaml +++ b/charts/model-engine/templates/gateway_deployment.yaml @@ -49,13 +49,6 @@ spec: port: 5000 periodSeconds: 2 failureThreshold: 30 - livenessProbe: - httpGet: - path: /healthz - port: 5000 - initialDelaySeconds: 5 - periodSeconds: 2 - failureThreshold: 10 command: - dumb-init - -- diff --git a/model-engine/model_engine_server/api/app.py b/model-engine/model_engine_server/api/app.py index a26a8ddc6..90f5620c8 100644 --- a/model-engine/model_engine_server/api/app.py +++ b/model-engine/model_engine_server/api/app.py @@ -5,7 +5,7 @@ from pathlib import Path import pytz -from fastapi import FastAPI, Request, Response +from fastapi import FastAPI, HTTPException, Request, Response from fastapi.responses import JSONResponse from fastapi.staticfiles import StaticFiles from model_engine_server.api.batch_jobs_v1 import batch_job_router_v1 @@ -21,6 +21,7 @@ from model_engine_server.api.model_endpoints_v1 import model_endpoint_router_v1 from model_engine_server.api.tasks_v1 import inference_task_router_v1 from model_engine_server.api.triggers_v1 import trigger_router_v1 +from model_engine_server.common.concurrency_limiter import MultiprocessingConcurrencyLimiter from model_engine_server.core.loggers import ( LoggerTagKey, LoggerTagManager, @@ -32,12 +33,34 @@ logger = make_logger(logger_name()) +# Allows us to make the Uvicorn worker concurrency in model_engine_server/api/worker.py very high +MAX_CONCURRENCY = 500 + +concurrency_limiter = MultiprocessingConcurrencyLimiter( + concurrency=MAX_CONCURRENCY, fail_on_concurrency_limit=True +) + +healthcheck_routes = ["/healthcheck", "/healthz", "/readyz"] + class CustomMiddleware(BaseHTTPMiddleware): async def dispatch(self, request: Request, call_next): try: LoggerTagManager.set(LoggerTagKey.REQUEST_ID, str(uuid.uuid4())) - return await call_next(request) + # we intentionally exclude healthcheck routes from the concurrency limiter + if request.url.path in healthcheck_routes: + return await call_next(request) + with concurrency_limiter: + return await call_next(request) + except HTTPException as e: + timestamp = datetime.now(pytz.timezone("US/Pacific")).strftime("%Y-%m-%d %H:%M:%S %Z") + return JSONResponse( + status_code=e.status_code, + content={ + "error": e.detail, + "timestamp": timestamp, + }, + ) except Exception as e: tb_str = traceback.format_exception(etype=type(e), value=e, tb=e.__traceback__) request_id = LoggerTagManager.get(LoggerTagKey.REQUEST_ID) @@ -49,14 +72,12 @@ async def dispatch(self, request: Request, call_next): } logger.error("Unhandled exception: %s", structured_log) return JSONResponse( - { - "status_code": 500, - "content": { - "error": "Internal error occurred. Our team has been notified.", - "timestamp": timestamp, - "request_id": request_id, - }, - } + status_code=500, + content={ + "error": "Internal error occurred. Our team has been notified.", + "timestamp": timestamp, + "request_id": request_id, + }, ) @@ -91,9 +112,10 @@ def load_redis(): get_or_create_aioredis_pool() -@app.get("/healthcheck") -@app.get("/healthz") -@app.get("/readyz") def healthcheck() -> Response: """Returns 200 if the app is healthy.""" return Response(status_code=200) + + +for endpoint in healthcheck_routes: + app.get(endpoint)(healthcheck) diff --git a/model-engine/model_engine_server/api/worker.py b/model-engine/model_engine_server/api/worker.py index d08113b51..289640c88 100644 --- a/model-engine/model_engine_server/api/worker.py +++ b/model-engine/model_engine_server/api/worker.py @@ -1,8 +1,9 @@ from uvicorn.workers import UvicornWorker -# Gunicorn returns 503 instead of 429 when concurrency exceeds the limit, before adding rate limiting just increase the concurrency +# Gunicorn returns 503 instead of 429 when concurrency exceeds the limit # We'll autoscale at target concurrency of a much lower number (around 50), and this just makes sure we don't 503 with bursty traffic -CONCURRENCY_LIMIT = 1000 +# We set this very high since model_engine_server/api/app.py sets a lower per-pod concurrency at which we start returning 429s +CONCURRENCY_LIMIT = 10000 class LaunchWorker(UvicornWorker): diff --git a/model-engine/model_engine_server/common/concurrency_limiter.py b/model-engine/model_engine_server/common/concurrency_limiter.py new file mode 100644 index 000000000..b4e10c814 --- /dev/null +++ b/model-engine/model_engine_server/common/concurrency_limiter.py @@ -0,0 +1,36 @@ +from multiprocessing import BoundedSemaphore +from multiprocessing.synchronize import BoundedSemaphore as BoundedSemaphoreType +from typing import Optional + +from fastapi import HTTPException +from model_engine_server.core.loggers import logger_name, make_logger + +logger = make_logger(logger_name()) + + +class MultiprocessingConcurrencyLimiter: + def __init__(self, concurrency: Optional[int], fail_on_concurrency_limit: bool): + self.concurrency = concurrency + if concurrency is not None: + if concurrency < 1: + raise ValueError("Concurrency should be at least 1") + self.semaphore: Optional[BoundedSemaphoreType] = BoundedSemaphore(value=concurrency) + self.blocking = ( + not fail_on_concurrency_limit + ) # we want to block if we want to queue up requests + else: + self.semaphore = None + self.blocking = False # Unused + + def __enter__(self): + logger.debug("Entering concurrency limiter semaphore") + if self.semaphore and not self.semaphore.acquire(block=self.blocking): + logger.warning(f"Too many requests (max {self.concurrency}), returning 429") + raise HTTPException(status_code=429, detail="Too many requests") + # Just raises an HTTPException. + # __exit__ should not run; otherwise the release() doesn't have an acquire() + + def __exit__(self, type, value, traceback): + logger.debug("Exiting concurrency limiter semaphore") + if self.semaphore: + self.semaphore.release() diff --git a/model-engine/model_engine_server/inference/forwarding/http_forwarder.py b/model-engine/model_engine_server/inference/forwarding/http_forwarder.py index 5943bc50a..f121bec2b 100644 --- a/model-engine/model_engine_server/inference/forwarding/http_forwarder.py +++ b/model-engine/model_engine_server/inference/forwarding/http_forwarder.py @@ -3,11 +3,9 @@ import os import subprocess from functools import lru_cache -from multiprocessing import BoundedSemaphore -from multiprocessing.synchronize import BoundedSemaphore as BoundedSemaphoreType -from typing import Optional -from fastapi import Depends, FastAPI, HTTPException +from fastapi import Depends, FastAPI +from model_engine_server.common.concurrency_limiter import MultiprocessingConcurrencyLimiter from model_engine_server.common.dtos.tasks import EndpointPredictV1Request from model_engine_server.core.loggers import logger_name, make_logger from model_engine_server.inference.forwarding.forwarding import ( @@ -21,33 +19,6 @@ app = FastAPI() -class MultiprocessingConcurrencyLimiter: - def __init__(self, concurrency: Optional[int], fail_on_concurrency_limit: bool): - if concurrency is not None: - if concurrency < 1: - raise ValueError("Concurrency should be at least 1") - self.semaphore: Optional[BoundedSemaphoreType] = BoundedSemaphore(value=concurrency) - self.blocking = ( - not fail_on_concurrency_limit - ) # we want to block if we want to queue up requests - else: - self.semaphore = None - self.blocking = False # Unused - - def __enter__(self): - logger.debug("Entering concurrency limiter semaphore") - if self.semaphore and not self.semaphore.acquire(block=self.blocking): - logger.warning("Too many requests, returning 429") - raise HTTPException(status_code=429, detail="Too many requests") - # Just raises an HTTPException. - # __exit__ should not run; otherwise the release() doesn't have an acquire() - - def __exit__(self, type, value, traceback): - logger.debug("Exiting concurrency limiter semaphore") - if self.semaphore: - self.semaphore.release() - - @app.get("/healthz") @app.get("/readyz") def healthcheck(): diff --git a/model-engine/model_engine_server/inference/sync_inference/fastapi_server.py b/model-engine/model_engine_server/inference/sync_inference/fastapi_server.py index bec1c50cf..02b68ecaf 100644 --- a/model-engine/model_engine_server/inference/sync_inference/fastapi_server.py +++ b/model-engine/model_engine_server/inference/sync_inference/fastapi_server.py @@ -1,10 +1,8 @@ import traceback from functools import wraps -from multiprocessing import BoundedSemaphore -from multiprocessing.synchronize import BoundedSemaphore as BoundedSemaphoreType -from typing import Optional from fastapi import BackgroundTasks, FastAPI, HTTPException, Response, status +from model_engine_server.common.concurrency_limiter import MultiprocessingConcurrencyLimiter from model_engine_server.common.dtos.tasks import EndpointPredictV1Request from model_engine_server.core.loggers import logger_name, make_logger from model_engine_server.inference.common import ( @@ -25,33 +23,6 @@ logger = make_logger(logger_name()) -class MultiprocessingConcurrencyLimiter: - def __init__(self, concurrency: Optional[int], fail_on_concurrency_limit: bool): - if concurrency is not None: - if concurrency < 1: - raise ValueError("Concurrency should be at least 1") - self.semaphore: Optional[BoundedSemaphoreType] = BoundedSemaphore(value=concurrency) - self.blocking = ( - not fail_on_concurrency_limit - ) # we want to block if we want to queue up requests - else: - self.semaphore = None - self.blocking = False # Unused - - def __enter__(self): - logger.debug("Entering concurrency limiter semaphore") - if self.semaphore and not self.semaphore.acquire(block=self.blocking): - logger.warning("Too many requests, returning 429") - raise HTTPException(status_code=429, detail="Too many requests") - # Just raises an HTTPException. - # __exit__ should not run; otherwise the release() doesn't have an acquire() - - def __exit__(self, type, value, traceback): - logger.debug("Exiting concurrency limiter semaphore") - if self.semaphore: - self.semaphore.release() - - def with_concurrency_limit(concurrency_limiter: MultiprocessingConcurrencyLimiter): def _inner(flask_func): @wraps(flask_func)