Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,7 @@
from model_engine_server.core.utils.timer import timer
from model_engine_server.domain.entities import ModelEndpointConfig
from model_engine_server.inference.async_inference.celery import async_inference_service
from model_engine_server.inference.common import (
get_endpoint_config,
load_predict_fn_or_cls,
run_predict,
)
from model_engine_server.inference.infra.gateways.datadog_inference_monitoring_metrics_gateway import (
DatadogInferenceMonitoringMetricsGateway,
)
from model_engine_server.inference.infra.gateways.firehose_streaming_storage_gateway import (
FirehoseStreamingStorageGateway,
)
from model_engine_server.inference.common import load_predict_fn_or_cls, run_predict
from model_engine_server.inference.post_inference_hooks import PostInferenceHooksHandler

logger = make_logger(logger_name())
Expand All @@ -38,23 +28,6 @@ def init_worker_global():
with timer(logger=logger, name="load_predict_fn_or_cls"):
predict_fn_or_cls = load_predict_fn_or_cls()

endpoint_config = get_endpoint_config()
hooks = PostInferenceHooksHandler(
endpoint_name=endpoint_config.endpoint_name,
bundle_name=endpoint_config.bundle_name,
post_inference_hooks=endpoint_config.post_inference_hooks,
user_id=endpoint_config.user_id,
billing_queue=endpoint_config.billing_queue,
billing_tags=endpoint_config.billing_tags,
default_callback_url=endpoint_config.default_callback_url,
default_callback_auth=endpoint_config.default_callback_auth,
monitoring_metrics_gateway=DatadogInferenceMonitoringMetricsGateway(),
endpoint_id=endpoint_config.endpoint_id,
endpoint_type=endpoint_config.endpoint_type,
bundle_id=endpoint_config.bundle_id,
labels=endpoint_config.labels,
streaming_storage_gateway=FirehoseStreamingStorageGateway(),
)
# k8s health check
with open(READYZ_FPATH, "w") as f:
f.write("READY")
Expand Down Expand Up @@ -96,11 +69,6 @@ def predict(self, request_params, return_pickled):
request_params_pydantic = EndpointPredictV1Request.parse_obj(request_params)
return run_predict(predict_fn_or_cls, request_params_pydantic) # type: ignore

def on_success(self, retval, task_id, args, kwargs):
request_params = args[0]
request_params_pydantic = EndpointPredictV1Request.parse_obj(request_params)
hooks.handle(request_params_pydantic, retval, task_id) # type: ignore


@async_inference_service.task(
base=InferenceTask,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,22 +1,11 @@
import traceback
from functools import wraps

from fastapi import BackgroundTasks, FastAPI, HTTPException, Response, status
from fastapi import FastAPI, HTTPException, Response, status
from model_engine_server.common.concurrency_limiter import MultiprocessingConcurrencyLimiter
from model_engine_server.common.dtos.tasks import EndpointPredictV1Request
from model_engine_server.core.loggers import logger_name, make_logger
from model_engine_server.inference.common import (
get_endpoint_config,
load_predict_fn_or_cls,
run_predict,
)
from model_engine_server.inference.infra.gateways.datadog_inference_monitoring_metrics_gateway import (
DatadogInferenceMonitoringMetricsGateway,
)
from model_engine_server.inference.infra.gateways.firehose_streaming_storage_gateway import (
FirehoseStreamingStorageGateway,
)
from model_engine_server.inference.post_inference_hooks import PostInferenceHooksHandler
from model_engine_server.inference.common import load_predict_fn_or_cls, run_predict
from model_engine_server.inference.sync_inference.constants import (
CONCURRENCY,
FAIL_ON_CONCURRENCY_LIMIT,
Expand Down Expand Up @@ -44,23 +33,6 @@ def _inner_2(*args, **kwargs):
# How does this interact with threads?
# Analogous to init_worker() inside async_inference
predict_fn = load_predict_fn_or_cls()
endpoint_config = get_endpoint_config()
hooks = PostInferenceHooksHandler(
endpoint_name=endpoint_config.endpoint_name,
bundle_name=endpoint_config.bundle_name,
post_inference_hooks=endpoint_config.post_inference_hooks,
user_id=endpoint_config.user_id,
billing_queue=endpoint_config.billing_queue,
billing_tags=endpoint_config.billing_tags,
default_callback_url=endpoint_config.default_callback_url,
default_callback_auth=endpoint_config.default_callback_auth,
monitoring_metrics_gateway=DatadogInferenceMonitoringMetricsGateway(),
endpoint_id=endpoint_config.endpoint_id,
endpoint_type=endpoint_config.endpoint_type,
bundle_id=endpoint_config.bundle_id,
labels=endpoint_config.labels,
streaming_storage_gateway=FirehoseStreamingStorageGateway(),
)


@app.get("/healthcheck")
Expand All @@ -72,14 +44,13 @@ def healthcheck():

@app.post("/predict")
@with_concurrency_limit(concurrency_limiter)
def predict(payload: EndpointPredictV1Request, background_tasks: BackgroundTasks):
def predict(payload: EndpointPredictV1Request):
"""
Assumption: payload is a JSON with format {"url": <url>, "args": <dictionary of args>, "returned_pickled": boolean}
Returns: Results of running the predict function on the request url. See `run_predict`.
"""
try:
result = run_predict(predict_fn, payload)
background_tasks.add_task(hooks.handle, payload, result)
return result
except Exception:
raise HTTPException(status_code=500, detail=dict(traceback=str(traceback.format_exc())))