diff --git a/model-engine/model_engine_server/api/app.py b/model-engine/model_engine_server/api/app.py index b3a41dfdb..a26a8ddc6 100644 --- a/model-engine/model_engine_server/api/app.py +++ b/model-engine/model_engine_server/api/app.py @@ -27,10 +27,42 @@ logger_name, make_logger, ) +from starlette.middleware import Middleware +from starlette.middleware.base import BaseHTTPMiddleware logger = make_logger(logger_name()) -app = FastAPI(title="launch", version="1.0.0", redoc_url="/api") + +class CustomMiddleware(BaseHTTPMiddleware): + async def dispatch(self, request: Request, call_next): + try: + LoggerTagManager.set(LoggerTagKey.REQUEST_ID, str(uuid.uuid4())) + return await call_next(request) + except Exception as e: + tb_str = traceback.format_exception(etype=type(e), value=e, tb=e.__traceback__) + request_id = LoggerTagManager.get(LoggerTagKey.REQUEST_ID) + timestamp = datetime.now(pytz.timezone("US/Pacific")).strftime("%Y-%m-%d %H:%M:%S %Z") + structured_log = { + "error": str(e), + "request_id": str(request_id), + "traceback": "".join(tb_str), + } + logger.error("Unhandled exception: %s", structured_log) + return JSONResponse( + { + "status_code": 500, + "content": { + "error": "Internal error occurred. Our team has been notified.", + "timestamp": timestamp, + "request_id": request_id, + }, + } + ) + + +app = FastAPI( + title="launch", version="1.0.0", redoc_url="/api", middleware=[Middleware(CustomMiddleware)] +) app.include_router(batch_job_router_v1) app.include_router(inference_task_router_v1) @@ -44,33 +76,6 @@ app.include_router(trigger_router_v1) -@app.middleware("http") -async def dispatch(request: Request, call_next): - try: - LoggerTagManager.set(LoggerTagKey.REQUEST_ID, str(uuid.uuid4())) - return await call_next(request) - except Exception as e: - tb_str = traceback.format_exception(etype=type(e), value=e, tb=e.__traceback__) - request_id = LoggerTagManager.get(LoggerTagKey.REQUEST_ID) - timestamp = datetime.now(pytz.timezone("US/Pacific")).strftime("%Y-%m-%d %H:%M:%S %Z") - structured_log = { - "error": str(e), - "request_id": str(request_id), - "traceback": "".join(tb_str), - } - logger.error("Unhandled exception: %s", structured_log) - return JSONResponse( - { - "status_code": 500, - "content": { - "error": "Internal error occurred. Our team has been notified.", - "timestamp": timestamp, - "request_id": request_id, - }, - } - ) - - # TODO: Remove this once we have a better way to serve internal docs INTERNAL_DOCS_PATH = str(Path(__file__).parents[3] / "launch_internal/site") if os.path.exists(INTERNAL_DOCS_PATH): diff --git a/model-engine/setup.cfg b/model-engine/setup.cfg index 1566418ea..a5f56d8a6 100644 --- a/model-engine/setup.cfg +++ b/model-engine/setup.cfg @@ -4,6 +4,7 @@ test=pytest [coverage:run] omit = model_engine_server/entrypoints/* + model_engine_server/api/app.py # TODO: Fix pylint errors # [pylint]