Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"""

import re
from dataclasses import dataclass
from typing import Any, Dict, List, Optional

from model_engine_server.common.constants import SUPPORTED_POST_INFERENCE_HOOKS
Expand Down Expand Up @@ -46,6 +47,8 @@
CONVERTED_FROM_ARTIFACT_LIKE_KEY = "_CONVERTED_FROM_ARTIFACT_LIKE"
MODEL_BUNDLE_CHANGED_KEY = "_MODEL_BUNDLE_CHANGED"

DEFAULT_DISALLOWED_TEAMS = ["_INVALID_TEAM"]

logger = make_logger(logger_name())


Expand Down Expand Up @@ -118,6 +121,20 @@ def validate_deployment_resources(
)


@dataclass
class ValidationResult:
passed: bool
message: str


# Placeholder team and product label validator that only checks for a single invalid team
def simple_team_product_validator(team: str, product: str) -> ValidationResult:
if team in DEFAULT_DISALLOWED_TEAMS:
return ValidationResult(False, "Invalid team")
else:
return ValidationResult(True, "Valid team")


def validate_labels(labels: Dict[str, str]) -> None:
for required_label in REQUIRED_ENDPOINT_LABELS:
if required_label not in labels:
Expand All @@ -129,6 +146,7 @@ def validate_labels(labels: Dict[str, str]) -> None:
if restricted_label in labels:
raise EndpointLabelsException(f"Cannot specify '{restricted_label}' in labels")

# TODO: remove after we fully migrate to the new team + product validator
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

😄

try:
from plugins.known_users import ALLOWED_TEAMS

Expand All @@ -138,6 +156,15 @@ def validate_labels(labels: Dict[str, str]) -> None:
except ModuleNotFoundError:
pass

try:
from shared_plugins.team_product_label_validation import validate_team_product_label
except ModuleNotFoundError:
validate_team_product_label = simple_team_product_validator

validation_result = validate_team_product_label(labels["team"], labels["product"])
if not validation_result.passed:
raise EndpointLabelsException(validation_result.message)

# Check k8s will accept the label values
regex_pattern = "(([A-Za-z0-9][-A-Za-z0-9_.]*)?[A-Za-z0-9])?" # k8s label regex
for label_value in labels.values():
Expand Down
11 changes: 6 additions & 5 deletions model-engine/tests/unit/api/test_model_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from fastapi.testclient import TestClient
from model_engine_server.common.dtos.model_endpoints import GetModelEndpointV1Response
from model_engine_server.domain.entities import ModelBundle, ModelEndpoint, ModelEndpointStatus
from model_engine_server.domain.use_cases.model_endpoint_use_cases import DEFAULT_DISALLOWED_TEAMS


def test_create_model_endpoint_success(
Expand Down Expand Up @@ -40,7 +41,6 @@ def test_create_model_endpoint_success(
assert response_2.status_code == 200


@pytest.mark.skip(reason="TODO: team validation is currently disabled")
def test_create_model_endpoint_invalid_team_returns_400(
model_bundle_1_v1: Tuple[ModelBundle, Any],
create_model_endpoint_request_sync: Dict[str, Any],
Expand All @@ -59,15 +59,16 @@ def test_create_model_endpoint_invalid_team_returns_400(
fake_batch_job_progress_gateway_contents={},
fake_docker_image_batch_job_bundle_repository_contents={},
)
create_model_endpoint_request_sync["labels"]["team"] = "some_invalid_team"
invalid_team_name = DEFAULT_DISALLOWED_TEAMS[0]
create_model_endpoint_request_sync["labels"]["team"] = invalid_team_name
response_1 = client.post(
"/v1/model-endpoints",
auth=(test_api_key, ""),
json=create_model_endpoint_request_sync,
)
assert response_1.status_code == 400

create_model_endpoint_request_async["labels"]["team"] = "some_invalid_team"
create_model_endpoint_request_async["labels"]["team"] = invalid_team_name
response_2 = client.post(
"/v1/model-endpoints",
auth=(test_api_key, ""),
Expand Down Expand Up @@ -394,7 +395,6 @@ def test_update_model_endpoint_by_id_success(
assert response.json()["endpoint_creation_task_id"]


@pytest.mark.skip(reason="TODO: team validation is currently disabled")
def test_update_model_endpoint_by_id_invalid_team_returns_400(
model_bundle_1_v1: Tuple[ModelBundle, Any],
model_endpoint_1: Tuple[ModelEndpoint, Any],
Expand All @@ -418,8 +418,9 @@ def test_update_model_endpoint_by_id_invalid_team_returns_400(
fake_batch_job_progress_gateway_contents={},
fake_docker_image_batch_job_bundle_repository_contents={},
)
invalid_team_name = DEFAULT_DISALLOWED_TEAMS[0]
update_model_endpoint_request["labels"] = {
"team": "some_invalid_team",
"team": invalid_team_name,
"product": "my_product",
}
response = client.put(
Expand Down