Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 34 additions & 1 deletion batch_inference_v2/batch_inference_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,17 @@ def infer(
model_endpoint_sample_set: Union[
mlrun.DataItem, list, dict, pd.DataFrame, pd.Series, np.ndarray
] = None,

# the following parameters are deprecated and will be removed once the versioning mechanism is implemented
# TODO: Remove the following parameters once FHUB-13 is resolved
trigger_monitoring_job: Optional[bool] = None,
batch_image_job: Optional[str] = None,
model_endpoint_drift_threshold: Optional[float] = None,
model_endpoint_possible_drift_threshold: Optional[float] = None,

# prediction kwargs to pass to the model predict function
**predict_kwargs: Dict[str, Any],

):
"""
Perform a prediction on the provided dataset using the specified model.
Expand Down Expand Up @@ -173,10 +183,33 @@ def infer(
:param model_endpoint_sample_set: A sample dataset to give to compare the inputs in the drift analysis.
Can be provided as an input (DataItem) or as a parameter (e.g. string, list, DataFrame).
The default chosen sample set will always be the one who is set in the model artifact itself.
:param trigger_monitoring_job: Whether to trigger the batch drift analysis after the infer job.
:param batch_image_job: The image that will be used to register the monitoring batch job if not exist.
By default, the image is mlrun/mlrun.
:param model_endpoint_drift_threshold: The threshold of which to mark drifts. Defaulted to 0.7.
:param model_endpoint_possible_drift_threshold: The threshold of which to mark possible drifts. Defaulted to 0.5.

raises MLRunInvalidArgumentError: if both `model_path` and `endpoint_id` are not provided
"""


if trigger_monitoring_job:
context.logger.warning("The `trigger_monitoring_job` parameter is deprecated and will be removed once the versioning mechanism is implemented. "
"if you are using mlrun<1.7.0, please import the previous version of this function, for example "
"'hub://batch_inference_v2:2.5.0'.")
if batch_image_job:
context.logger.warning("The `batch_image_job` parameter is deprecated and will be removed once the versioning mechanism is implemented. "
"if you are using mlrun<1.7.0, please import the previous version of this function, for example "
"'hub://batch_inference_v2:2.5.0'.")
if model_endpoint_drift_threshold:
context.logger.warning("The `model_endpoint_drift_threshold` parameter is deprecated and will be removed once the versioning mechanism is implemented. "
"if you are using mlrun<1.7.0, please import the previous version of this function, for example "
"'hub://batch_inference_v2:2.5.0'.")
if model_endpoint_possible_drift_threshold:
context.logger.warning("The `model_endpoint_possible_drift_threshold` parameter is deprecated and will be removed once the versioning mechanism is implemented. "
"if you are using mlrun<1.7.0, please import the previous version of this function, for example "
"'hub://batch_inference_v2:2.5.0'.")

# Loading the model:
context.logger.info(f"Loading model...")
if isinstance(model_path, mlrun.DataItem):
Expand Down Expand Up @@ -250,4 +283,4 @@ def infer(
model_endpoint_name=model_endpoint_name,
infer_results_df=result_set.copy(),
sample_set_statistics=sample_set_statistics,
)
)
Loading