Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 16 additions & 9 deletions model-engine/model_engine_server/api/model_endpoints_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from model_engine_server.core.loggers import logger_name, make_logger
from model_engine_server.domain.exceptions import (
EndpointDeleteFailedException,
EndpointInfraStateNotFound,
EndpointLabelsException,
EndpointResourceInvalidRequestException,
ExistingEndpointOperationInProgressException,
Expand Down Expand Up @@ -67,14 +68,11 @@ async def create_model_endpoint(
status_code=400,
detail="The specified model endpoint already exists.",
) from exc
except EndpointLabelsException as exc:
raise HTTPException(
status_code=400,
detail=str(exc),
) from exc
except ObjectHasInvalidValueException as exc:
raise HTTPException(status_code=400, detail=str(exc))
except EndpointResourceInvalidRequestException as exc:
except (
EndpointLabelsException,
ObjectHasInvalidValueException,
EndpointResourceInvalidRequestException,
) as exc:
raise HTTPException(
status_code=400,
detail=str(exc),
Expand Down Expand Up @@ -148,7 +146,11 @@ async def update_model_endpoint(
return await use_case.execute(
user=auth, model_endpoint_id=model_endpoint_id, request=request
)
except EndpointLabelsException as exc:
except (
EndpointLabelsException,
ObjectHasInvalidValueException,
EndpointResourceInvalidRequestException,
) as exc:
raise HTTPException(
status_code=400,
detail=str(exc),
Expand All @@ -163,6 +165,11 @@ async def update_model_endpoint(
status_code=409,
detail="Existing operation on endpoint in progress, try again later.",
) from exc
except EndpointInfraStateNotFound as exc:
raise HTTPException(
status_code=500,
detail="Endpoint infra state not found, try again later.",
) from exc


@model_endpoint_router_v1.delete(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -362,13 +362,22 @@ async def execute(
# infra_state to make sure that after the update, all resources are valid and in sync.
# E.g. If user only want to update gpus and leave gpu_type as None, we use the existing gpu_type
# from infra_state to avoid passing in None to validate_resource_requests.
raw_request = request.dict(exclude_unset=True)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good Catch here. I'm guessing the bug here was that pydantic fields with non-falsey defaults were evaluating as being set by the user even if they were not explicitly set in the request?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, primarily that a user could not specify None for fields for which it's allowed, ie gpu_type

validate_resource_requests(
bundle=bundle,
cpus=request.cpus or infra_state.resource_state.cpus,
memory=request.memory or infra_state.resource_state.memory,
storage=request.storage or infra_state.resource_state.storage,
gpus=request.gpus or infra_state.resource_state.gpus,
gpu_type=request.gpu_type or infra_state.resource_state.gpu_type,
cpus=(request.cpus if "cpus" in raw_request else infra_state.resource_state.cpus),
memory=(
request.memory if "memory" in raw_request else infra_state.resource_state.memory
),
storage=(
request.storage if "storage" in raw_request else infra_state.resource_state.storage
),
gpus=(request.gpus if "gpus" in raw_request else infra_state.resource_state.gpus),
gpu_type=(
request.gpu_type
if "gpu_type" in raw_request
else infra_state.resource_state.gpu_type
),
)

validate_deployment_resources(
Expand Down
298 changes: 298 additions & 0 deletions model-engine/tests/unit/domain/test_model_endpoint_use_cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
ObjectNotFoundException,
)
from model_engine_server.domain.use_cases.model_endpoint_use_cases import (
CONVERTED_FROM_ARTIFACT_LIKE_KEY,
CreateModelEndpointV1UseCase,
DeleteModelEndpointByIdV1UseCase,
GetModelEndpointByIdV1UseCase,
Expand Down Expand Up @@ -855,6 +856,303 @@ async def test_update_model_endpoint_team_success(
assert isinstance(response, UpdateModelEndpointV1Response)


@pytest.mark.asyncio
async def test_update_model_endpoint_use_case_raises_invalid_value_exception(
fake_model_bundle_repository,
fake_model_endpoint_service,
model_bundle_2: ModelBundle,
model_endpoint_1: ModelEndpoint,
update_model_endpoint_request: UpdateModelEndpointV1Request,
):
fake_model_bundle_repository.add_model_bundle(model_bundle_2)
fake_model_endpoint_service.add_model_endpoint(model_endpoint_1)
fake_model_endpoint_service.model_bundle_repository = fake_model_bundle_repository
use_case = UpdateModelEndpointByIdV1UseCase(
model_bundle_repository=fake_model_bundle_repository,
model_endpoint_service=fake_model_endpoint_service,
)
user_id = model_endpoint_1.record.created_by
user = User(user_id=user_id, team_id=user_id, is_privileged_user=True)

request = update_model_endpoint_request.copy()
request.metadata = {CONVERTED_FROM_ARTIFACT_LIKE_KEY: False}
with pytest.raises(ObjectHasInvalidValueException):
await use_case.execute(
user=user,
model_endpoint_id=model_endpoint_1.record.id,
request=request,
)


@pytest.mark.asyncio
async def test_update_model_endpoint_use_case_raises_resource_request_exception(
fake_model_bundle_repository,
fake_model_endpoint_service,
model_bundle_1: ModelBundle,
model_bundle_2: ModelBundle,
model_bundle_4: ModelBundle,
model_bundle_6: ModelBundle,
model_bundle_triton_enhanced_runnable_image_0_cpu_None_memory_storage: ModelBundle,
model_endpoint_1: ModelEndpoint,
model_endpoint_2: ModelEndpoint,
update_model_endpoint_request: UpdateModelEndpointV1Request,
):
fake_model_bundle_repository.add_model_bundle(model_bundle_1)
fake_model_bundle_repository.add_model_bundle(model_bundle_2)
fake_model_bundle_repository.add_model_bundle(model_bundle_4)
fake_model_bundle_repository.add_model_bundle(model_bundle_6)
fake_model_bundle_repository.add_model_bundle(
model_bundle_triton_enhanced_runnable_image_0_cpu_None_memory_storage
)
fake_model_endpoint_service.add_model_endpoint(model_endpoint_1)
fake_model_endpoint_service.add_model_endpoint(model_endpoint_2)
fake_model_endpoint_service.model_bundle_repository = fake_model_bundle_repository
use_case = UpdateModelEndpointByIdV1UseCase(
model_bundle_repository=fake_model_bundle_repository,
model_endpoint_service=fake_model_endpoint_service,
)
user_id = model_endpoint_1.record.created_by
user = User(user_id=user_id, team_id=user_id, is_privileged_user=True)

request = update_model_endpoint_request.copy()
request.cpus = -1
with pytest.raises(EndpointResourceInvalidRequestException):
await use_case.execute(
user=user,
model_endpoint_id=model_endpoint_1.record.id,
request=request,
)

request = update_model_endpoint_request.copy()
request.cpus = float("inf")
with pytest.raises(EndpointResourceInvalidRequestException):
await use_case.execute(
user=user,
model_endpoint_id=model_endpoint_1.record.id,
request=request,
)

request = update_model_endpoint_request.copy()
request.memory = "invalid_memory_amount"
with pytest.raises(EndpointResourceInvalidRequestException):
await use_case.execute(
user=user,
model_endpoint_id=model_endpoint_1.record.id,
request=request,
)

request = update_model_endpoint_request.copy()
request.memory = float("inf")
with pytest.raises(EndpointResourceInvalidRequestException):
await use_case.execute(
user=user,
model_endpoint_id=model_endpoint_1.record.id,
request=request,
)

request = update_model_endpoint_request.copy()
request.storage = "invalid_storage_amount"
with pytest.raises(EndpointResourceInvalidRequestException):
await use_case.execute(
user=user,
model_endpoint_id=model_endpoint_1.record.id,
request=request,
)

request = update_model_endpoint_request.copy()
request.storage = float("inf")
with pytest.raises(EndpointResourceInvalidRequestException):
await use_case.execute(
user=user,
model_endpoint_id=model_endpoint_1.record.id,
request=request,
)

# specific to sync endpoint
request = update_model_endpoint_request.copy()
request.min_workers = 0
with pytest.raises(EndpointResourceInvalidRequestException):
await use_case.execute(
user=user,
model_endpoint_id=model_endpoint_2.record.id,
request=request,
)

request = update_model_endpoint_request.copy()
request.max_workers = 2**63
with pytest.raises(EndpointResourceInvalidRequestException):
await use_case.execute(
user=user,
model_endpoint_id=model_endpoint_1.record.id,
request=request,
)

request = update_model_endpoint_request.copy()
request.gpus = 0
with pytest.raises(EndpointResourceInvalidRequestException):
await use_case.execute(
user=user,
model_endpoint_id=model_endpoint_1.record.id,
request=request,
)

request = update_model_endpoint_request.copy()
request.gpu_type = None
with pytest.raises(EndpointResourceInvalidRequestException):
await use_case.execute(
user=user,
model_endpoint_id=model_endpoint_1.record.id,
request=request,
)

request = update_model_endpoint_request.copy()
request.gpu_type = "invalid_gpu_type"
with pytest.raises(EndpointResourceInvalidRequestException):
await use_case.execute(
user=user,
model_endpoint_id=model_endpoint_1.record.id,
request=request,
)

instance_limits = REQUESTS_BY_GPU_TYPE[model_endpoint_1.infra_state.resource_state.gpu_type]

request = update_model_endpoint_request.copy()
request.model_bundle_id = model_bundle_1.id
# Test that request.cpus + FORWARDER_CPU_USAGE > instance_limits["cpus"] should fail
request.cpus = instance_limits["cpus"]
with pytest.raises(EndpointResourceInvalidRequestException):
await use_case.execute(
user=user,
model_endpoint_id=model_endpoint_1.record.id,
request=request,
)

request = update_model_endpoint_request.copy()
request.model_bundle_id = model_bundle_1.id
# Test that request.memory + FORWARDER_MEMORY_USAGE > instance_limits["memory"] should fail
request.memory = instance_limits["memory"]
with pytest.raises(EndpointResourceInvalidRequestException):
await use_case.execute(
user=user,
model_endpoint_id=model_endpoint_1.record.id,
request=request,
)

request = update_model_endpoint_request.copy()
request.model_bundle_id = model_bundle_1.id
# Test that request.storage + FORWARDER_STORAGE_USAGE > STORAGE_LIMIT should fail
request.storage = STORAGE_LIMIT
with pytest.raises(EndpointResourceInvalidRequestException):
await use_case.execute(
user=user,
model_endpoint_id=model_endpoint_1.record.id,
request=request,
)

request = update_model_endpoint_request.copy()
request.model_bundle_id = model_bundle_4.id
# Test that request.cpus + FORWARDER_CPU_USAGE > instance_limits["cpus"] should fail
request.cpus = instance_limits["cpus"]
with pytest.raises(EndpointResourceInvalidRequestException):
await use_case.execute(
user=user,
model_endpoint_id=model_endpoint_1.record.id,
request=request,
)

request = update_model_endpoint_request.copy()
request.model_bundle_id = model_bundle_4.id
# Test that request.memory + FORWARDER_MEMORY_USAGE > instance_limits["memory"] should fail
request.memory = instance_limits["memory"]
with pytest.raises(EndpointResourceInvalidRequestException):
await use_case.execute(
user=user,
model_endpoint_id=model_endpoint_1.record.id,
request=request,
)

request = update_model_endpoint_request.copy()
request.model_bundle_id = model_bundle_4.id
# Test that request.storage + FORWARDER_STORAGE_USAGE > STORAGE_LIMIT should fail
request.storage = STORAGE_LIMIT
with pytest.raises(EndpointResourceInvalidRequestException):
await use_case.execute(
user=user,
model_endpoint_id=model_endpoint_1.record.id,
request=request,
)

# Test TritonEnhancedRunnableImageFlavor specific validation logic
request = update_model_endpoint_request.copy()
request.model_bundle_id = model_bundle_6.id
# TritonEnhancedRunnableImageFlavor requires gpu >= 1
request.gpus = 0.9
with pytest.raises(EndpointResourceInvalidRequestException):
await use_case.execute(
user=user,
model_endpoint_id=model_endpoint_1.record.id,
request=request,
)

request = update_model_endpoint_request.copy()
request.model_bundle_id = model_bundle_6.id
# TritonEnhancedRunnableImageFlavor requires gpu_type be specified
request.gpu_type = None
with pytest.raises(EndpointResourceInvalidRequestException):
await use_case.execute(
user=user,
model_endpoint_id=model_endpoint_1.record.id,
request=request,
)

request = update_model_endpoint_request.copy()
request.model_bundle_id = model_bundle_6.id
# Test that request.cpus + FORWARDER_CPU_USAGE + triton_num_cpu > instance_limits["cpu"] should fail
request.cpus = instance_limits["cpus"] - FORWARDER_CPU_USAGE
with pytest.raises(EndpointResourceInvalidRequestException):
await use_case.execute(
user=user,
model_endpoint_id=model_endpoint_1.record.id,
request=request,
)

request = update_model_endpoint_request.copy()
request.model_bundle_id = model_bundle_6.id
# Test that request.memory + FORWARDER_MEMORY_USAGE + triton_memory > instance_limits["memory"] should fail
request.memory = parse_mem_request(instance_limits["memory"]) - parse_mem_request(
FORWARDER_MEMORY_USAGE
)
with pytest.raises(EndpointResourceInvalidRequestException):
await use_case.execute(
user=user,
model_endpoint_id=model_endpoint_1.record.id,
request=request,
)

request = update_model_endpoint_request.copy()
request.model_bundle_id = model_bundle_6.id
# Test that request.storage + FORWARDER_STORAGE_USAGE + triton_storage > STORAGE_LIMIT should fail
request.storage = parse_mem_request(STORAGE_LIMIT) - parse_mem_request(FORWARDER_STORAGE_USAGE)
with pytest.raises(EndpointResourceInvalidRequestException):
await use_case.execute(
user=user,
model_endpoint_id=model_endpoint_1.record.id,
request=request,
)

request = update_model_endpoint_request.copy()
# Test triton_num_cpu >= 1
request.model_bundle_id = (
model_bundle_triton_enhanced_runnable_image_0_cpu_None_memory_storage.id
)
with pytest.raises(EndpointResourceInvalidRequestException):
await use_case.execute(
user=user,
model_endpoint_id=model_endpoint_1.record.id,
request=request,
)


@pytest.mark.asyncio
async def test_update_model_endpoint_raises_not_found(
fake_model_bundle_repository,
Expand Down