Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 66 additions & 8 deletions model-engine/model_engine_server/api/llms_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
StreamError,
StreamErrorContent,
TokenUsage,
UpdateLLMModelEndpointV1Request,
UpdateLLMModelEndpointV1Response,
)
from model_engine_server.common.dtos.model_endpoints import ModelEndpointOrderBy
from model_engine_server.core.auth.authentication_repository import User
Expand All @@ -54,7 +56,6 @@
LLMFineTuningQuotaReached,
ObjectAlreadyExistsException,
ObjectHasInvalidValueException,
ObjectNotApprovedException,
ObjectNotAuthorizedException,
ObjectNotFoundException,
UpstreamServiceError,
Expand All @@ -70,11 +71,13 @@
from model_engine_server.domain.use_cases.llm_model_endpoint_use_cases import (
CompletionStreamV1UseCase,
CompletionSyncV1UseCase,
CreateLLMModelBundleV1UseCase,
CreateLLMModelEndpointV1UseCase,
DeleteLLMEndpointByNameUseCase,
GetLLMModelEndpointByNameV1UseCase,
ListLLMModelEndpointsV1UseCase,
ModelDownloadV1UseCase,
UpdateLLMModelEndpointV1UseCase,
)
from model_engine_server.domain.use_cases.model_bundle_use_cases import CreateModelBundleV2UseCase
from sse_starlette.sse import EventSourceResponse
Expand Down Expand Up @@ -151,13 +154,16 @@ async def create_model_endpoint(
docker_repository=external_interfaces.docker_repository,
model_primitive_gateway=external_interfaces.model_primitive_gateway,
)
use_case = CreateLLMModelEndpointV1UseCase(
create_llm_model_bundle_use_case = CreateLLMModelBundleV1UseCase(
create_model_bundle_use_case=create_model_bundle_use_case,
model_bundle_repository=external_interfaces.model_bundle_repository,
model_endpoint_service=external_interfaces.model_endpoint_service,
llm_artifact_gateway=external_interfaces.llm_artifact_gateway,
docker_repository=external_interfaces.docker_repository,
)
use_case = CreateLLMModelEndpointV1UseCase(
create_llm_model_bundle_use_case=create_llm_model_bundle_use_case,
model_endpoint_service=external_interfaces.model_endpoint_service,
)
return await use_case.execute(user=auth, request=request)
except ObjectAlreadyExistsException as exc:
raise HTTPException(
Expand All @@ -176,11 +182,6 @@ async def create_model_endpoint(
status_code=400,
detail=str(exc),
) from exc
except ObjectNotApprovedException as exc:
raise HTTPException(
status_code=403,
detail="The specified model bundle was not approved yet.",
) from exc
except (ObjectNotFoundException, ObjectNotAuthorizedException) as exc:
raise HTTPException(
status_code=404,
Expand Down Expand Up @@ -234,6 +235,63 @@ async def get_model_endpoint(
) from exc


@llm_router_v1.put(
"/model-endpoints/{model_endpoint_name}", response_model=UpdateLLMModelEndpointV1Response
)
async def update_model_endpoint(
model_endpoint_name: str,
request: UpdateLLMModelEndpointV1Request,
auth: User = Depends(verify_authentication),
external_interfaces: ExternalInterfaces = Depends(get_external_interfaces),
) -> UpdateLLMModelEndpointV1Response:
"""
Updates an LLM endpoint for the current user.
"""
logger.info(f"PUT /llm/model-endpoints/{model_endpoint_name} with {request} for {auth}")
try:
create_model_bundle_use_case = CreateModelBundleV2UseCase(
model_bundle_repository=external_interfaces.model_bundle_repository,
docker_repository=external_interfaces.docker_repository,
model_primitive_gateway=external_interfaces.model_primitive_gateway,
)
create_llm_model_bundle_use_case = CreateLLMModelBundleV1UseCase(
create_model_bundle_use_case=create_model_bundle_use_case,
model_bundle_repository=external_interfaces.model_bundle_repository,
llm_artifact_gateway=external_interfaces.llm_artifact_gateway,
docker_repository=external_interfaces.docker_repository,
)
use_case = UpdateLLMModelEndpointV1UseCase(
create_llm_model_bundle_use_case=create_llm_model_bundle_use_case,
model_endpoint_service=external_interfaces.model_endpoint_service,
llm_model_endpoint_service=external_interfaces.llm_model_endpoint_service,
)
return await use_case.execute(
user=auth, model_endpoint_name=model_endpoint_name, request=request
)
except EndpointLabelsException as exc:
raise HTTPException(
status_code=400,
detail=str(exc),
) from exc
except ObjectHasInvalidValueException as exc:
raise HTTPException(status_code=400, detail=str(exc))
except EndpointResourceInvalidRequestException as exc:
raise HTTPException(
status_code=400,
detail=str(exc),
) from exc
except (ObjectNotFoundException, ObjectNotAuthorizedException) as exc:
raise HTTPException(
status_code=404,
detail="The specified LLM endpoint could not be found.",
) from exc
except DockerImageNotFoundException as exc:
raise HTTPException(
status_code=404,
detail="The specified docker image could not be found.",
) from exc


@llm_router_v1.post("/completions-sync", response_model=CompletionSyncV1Response)
async def create_completion_sync_task(
model_endpoint_name: str,
Expand Down
11 changes: 0 additions & 11 deletions model-engine/model_engine_server/api/model_endpoints_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
ExistingEndpointOperationInProgressException,
ObjectAlreadyExistsException,
ObjectHasInvalidValueException,
ObjectNotApprovedException,
ObjectNotAuthorizedException,
ObjectNotFoundException,
)
Expand Down Expand Up @@ -80,11 +79,6 @@ async def create_model_endpoint(
status_code=400,
detail=str(exc),
) from exc
except ObjectNotApprovedException as exc:
raise HTTPException(
status_code=403,
detail="The specified model bundle was not approved yet.",
) from exc
except (ObjectNotFoundException, ObjectNotAuthorizedException) as exc:
raise HTTPException(
status_code=404,
Expand Down Expand Up @@ -154,11 +148,6 @@ async def update_model_endpoint(
return await use_case.execute(
user=auth, model_endpoint_id=model_endpoint_id, request=request
)
except ObjectNotApprovedException as exc:
raise HTTPException(
status_code=403,
detail="The specified model bundle was not approved yet.",
) from exc
except EndpointLabelsException as exc:
raise HTTPException(
status_code=400,
Expand Down
48 changes: 47 additions & 1 deletion model-engine/model_engine_server/common/dtos/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,14 +87,60 @@ class GetLLMModelEndpointV1Response(BaseModel):
inference_framework_image_tag: Optional[str] = None
num_shards: Optional[int] = None
quantize: Optional[Quantization] = None
checkpoint_path: Optional[str] = None
spec: Optional[GetModelEndpointV1Response] = None


class ListLLMModelEndpointsV1Response(BaseModel):
model_endpoints: List[GetLLMModelEndpointV1Response]


# Delete and update use the default Launch endpoint APIs.
class UpdateLLMModelEndpointV1Request(BaseModel):
# LLM specific fields
model_name: Optional[str]
source: Optional[LLMSource]
inference_framework_image_tag: Optional[str]
num_shards: Optional[int]
"""
Number of shards to distribute the model onto GPUs. Only affects behavior for text-generation-inference models
"""

quantize: Optional[Quantization]
"""
Whether to quantize the model. Only affect behavior for text-generation-inference models
"""

checkpoint_path: Optional[str]
"""
Path to the checkpoint to load the model from. Only affects behavior for text-generation-inference models
"""

# General endpoint fields
metadata: Optional[Dict[str, Any]]
post_inference_hooks: Optional[List[str]]
cpus: Optional[CpuSpecificationType]
gpus: Optional[int]
memory: Optional[StorageSpecificationType]
gpu_type: Optional[GpuType]
storage: Optional[StorageSpecificationType]
optimize_costs: Optional[bool]
min_workers: Optional[int]
max_workers: Optional[int]
per_worker: Optional[int]
labels: Optional[Dict[str, str]]
prewarm: Optional[bool]
high_priority: Optional[bool]
billing_tags: Optional[Dict[str, Any]]
default_callback_url: Optional[HttpUrl]
default_callback_auth: Optional[CallbackAuth]
public_inference: Optional[bool]


class UpdateLLMModelEndpointV1Response(BaseModel):
endpoint_creation_task_id: str
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we call this endpoint_update_task_id?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The non-LLM specific UpdateModelEndpointV1Response has endpoint_creation_task_id... do we want to keep these consistent? 🤔



# Delete uses the default Launch endpoint APIs.


class CompletionSyncV1Request(BaseModel):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,4 @@ class LLMMetadata:
inference_framework_image_tag: str
num_shards: int
quantize: Optional[Quantization] = None
checkpoint_path: Optional[str] = None
6 changes: 0 additions & 6 deletions model-engine/model_engine_server/domain/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,6 @@ class ObjectHasInvalidValueException(DomainException, ValueError):
"""


class ObjectNotApprovedException(DomainException):
"""
Thrown when a required object is not approved, e.g. for a Bundle in review.
"""


@dataclass
class DockerImageNotFoundException(DomainException):
"""
Expand Down
Loading