Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ forwarder:
model_engine_unwrap: false
serialize_results_as_string: false
wrap_response: false
forward_http_status: true
async:
user_port: 5005
user_hostname: "localhost"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ forwarder:
batch_route: null
model_engine_unwrap: true
serialize_results_as_string: true
forward_http_status: true
async:
user_port: 5005
user_hostname: "localhost"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ forwarder:
batch_route: null
model_engine_unwrap: true
serialize_results_as_string: true
forward_http_status: true
stream:
user_port: 5005
user_hostname: "localhost"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import subprocess

from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
from sse_starlette.sse import EventSourceResponse

app = FastAPI()
Expand All @@ -21,6 +22,12 @@ async def predict(request: Request):
return await request.json()


@app.post("/predict500")
async def predict500(request: Request):
response = JSONResponse(content=await request.json(), status_code=500)
return response


@app.post("/stream")
async def stream(request: Request):
value = (await request.body()).decode()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import requests
import sseclient
import yaml
from fastapi.responses import JSONResponse
from model_engine_server.common.dtos.tasks import EndpointPredictV1Request
from model_engine_server.core.loggers import logger_name, make_logger
from model_engine_server.inference.common import get_endpoint_config
Expand Down Expand Up @@ -122,6 +123,7 @@ class Forwarder(ModelEngineSerializationMixin):
serialize_results_as_string: bool
post_inference_hooks_handler: PostInferenceHooksHandler
wrap_response: bool
forward_http_status: bool

def __call__(self, json_payload: Any) -> Any:
request_obj = EndpointPredictV1Request.parse_obj(json_payload)
Expand All @@ -131,32 +133,42 @@ def __call__(self, json_payload: Any) -> Any:
logger.info(f"Accepted request, forwarding {json_payload_repr=}")

try:
response: Any = requests.post(
response_raw: Any = requests.post(
self.predict_endpoint,
json=json_payload,
headers={
"Content-Type": "application/json",
},
).json()
)
response = response_raw.json()
except Exception:
logger.exception(
f"Failed to get response for request ({json_payload_repr}) "
"from user-defined inference service."
)
raise
if isinstance(response, dict):
logger.info(f"Got response from user-defined service: {response.keys()=}")
logger.info(
f"Got response from user-defined service: {response.keys()=}, {response_raw.status_code=}"
)
elif isinstance(response, list):
logger.info(f"Got response from user-defined service: {len(response)=}")
logger.info(
f"Got response from user-defined service: {len(response)=}, {response_raw.status_code=}"
)
else:
logger.info(f"Got response from user-defined service: {response=}")
logger.info(
f"Got response from user-defined service: {response=}, {response_raw.status_code=}"
)

if self.wrap_response:
response = self.get_response_payload(using_serialize_results_as_string, response)

# TODO: we actually want to do this after we've returned the response.
self.post_inference_hooks_handler.handle(request_obj, response)
return response
if self.forward_http_status:
return JSONResponse(content=response, status_code=response_raw.status_code)
else:
return response


@dataclass(frozen=True)
Expand All @@ -180,6 +192,7 @@ class LoadForwarder:
model_engine_unwrap: bool = True
serialize_results_as_string: bool = True
wrap_response: bool = True
forward_http_status: bool = False

def load(self, resources: Path, cache: Any) -> Forwarder:
if self.use_grpc:
Expand Down Expand Up @@ -278,6 +291,7 @@ def endpoint(route: str) -> str:
serialize_results_as_string=serialize_results_as_string,
post_inference_hooks_handler=handler,
wrap_response=self.wrap_response,
forward_http_status=self.forward_http_status,
)


Expand Down Expand Up @@ -492,7 +506,7 @@ def _set_value(config: dict, key_path: List[str], value: Any) -> None:
"""
key = key_path[0]
if len(key_path) == 1:
config[key] = value
config[key] = value if not value.isdigit() else int(value)
else:
if key not in config:
config[key] = dict()
Expand Down
68 changes: 67 additions & 1 deletion model-engine/tests/unit/inference/test_forwarding.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from unittest import mock

import pytest
from fastapi.responses import JSONResponse
from model_engine_server.core.utils.env import environment
from model_engine_server.domain.entities import ModelEndpointConfig
from model_engine_server.inference.forwarding.forwarding import (
Expand Down Expand Up @@ -33,6 +34,19 @@ class mocked_static_status_code:
def mocked_post(*args, **kwargs): # noqa
@dataclass
class mocked_static_json:
status_code: int = 200

def json(self) -> dict:
return PAYLOAD # type: ignore

return mocked_static_json()


def mocked_post_500(*args, **kwargs): # noqa
@dataclass
class mocked_static_json:
status_code: int = 500

def json(self) -> dict:
return PAYLOAD # type: ignore

Expand Down Expand Up @@ -85,16 +99,27 @@ def test_forwarders(post_inference_hooks_handler):
serialize_results_as_string=False,
post_inference_hooks_handler=post_inference_hooks_handler,
wrap_response=True,
forward_http_status=True,
)
json_response = fwd({"ignore": "me"})
_check(json_response)


def _check(json_response) -> None:
json_response = (
json.loads(json_response.body.decode("utf-8"))
if isinstance(json_response, JSONResponse)
else json_response
)
assert json_response == {"result": PAYLOAD}


def _check_responses_not_wrapped(json_response) -> None:
json_response = (
json.loads(json_response.body.decode("utf-8"))
if isinstance(json_response, JSONResponse)
else json_response
)
assert json_response == PAYLOAD


Expand All @@ -121,12 +146,18 @@ def test_forwarders_serialize_results_as_string(post_inference_hooks_handler):
serialize_results_as_string=True,
post_inference_hooks_handler=post_inference_hooks_handler,
wrap_response=True,
forward_http_status=True,
)
json_response = fwd({"ignore": "me"})
_check_serialized(json_response)


def _check_serialized(json_response) -> None:
json_response = (
json.loads(json_response.body.decode("utf-8"))
if isinstance(json_response, JSONResponse)
else json_response
)
assert isinstance(json_response["result"], str)
assert len(json_response) == 1, f"expecting only 'result' key, but got {json_response=}"
assert json.loads(json_response["result"]) == PAYLOAD
Expand All @@ -141,17 +172,18 @@ def test_forwarders_override_serialize_results(post_inference_hooks_handler):
serialize_results_as_string=True,
post_inference_hooks_handler=post_inference_hooks_handler,
wrap_response=True,
forward_http_status=True,
)
json_response = fwd({"ignore": "me", KEY_SERIALIZE_RESULTS_AS_STRING: False})
_check(json_response)
assert json_response == {"result": PAYLOAD}

fwd = Forwarder(
"ignored",
model_engine_unwrap=True,
serialize_results_as_string=False,
post_inference_hooks_handler=post_inference_hooks_handler,
wrap_response=True,
forward_http_status=True,
)
json_response = fwd({"ignore": "me", KEY_SERIALIZE_RESULTS_AS_STRING: True})
_check_serialized(json_response)
Expand All @@ -166,11 +198,43 @@ def test_forwarder_does_not_wrap_response(post_inference_hooks_handler):
serialize_results_as_string=False,
post_inference_hooks_handler=post_inference_hooks_handler,
wrap_response=False,
forward_http_status=True,
)
json_response = fwd({"ignore": "me"})
_check_responses_not_wrapped(json_response)


@mock.patch("requests.post", mocked_post_500)
@mock.patch("requests.get", mocked_get)
def test_forwarder_return_status_code(post_inference_hooks_handler):
fwd = Forwarder(
"ignored",
model_engine_unwrap=True,
serialize_results_as_string=True,
post_inference_hooks_handler=post_inference_hooks_handler,
wrap_response=False,
forward_http_status=True,
)
json_response = fwd({"ignore": "me"})
_check_responses_not_wrapped(json_response)
assert json_response.status_code == 500


@mock.patch("requests.post", mocked_post_500)
@mock.patch("requests.get", mocked_get)
def test_forwarder_dont_return_status_code(post_inference_hooks_handler):
fwd = Forwarder(
"ignored",
model_engine_unwrap=True,
serialize_results_as_string=True,
post_inference_hooks_handler=post_inference_hooks_handler,
wrap_response=False,
forward_http_status=False,
)
json_response = fwd({"ignore": "me"})
assert json_response == PAYLOAD


@mock.patch("requests.post", mocked_post)
@mock.patch("requests.get", mocked_get)
@mock.patch(
Expand Down Expand Up @@ -219,6 +283,7 @@ def test_forwarder_serialize_within_args(post_inference_hooks_handler):
serialize_results_as_string=True,
post_inference_hooks_handler=post_inference_hooks_handler,
wrap_response=True,
forward_http_status=True,
)
# expected: no `serialize_results_as_string` at top-level nor in 'args'
json_response = fwd({"something": "to ignore", "args": {"my": "payload", "is": "here"}})
Expand All @@ -237,6 +302,7 @@ def test_forwarder_serialize_within_args(post_inference_hooks_handler):
serialize_results_as_string=True,
post_inference_hooks_handler=post_inference_hooks_handler,
wrap_response=True,
forward_http_status=True,
)
json_response = fwd(payload)
_check_serialized(json_response)
Expand Down