Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions charts/model-engine/templates/service_template_config_map.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -515,6 +515,19 @@ data:
cpu: ${CPUS}
memory: ${MEMORY}
controlledResources: ["cpu", "memory"]
pod-disruption-budget.yaml: |-
apiVersion: policy/v1
kind: PodDisruptionBudget
metadata:
name: ${RESOURCE_NAME}
namespace: ${NAMESPACE}
labels:
{{- $service_template_labels | nindent 8 }}
spec:
minAvailable: 1
selector:
matchLabels:
app: ${RESOURCE_NAME}
batch-job-orchestration-job.yaml: |-
apiVersion: batch/v1
kind: Job
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,6 @@ class CachedImages(TypedDict):
t4: List[str]


KUBERNETES_MAX_LENGTH = 64


class ImageCacheGateway:
async def create_or_update_image_cache(self, cached_images: CachedImages) -> None:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
_kubernetes_core_api = None
_kubernetes_autoscaling_api = None
_kubernetes_batch_api = None
_kubernetes_policy_api = None
_kubernetes_custom_objects_api = None
_kubernetes_cluster_version = None

Expand Down Expand Up @@ -147,6 +148,16 @@ def get_kubernetes_batch_client(): # pragma: no cover
return _kubernetes_batch_api


def get_kubernetes_policy_client(): # pragma: no cover
if _lazy_load_kubernetes_clients:
global _kubernetes_policy_api
else:
_kubernetes_policy_api = None
if not _kubernetes_policy_api:
_kubernetes_policy_api = kubernetes_asyncio.client.PolicyV1Api()
return _kubernetes_policy_api


def get_kubernetes_custom_objects_client(): # pragma: no cover
if _lazy_load_kubernetes_clients:
global _kubernetes_custom_objects_api
Expand Down Expand Up @@ -599,6 +610,37 @@ async def _create_vpa(vpa: Dict[str, Any], name: str) -> None:
logger.exception("Got an exception when trying to apply the VerticalPodAutoscaler")
raise

@staticmethod
async def _create_pdb(pdb: Dict[str, Any], name: str) -> None:
"""
Lower-level function to create/patch a k8s PodDisruptionBudget (pdb)
Args:
pdb: PDB body (a nested Dict in the format specified by Kubernetes)
name: The name of the pdb on K8s

Returns:
Nothing; raises a k8s ApiException if failure

"""
policy_api = get_kubernetes_policy_client()
try:
await policy_api.create_namespaced_pod_disruption_budget(
namespace=hmi_config.endpoint_namespace,
body=pdb,
)
except ApiException as exc:
if exc.status == 409:
logger.info(f"PodDisruptionBudget {name} already exists, replacing")

await policy_api.patch_namespaced_pod_disruption_budget(
name=name,
namespace=hmi_config.endpoint_namespace,
body=pdb,
)
else:
logger.exception("Got an exception when trying to apply the PodDisruptionBudget")
raise

@staticmethod
async def _create_keda_scaled_object(scaled_object: Dict[str, Any], name: str) -> None:
custom_objects_api = get_kubernetes_custom_objects_client()
Expand Down Expand Up @@ -1035,6 +1077,27 @@ async def _delete_hpa(endpoint_id: str, deployment_name: str) -> bool:
return False
return True

@staticmethod
async def _delete_pdb(endpoint_id: str) -> bool:
policy_client = get_kubernetes_policy_client()
k8s_resource_group_name = _endpoint_id_to_k8s_resource_group_name(endpoint_id)
try:
await policy_client.delete_namespaced_pod_disruption_budget(
namespace=hmi_config.endpoint_namespace,
name=k8s_resource_group_name,
)
except ApiException as e:
if e.status == 404:
logger.warning(
f"Trying to delete nonexistent PodDisruptionBudget {k8s_resource_group_name}"
)
else:
logger.exception(
f"Deletion of PodDisruptionBudget {k8s_resource_group_name} failed"
)
return False
return True

@staticmethod
async def _delete_keda_scaled_object(endpoint_id: str) -> bool:
custom_objects_client = get_kubernetes_custom_objects_client()
Expand Down Expand Up @@ -1152,6 +1215,19 @@ async def _create_or_update_resources(
name=k8s_resource_group_name,
)

pdb_config_arguments = get_endpoint_resource_arguments_from_request(
k8s_resource_group_name=k8s_resource_group_name,
request=request,
sqs_queue_name=sqs_queue_name_str,
sqs_queue_url=sqs_queue_url_str,
endpoint_resource_name="pod-disruption-budget",
)
pdb_template = load_k8s_yaml("pod-disruption-budget.yaml", pdb_config_arguments)
await self._create_pdb(
pdb=pdb_template,
name=k8s_resource_group_name,
)

if model_endpoint_record.endpoint_type in {
ModelEndpointType.SYNC,
ModelEndpointType.STREAMING,
Expand Down Expand Up @@ -1561,6 +1637,7 @@ async def _delete_resources_async(self, endpoint_id: str, deployment_name: str)
endpoint_id=endpoint_id, deployment_name=deployment_name
)
await self._delete_vpa(endpoint_id=endpoint_id)
await self._delete_pdb(endpoint_id=endpoint_id)
return deployment_delete_succeeded and config_map_delete_succeeded

async def _delete_resources_sync(self, endpoint_id: str, deployment_name: str) -> bool:
Expand All @@ -1582,6 +1659,7 @@ async def _delete_resources_sync(self, endpoint_id: str, deployment_name: str) -
endpoint_id=endpoint_id
)
await self._delete_vpa(endpoint_id=endpoint_id)
await self._delete_pdb(endpoint_id=endpoint_id)

destination_rule_delete_succeeded = await self._delete_destination_rule(
endpoint_id=endpoint_id
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@
LAUNCH_HIGH_PRIORITY_CLASS = "model-engine-high-priority"
LAUNCH_DEFAULT_PRIORITY_CLASS = "model-engine-default-priority"

KUBERNETES_MAX_LENGTH = 64
IMAGE_HASH_MAX_LENGTH = 32
FORWARDER_PORT = 5000
USER_CONTAINER_PORT = 5005
ARTIFACT_LIKE_CONTAINER_PORT = FORWARDER_PORT
Expand Down Expand Up @@ -329,6 +329,12 @@ class VerticalPodAutoscalerArguments(_BaseEndpointArguments):
MEMORY: str


class PodDisruptionBudgetArguments(_BaseEndpointArguments):
"""Keyword-arguments for substituting into pod disruption budget templates."""

pass


class VirtualServiceArguments(_BaseEndpointArguments):
"""Keyword-arguments for substituting into virtual-service templates."""

Expand Down Expand Up @@ -432,7 +438,7 @@ class VerticalAutoscalingEndpointParams(TypedDict):


def compute_image_hash(image: str) -> str:
return str(hashlib.md5(str(image).encode()).hexdigest())[:KUBERNETES_MAX_LENGTH]
return str(hashlib.sha256(str(image).encode()).hexdigest())[:IMAGE_HASH_MAX_LENGTH]


def container_start_triton_cmd(
Expand Down Expand Up @@ -1184,5 +1190,18 @@ def get_endpoint_resource_arguments_from_request(
CPUS=str(build_endpoint_request.cpus),
MEMORY=str(build_endpoint_request.memory),
)
elif endpoint_resource_name == "pod-disruption-budget":
return PodDisruptionBudgetArguments(
# Base resource arguments
RESOURCE_NAME=k8s_resource_group_name,
NAMESPACE=hmi_config.endpoint_namespace,
ENDPOINT_ID=model_endpoint_record.id,
ENDPOINT_NAME=model_endpoint_record.name,
TEAM=team,
PRODUCT=product,
CREATED_BY=created_by,
OWNER=owner,
GIT_TAG=GIT_TAG,
)
else:
raise Exception(f"Unknown resource name: {endpoint_resource_name}")
Original file line number Diff line number Diff line change
Expand Up @@ -2728,6 +2728,31 @@ data:
cpu: ${CPUS}
memory: ${MEMORY}
controlledResources: ["cpu", "memory"]
pod-disruption-budget.yaml: |-
apiVersion: policy/v1
kind: PodDisruptionBudget
metadata:
name: ${RESOURCE_NAME}
namespace: ${NAMESPACE}
labels:
user_id: ${OWNER}
team: ${TEAM}
product: ${PRODUCT}
created_by: ${CREATED_BY}
owner: ${OWNER}
env: circleci
managed-by: model-engine
use_scale_launch_endpoint_network_policy: "true"
tags.datadoghq.com/env: circleci
tags.datadoghq.com/version: ${GIT_TAG}
tags.datadoghq.com/service: ${ENDPOINT_NAME}
endpoint_id: ${ENDPOINT_ID}
endpoint_name: ${ENDPOINT_NAME}
spec:
minAvailable: 1
selector:
matchLabels:
app: ${RESOURCE_NAME}
batch-job-orchestration-job.yaml: |-
apiVersion: batch/v1
kind: Job
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,16 @@ def mock_autoscaling_client():
yield mock_client


@pytest.fixture
def mock_policy_client():
mock_client = AsyncMock()
with patch(
f"{MODULE_PATH}.get_kubernetes_policy_client",
return_value=mock_client,
):
yield mock_client


@pytest.fixture
def mock_custom_objects_client():
mock_client = AsyncMock()
Expand Down Expand Up @@ -276,6 +286,7 @@ async def test_create_async_endpoint_has_correct_labels(
mock_apps_client,
mock_core_client,
mock_autoscaling_client,
mock_policy_client,
mock_custom_objects_client,
mock_get_kubernetes_cluster_version,
create_resources_request_async_runnable_image: CreateOrUpdateResourcesRequest,
Expand Down Expand Up @@ -323,6 +334,11 @@ async def test_create_async_endpoint_has_correct_labels(
)
assert delete_custom_object_call_args_list == []

# Verify PDB labels
create_pdb_call_args = mock_policy_client.create_namespaced_pod_disruption_budget.call_args
pdb_body = create_pdb_call_args.kwargs["body"]
_verify_non_deployment_labels(pdb_body, request)

if build_endpoint_request.model_endpoint_record.endpoint_type == ModelEndpointType.SYNC:
assert create_custom_object_call_args_list == []
_verify_custom_object_plurals(
Expand All @@ -339,6 +355,7 @@ async def test_create_streaming_endpoint_has_correct_labels(
mock_apps_client,
mock_core_client,
mock_autoscaling_client,
mock_policy_client,
mock_custom_objects_client,
mock_get_kubernetes_cluster_version,
create_resources_request_streaming_runnable_image: CreateOrUpdateResourcesRequest,
Expand All @@ -365,6 +382,11 @@ async def test_create_streaming_endpoint_has_correct_labels(
config_map_body = create_config_map_call_args.kwargs["body"]
_verify_non_deployment_labels(config_map_body, request)

# Verify PDB labels
create_pdb_call_args = mock_policy_client.create_namespaced_pod_disruption_budget.call_args
pdb_body = create_pdb_call_args.kwargs["body"]
_verify_non_deployment_labels(pdb_body, request)

# Verify HPA labels
create_hpa_call_args = (
mock_autoscaling_client.create_namespaced_horizontal_pod_autoscaler.call_args
Expand Down Expand Up @@ -406,6 +428,7 @@ async def test_create_sync_endpoint_has_correct_labels(
mock_apps_client,
mock_core_client,
mock_autoscaling_client,
mock_policy_client,
mock_custom_objects_client,
mock_get_kubernetes_cluster_version,
create_resources_request_sync_runnable_image: CreateOrUpdateResourcesRequest,
Expand Down Expand Up @@ -441,6 +464,11 @@ async def test_create_sync_endpoint_has_correct_labels(
hpa_body = create_hpa_call_args.kwargs["body"]
_verify_non_deployment_labels(hpa_body, request)

# Verify PDB labels
create_pdb_call_args = mock_policy_client.create_namespaced_pod_disruption_budget.call_args
pdb_body = create_pdb_call_args.kwargs["body"]
_verify_non_deployment_labels(pdb_body, request)

# Make sure that an VPA is created if optimize_costs is True.
build_endpoint_request = request.build_endpoint_request
optimize_costs = build_endpoint_request.optimize_costs
Expand Down Expand Up @@ -477,6 +505,7 @@ async def test_create_sync_endpoint_has_correct_k8s_service_type(
mock_apps_client,
mock_core_client,
mock_autoscaling_client,
mock_policy_client,
mock_custom_objects_client,
mock_get_kubernetes_cluster_version,
create_resources_request_sync_runnable_image: CreateOrUpdateResourcesRequest,
Expand Down Expand Up @@ -531,6 +560,7 @@ async def test_get_resources_async_success(
mock_apps_client,
mock_core_client,
mock_autoscaling_client,
mock_policy_client,
mock_custom_objects_client,
):
k8s_endpoint_resource_delegate.__setattr__(
Expand Down Expand Up @@ -590,6 +620,7 @@ async def test_get_resources_sync_success(
mock_apps_client,
mock_core_client,
mock_autoscaling_client,
mock_policy_client,
mock_custom_objects_client,
):
k8s_endpoint_resource_delegate.__setattr__(
Expand Down Expand Up @@ -653,6 +684,7 @@ async def test_delete_resources_async_success(
mock_apps_client,
mock_core_client,
mock_autoscaling_client,
mock_policy_client,
mock_custom_objects_client,
):
deleted = await k8s_endpoint_resource_delegate.delete_resources(
Expand All @@ -667,6 +699,7 @@ async def test_delete_resources_sync_success(
mock_apps_client,
mock_core_client,
mock_autoscaling_client,
mock_policy_client,
mock_custom_objects_client,
):
deleted = await k8s_endpoint_resource_delegate.delete_resources(
Expand Down