diff --git a/model-engine/model_engine_server/api/llms_v1.py b/model-engine/model_engine_server/api/llms_v1.py index 6fc2a1d15..79961e2d0 100644 --- a/model-engine/model_engine_server/api/llms_v1.py +++ b/model-engine/model_engine_server/api/llms_v1.py @@ -33,6 +33,8 @@ StreamError, StreamErrorContent, TokenUsage, + UpdateLLMModelEndpointV1Request, + UpdateLLMModelEndpointV1Response, ) from model_engine_server.common.dtos.model_endpoints import ModelEndpointOrderBy from model_engine_server.core.auth.authentication_repository import User @@ -54,7 +56,6 @@ LLMFineTuningQuotaReached, ObjectAlreadyExistsException, ObjectHasInvalidValueException, - ObjectNotApprovedException, ObjectNotAuthorizedException, ObjectNotFoundException, UpstreamServiceError, @@ -70,11 +71,13 @@ from model_engine_server.domain.use_cases.llm_model_endpoint_use_cases import ( CompletionStreamV1UseCase, CompletionSyncV1UseCase, + CreateLLMModelBundleV1UseCase, CreateLLMModelEndpointV1UseCase, DeleteLLMEndpointByNameUseCase, GetLLMModelEndpointByNameV1UseCase, ListLLMModelEndpointsV1UseCase, ModelDownloadV1UseCase, + UpdateLLMModelEndpointV1UseCase, ) from model_engine_server.domain.use_cases.model_bundle_use_cases import CreateModelBundleV2UseCase from sse_starlette.sse import EventSourceResponse @@ -151,13 +154,16 @@ async def create_model_endpoint( docker_repository=external_interfaces.docker_repository, model_primitive_gateway=external_interfaces.model_primitive_gateway, ) - use_case = CreateLLMModelEndpointV1UseCase( + create_llm_model_bundle_use_case = CreateLLMModelBundleV1UseCase( create_model_bundle_use_case=create_model_bundle_use_case, model_bundle_repository=external_interfaces.model_bundle_repository, - model_endpoint_service=external_interfaces.model_endpoint_service, llm_artifact_gateway=external_interfaces.llm_artifact_gateway, docker_repository=external_interfaces.docker_repository, ) + use_case = CreateLLMModelEndpointV1UseCase( + create_llm_model_bundle_use_case=create_llm_model_bundle_use_case, + model_endpoint_service=external_interfaces.model_endpoint_service, + ) return await use_case.execute(user=auth, request=request) except ObjectAlreadyExistsException as exc: raise HTTPException( @@ -176,11 +182,6 @@ async def create_model_endpoint( status_code=400, detail=str(exc), ) from exc - except ObjectNotApprovedException as exc: - raise HTTPException( - status_code=403, - detail="The specified model bundle was not approved yet.", - ) from exc except (ObjectNotFoundException, ObjectNotAuthorizedException) as exc: raise HTTPException( status_code=404, @@ -234,6 +235,63 @@ async def get_model_endpoint( ) from exc +@llm_router_v1.put( + "/model-endpoints/{model_endpoint_name}", response_model=UpdateLLMModelEndpointV1Response +) +async def update_model_endpoint( + model_endpoint_name: str, + request: UpdateLLMModelEndpointV1Request, + auth: User = Depends(verify_authentication), + external_interfaces: ExternalInterfaces = Depends(get_external_interfaces), +) -> UpdateLLMModelEndpointV1Response: + """ + Updates an LLM endpoint for the current user. + """ + logger.info(f"PUT /llm/model-endpoints/{model_endpoint_name} with {request} for {auth}") + try: + create_model_bundle_use_case = CreateModelBundleV2UseCase( + model_bundle_repository=external_interfaces.model_bundle_repository, + docker_repository=external_interfaces.docker_repository, + model_primitive_gateway=external_interfaces.model_primitive_gateway, + ) + create_llm_model_bundle_use_case = CreateLLMModelBundleV1UseCase( + create_model_bundle_use_case=create_model_bundle_use_case, + model_bundle_repository=external_interfaces.model_bundle_repository, + llm_artifact_gateway=external_interfaces.llm_artifact_gateway, + docker_repository=external_interfaces.docker_repository, + ) + use_case = UpdateLLMModelEndpointV1UseCase( + create_llm_model_bundle_use_case=create_llm_model_bundle_use_case, + model_endpoint_service=external_interfaces.model_endpoint_service, + llm_model_endpoint_service=external_interfaces.llm_model_endpoint_service, + ) + return await use_case.execute( + user=auth, model_endpoint_name=model_endpoint_name, request=request + ) + except EndpointLabelsException as exc: + raise HTTPException( + status_code=400, + detail=str(exc), + ) from exc + except ObjectHasInvalidValueException as exc: + raise HTTPException(status_code=400, detail=str(exc)) + except EndpointResourceInvalidRequestException as exc: + raise HTTPException( + status_code=400, + detail=str(exc), + ) from exc + except (ObjectNotFoundException, ObjectNotAuthorizedException) as exc: + raise HTTPException( + status_code=404, + detail="The specified LLM endpoint could not be found.", + ) from exc + except DockerImageNotFoundException as exc: + raise HTTPException( + status_code=404, + detail="The specified docker image could not be found.", + ) from exc + + @llm_router_v1.post("/completions-sync", response_model=CompletionSyncV1Response) async def create_completion_sync_task( model_endpoint_name: str, diff --git a/model-engine/model_engine_server/api/model_endpoints_v1.py b/model-engine/model_engine_server/api/model_endpoints_v1.py index 3b45f0714..807393cdc 100644 --- a/model-engine/model_engine_server/api/model_endpoints_v1.py +++ b/model-engine/model_engine_server/api/model_endpoints_v1.py @@ -31,7 +31,6 @@ ExistingEndpointOperationInProgressException, ObjectAlreadyExistsException, ObjectHasInvalidValueException, - ObjectNotApprovedException, ObjectNotAuthorizedException, ObjectNotFoundException, ) @@ -80,11 +79,6 @@ async def create_model_endpoint( status_code=400, detail=str(exc), ) from exc - except ObjectNotApprovedException as exc: - raise HTTPException( - status_code=403, - detail="The specified model bundle was not approved yet.", - ) from exc except (ObjectNotFoundException, ObjectNotAuthorizedException) as exc: raise HTTPException( status_code=404, @@ -154,11 +148,6 @@ async def update_model_endpoint( return await use_case.execute( user=auth, model_endpoint_id=model_endpoint_id, request=request ) - except ObjectNotApprovedException as exc: - raise HTTPException( - status_code=403, - detail="The specified model bundle was not approved yet.", - ) from exc except EndpointLabelsException as exc: raise HTTPException( status_code=400, diff --git a/model-engine/model_engine_server/common/dtos/llms.py b/model-engine/model_engine_server/common/dtos/llms.py index dd2e06a04..346c9ae29 100644 --- a/model-engine/model_engine_server/common/dtos/llms.py +++ b/model-engine/model_engine_server/common/dtos/llms.py @@ -87,6 +87,7 @@ class GetLLMModelEndpointV1Response(BaseModel): inference_framework_image_tag: Optional[str] = None num_shards: Optional[int] = None quantize: Optional[Quantization] = None + checkpoint_path: Optional[str] = None spec: Optional[GetModelEndpointV1Response] = None @@ -94,7 +95,52 @@ class ListLLMModelEndpointsV1Response(BaseModel): model_endpoints: List[GetLLMModelEndpointV1Response] -# Delete and update use the default Launch endpoint APIs. +class UpdateLLMModelEndpointV1Request(BaseModel): + # LLM specific fields + model_name: Optional[str] + source: Optional[LLMSource] + inference_framework_image_tag: Optional[str] + num_shards: Optional[int] + """ + Number of shards to distribute the model onto GPUs. Only affects behavior for text-generation-inference models + """ + + quantize: Optional[Quantization] + """ + Whether to quantize the model. Only affect behavior for text-generation-inference models + """ + + checkpoint_path: Optional[str] + """ + Path to the checkpoint to load the model from. Only affects behavior for text-generation-inference models + """ + + # General endpoint fields + metadata: Optional[Dict[str, Any]] + post_inference_hooks: Optional[List[str]] + cpus: Optional[CpuSpecificationType] + gpus: Optional[int] + memory: Optional[StorageSpecificationType] + gpu_type: Optional[GpuType] + storage: Optional[StorageSpecificationType] + optimize_costs: Optional[bool] + min_workers: Optional[int] + max_workers: Optional[int] + per_worker: Optional[int] + labels: Optional[Dict[str, str]] + prewarm: Optional[bool] + high_priority: Optional[bool] + billing_tags: Optional[Dict[str, Any]] + default_callback_url: Optional[HttpUrl] + default_callback_auth: Optional[CallbackAuth] + public_inference: Optional[bool] + + +class UpdateLLMModelEndpointV1Response(BaseModel): + endpoint_creation_task_id: str + + +# Delete uses the default Launch endpoint APIs. class CompletionSyncV1Request(BaseModel): diff --git a/model-engine/model_engine_server/domain/entities/llm_entity.py b/model-engine/model_engine_server/domain/entities/llm_entity.py index 30ec89933..4da8c2787 100644 --- a/model-engine/model_engine_server/domain/entities/llm_entity.py +++ b/model-engine/model_engine_server/domain/entities/llm_entity.py @@ -28,3 +28,4 @@ class LLMMetadata: inference_framework_image_tag: str num_shards: int quantize: Optional[Quantization] = None + checkpoint_path: Optional[str] = None diff --git a/model-engine/model_engine_server/domain/exceptions.py b/model-engine/model_engine_server/domain/exceptions.py index 934a5e215..b78bb281c 100644 --- a/model-engine/model_engine_server/domain/exceptions.py +++ b/model-engine/model_engine_server/domain/exceptions.py @@ -31,12 +31,6 @@ class ObjectHasInvalidValueException(DomainException, ValueError): """ -class ObjectNotApprovedException(DomainException): - """ - Thrown when a required object is not approved, e.g. for a Bundle in review. - """ - - @dataclass class DockerImageNotFoundException(DomainException): """ diff --git a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py index d379aac20..af14bf06a 100644 --- a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py @@ -27,6 +27,8 @@ ModelDownloadRequest, ModelDownloadResponse, TokenOutput, + UpdateLLMModelEndpointV1Request, + UpdateLLMModelEndpointV1Response, ) from model_engine_server.common.dtos.model_bundles import CreateModelBundleV2Request from model_engine_server.common.dtos.model_endpoints import ModelEndpointOrderBy @@ -48,6 +50,7 @@ ) from model_engine_server.domain.exceptions import ( DockerImageNotFoundException, + EndpointInfraStateNotFound, EndpointLabelsException, EndpointUnsupportedInferenceTypeException, InvalidRequestException, @@ -70,6 +73,7 @@ from ..authorization.live_authorization_module import LiveAuthorizationModule from .model_bundle_use_cases import CreateModelBundleV2UseCase from .model_endpoint_use_cases import ( + CONVERTED_FROM_ARTIFACT_LIKE_KEY, _handle_post_inference_hooks, model_endpoint_entity_to_get_model_endpoint_response, validate_billing_tags, @@ -237,6 +241,7 @@ def _model_endpoint_entity_to_get_llm_model_endpoint_response( inference_framework_image_tag=llm_metadata["inference_framework_image_tag"], num_shards=llm_metadata["num_shards"], quantize=llm_metadata.get("quantize"), + checkpoint_path=llm_metadata.get("checkpoint_path"), spec=model_endpoint_entity_to_get_model_endpoint_response(model_endpoint), ) return response @@ -274,19 +279,17 @@ def validate_quantization( ) -class CreateLLMModelEndpointV1UseCase: +class CreateLLMModelBundleV1UseCase: def __init__( self, create_model_bundle_use_case: CreateModelBundleV2UseCase, model_bundle_repository: ModelBundleRepository, - model_endpoint_service: ModelEndpointService, llm_artifact_gateway: LLMArtifactGateway, docker_repository: DockerRepository, ): self.authz_module = LiveAuthorizationModule() self.create_model_bundle_use_case = create_model_bundle_use_case self.model_bundle_repository = model_bundle_repository - self.model_endpoint_service = model_endpoint_service self.llm_artifact_gateway = llm_artifact_gateway self.docker_repository = docker_repository @@ -302,7 +305,7 @@ def check_docker_image_exists_for_image_tag( tag=framework_image_tag, ) - async def create_model_bundle( + async def execute( self, user: User, endpoint_name: str, @@ -840,6 +843,17 @@ async def create_tensorrt_llm_bundle( ) ).model_bundle_id + +class CreateLLMModelEndpointV1UseCase: + def __init__( + self, + create_llm_model_bundle_use_case: CreateLLMModelBundleV1UseCase, + model_endpoint_service: ModelEndpointService, + ): + self.authz_module = LiveAuthorizationModule() + self.create_llm_model_bundle_use_case = create_llm_model_bundle_use_case + self.model_endpoint_service = model_endpoint_service + async def execute( self, user: User, request: CreateLLMModelEndpointV1Request ) -> CreateLLMModelEndpointV1Response: @@ -865,10 +879,10 @@ async def execute( ]: if request.endpoint_type != ModelEndpointType.STREAMING: raise ObjectHasInvalidValueException( - f"Creating endpoint type {str(request.endpoint_type)} is not allowed. Can only create streaming endpoints for text-generation-inference, vLLM and LightLLM." + f"Creating endpoint type {str(request.endpoint_type)} is not allowed. Can only create streaming endpoints for text-generation-inference, vLLM, LightLLM, and TensorRT-LLM." ) - bundle = await self.create_model_bundle( + bundle = await self.create_llm_model_bundle_use_case.execute( user, endpoint_name=request.name, model_name=request.model_name, @@ -908,6 +922,7 @@ async def execute( inference_framework_image_tag=request.inference_framework_image_tag, num_shards=request.num_shards, quantize=request.quantize, + checkpoint_path=request.checkpoint_path, ) ) @@ -1025,6 +1040,158 @@ async def execute(self, user: User, model_endpoint_name: str) -> GetLLMModelEndp return _model_endpoint_entity_to_get_llm_model_endpoint_response(model_endpoint) +class UpdateLLMModelEndpointV1UseCase: + def __init__( + self, + create_llm_model_bundle_use_case: CreateLLMModelBundleV1UseCase, + model_endpoint_service: ModelEndpointService, + llm_model_endpoint_service: LLMModelEndpointService, + ): + self.authz_module = LiveAuthorizationModule() + self.create_llm_model_bundle_use_case = create_llm_model_bundle_use_case + self.model_endpoint_service = model_endpoint_service + self.llm_model_endpoint_service = llm_model_endpoint_service + + async def execute( + self, user: User, model_endpoint_name: str, request: UpdateLLMModelEndpointV1Request + ) -> UpdateLLMModelEndpointV1Response: + if request.labels is not None: + validate_labels(request.labels) + validate_billing_tags(request.billing_tags) + validate_post_inference_hooks(user, request.post_inference_hooks) + + model_endpoint = await self.llm_model_endpoint_service.get_llm_model_endpoint( + model_endpoint_name + ) + if not model_endpoint: + raise ObjectNotFoundException + if not self.authz_module.check_access_write_owned_entity(user, model_endpoint.record): + raise ObjectNotAuthorizedException + + endpoint_record = model_endpoint.record + model_endpoint_id = endpoint_record.id + bundle = endpoint_record.current_model_bundle + + # TODO: We may want to consider what happens if an endpoint gets stuck in UPDATE_PENDING + # on first creating it, and we need to find a way to get it unstuck. This would end up + # causing endpoint.infra_state to be None. + if model_endpoint.infra_state is None: + error_msg = f"Endpoint infra state not found for {model_endpoint_name=}" + logger.error(error_msg) + raise EndpointInfraStateNotFound(error_msg) + + infra_state = model_endpoint.infra_state + + if ( + request.model_name + or request.source + or request.inference_framework_image_tag + or request.num_shards + or request.quantize + or request.checkpoint_path + ): + llm_metadata = (model_endpoint.record.metadata or {}).get("_llm", {}) + inference_framework = llm_metadata["inference_framework"] + + model_name = request.model_name or llm_metadata["model_name"] + source = request.source or llm_metadata["source"] + inference_framework_image_tag = ( + request.inference_framework_image_tag + or llm_metadata["inference_framework_image_tag"] + ) + num_shards = request.num_shards or llm_metadata["num_shards"] + quantize = request.quantize or llm_metadata.get("quantize") + checkpoint_path = request.checkpoint_path or llm_metadata.get("checkpoint_path") + + validate_model_name(model_name, inference_framework) + validate_num_shards( + num_shards, inference_framework, request.gpus or infra_state.resource_state.gpus + ) + validate_quantization(quantize, inference_framework) + + bundle = await self.create_llm_model_bundle_use_case.execute( + user, + endpoint_name=model_endpoint_name, + model_name=model_name, + source=source, + framework=inference_framework, + framework_image_tag=inference_framework_image_tag, + endpoint_type=endpoint_record.endpoint_type, + num_shards=num_shards, + quantize=quantize, + checkpoint_path=checkpoint_path, + ) + + metadata = endpoint_record.metadata or {} + metadata["_llm"] = asdict( + LLMMetadata( + model_name=model_name, + source=source, + inference_framework=inference_framework, + inference_framework_image_tag=inference_framework_image_tag, + num_shards=num_shards, + quantize=quantize, + checkpoint_path=checkpoint_path, + ) + ) + request.metadata = metadata + + # For resources that are not specified in the update endpoint request, pass in resource from + # infra_state to make sure that after the update, all resources are valid and in sync. + # E.g. If user only want to update gpus and leave gpu_type as None, we use the existing gpu_type + # from infra_state to avoid passing in None to validate_resource_requests. + validate_resource_requests( + bundle=bundle, + cpus=request.cpus or infra_state.resource_state.cpus, + memory=request.memory or infra_state.resource_state.memory, + storage=request.storage or infra_state.resource_state.storage, + gpus=request.gpus or infra_state.resource_state.gpus, + gpu_type=request.gpu_type or infra_state.resource_state.gpu_type, + ) + + validate_deployment_resources( + min_workers=request.min_workers, + max_workers=request.max_workers, + endpoint_type=endpoint_record.endpoint_type, + ) + + if request.metadata is not None and CONVERTED_FROM_ARTIFACT_LIKE_KEY in request.metadata: + raise ObjectHasInvalidValueException( + f"{CONVERTED_FROM_ARTIFACT_LIKE_KEY} is a reserved metadata key and cannot be used by user." + ) + + updated_endpoint_record = await self.model_endpoint_service.update_model_endpoint( + model_endpoint_id=model_endpoint_id, + model_bundle_id=bundle.id, + metadata=request.metadata, + post_inference_hooks=request.post_inference_hooks, + cpus=request.cpus, + gpus=request.gpus, + memory=request.memory, + gpu_type=request.gpu_type, + storage=request.storage, + optimize_costs=request.optimize_costs, + min_workers=request.min_workers, + max_workers=request.max_workers, + per_worker=request.per_worker, + labels=request.labels, + prewarm=request.prewarm, + high_priority=request.high_priority, + default_callback_url=request.default_callback_url, + default_callback_auth=request.default_callback_auth, + public_inference=request.public_inference, + ) + _handle_post_inference_hooks( + created_by=endpoint_record.created_by, + name=updated_endpoint_record.name, + post_inference_hooks=request.post_inference_hooks, + ) + + return UpdateLLMModelEndpointV1Response( + endpoint_creation_task_id=updated_endpoint_record.creation_task_id # type: ignore + ) + + class DeleteLLMEndpointByNameUseCase: """ Use case for deleting an LLM Model Endpoint of a given user by endpoint name. diff --git a/model-engine/tests/unit/domain/conftest.py b/model-engine/tests/unit/domain/conftest.py index 06310666e..f433071cc 100644 --- a/model-engine/tests/unit/domain/conftest.py +++ b/model-engine/tests/unit/domain/conftest.py @@ -7,6 +7,7 @@ CompletionStreamV1Request, CompletionSyncV1Request, CreateLLMModelEndpointV1Request, + UpdateLLMModelEndpointV1Request, ) from model_engine_server.common.dtos.model_bundles import ( CreateModelBundleV1Request, @@ -218,6 +219,7 @@ def create_llm_model_endpoint_request_async() -> CreateLLMModelEndpointV1Request labels={"team": "infra", "product": "my_product"}, aws_role="test_aws_role", results_s3_bucket="test_s3_bucket", + checkpoint_path="s3://test_checkpoint_path", ) @@ -247,6 +249,16 @@ def create_llm_model_endpoint_request_streaming() -> CreateLLMModelEndpointV1Req ) +@pytest.fixture +def update_llm_model_endpoint_request() -> UpdateLLMModelEndpointV1Request: + return UpdateLLMModelEndpointV1Request( + checkpoint_path="s3://test_checkpoint_path", + memory="4G", + min_workers=0, + max_workers=1, + ) + + @pytest.fixture def create_llm_model_endpoint_request_llama_2() -> CreateLLMModelEndpointV1Request: return CreateLLMModelEndpointV1Request( diff --git a/model-engine/tests/unit/domain/test_llm_use_cases.py b/model-engine/tests/unit/domain/test_llm_use_cases.py index 589e453b5..c4fbb31f6 100644 --- a/model-engine/tests/unit/domain/test_llm_use_cases.py +++ b/model-engine/tests/unit/domain/test_llm_use_cases.py @@ -11,6 +11,7 @@ CreateLLMModelEndpointV1Response, ModelDownloadRequest, TokenOutput, + UpdateLLMModelEndpointV1Request, ) from model_engine_server.common.dtos.tasks import SyncEndpointPredictV1Response, TaskStatus from model_engine_server.core.auth.authentication_repository import User @@ -38,10 +39,12 @@ from model_engine_server.domain.use_cases.llm_model_endpoint_use_cases import ( CompletionStreamV1UseCase, CompletionSyncV1UseCase, + CreateLLMModelBundleV1UseCase, CreateLLMModelEndpointV1UseCase, DeleteLLMEndpointByNameUseCase, GetLLMModelEndpointByNameV1UseCase, ModelDownloadV1UseCase, + UpdateLLMModelEndpointV1UseCase, _include_safetensors_bin_or_pt, ) from model_engine_server.domain.use_cases.model_bundle_use_cases import CreateModelBundleV2UseCase @@ -66,13 +69,17 @@ async def test_create_model_endpoint_use_case_success( docker_repository=fake_docker_repository_image_always_exists, model_primitive_gateway=fake_model_primitive_gateway, ) - use_case = CreateLLMModelEndpointV1UseCase( + llm_bundle_use_case = CreateLLMModelBundleV1UseCase( create_model_bundle_use_case=bundle_use_case, model_bundle_repository=fake_model_bundle_repository, - model_endpoint_service=fake_model_endpoint_service, llm_artifact_gateway=fake_llm_artifact_gateway, docker_repository=fake_docker_repository_image_always_exists, ) + use_case = CreateLLMModelEndpointV1UseCase( + create_llm_model_bundle_use_case=llm_bundle_use_case, + model_endpoint_service=fake_model_endpoint_service, + ) + user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) response_1 = await use_case.execute(user=user, request=create_llm_model_endpoint_request_async) assert response_1.endpoint_creation_task_id @@ -93,6 +100,7 @@ async def test_create_model_endpoint_use_case_success( "inference_framework_image_tag": create_llm_model_endpoint_request_async.inference_framework_image_tag, "num_shards": create_llm_model_endpoint_request_async.num_shards, "quantize": None, + "checkpoint_path": create_llm_model_endpoint_request_async.checkpoint_path, } } @@ -115,6 +123,7 @@ async def test_create_model_endpoint_use_case_success( "inference_framework_image_tag": create_llm_model_endpoint_request_sync.inference_framework_image_tag, "num_shards": create_llm_model_endpoint_request_sync.num_shards, "quantize": None, + "checkpoint_path": None, } } @@ -139,6 +148,7 @@ async def test_create_model_endpoint_use_case_success( "inference_framework_image_tag": create_llm_model_endpoint_request_streaming.inference_framework_image_tag, "num_shards": create_llm_model_endpoint_request_streaming.num_shards, "quantize": None, + "checkpoint_path": None, } } @@ -182,14 +192,16 @@ async def test_create_model_bundle_inference_framework_image_tag_validation( docker_repository=fake_docker_repository_image_always_exists, model_primitive_gateway=fake_model_primitive_gateway, ) - - use_case = CreateLLMModelEndpointV1UseCase( + llm_bundle_use_case = CreateLLMModelBundleV1UseCase( create_model_bundle_use_case=bundle_use_case, model_bundle_repository=fake_model_bundle_repository, - model_endpoint_service=fake_model_endpoint_service, llm_artifact_gateway=fake_llm_artifact_gateway, docker_repository=fake_docker_repository_image_always_exists, ) + use_case = CreateLLMModelEndpointV1UseCase( + create_llm_model_bundle_use_case=llm_bundle_use_case, + model_endpoint_service=fake_model_endpoint_service, + ) request = create_llm_model_endpoint_text_generation_inference_request_streaming.copy() request.inference_framework = inference_framework @@ -198,7 +210,7 @@ async def test_create_model_bundle_inference_framework_image_tag_validation( if valid: await use_case.execute(user=user, request=request) else: - use_case.docker_repository = fake_docker_repository_image_never_exists + llm_bundle_use_case.docker_repository = fake_docker_repository_image_never_exists with pytest.raises(DockerImageNotFoundException): await use_case.execute(user=user, request=request) @@ -220,13 +232,16 @@ async def test_create_model_endpoint_text_generation_inference_use_case_success( docker_repository=fake_docker_repository_image_always_exists, model_primitive_gateway=fake_model_primitive_gateway, ) - use_case = CreateLLMModelEndpointV1UseCase( + llm_bundle_use_case = CreateLLMModelBundleV1UseCase( create_model_bundle_use_case=bundle_use_case, model_bundle_repository=fake_model_bundle_repository, - model_endpoint_service=fake_model_endpoint_service, llm_artifact_gateway=fake_llm_artifact_gateway, docker_repository=fake_docker_repository_image_always_exists, ) + use_case = CreateLLMModelEndpointV1UseCase( + create_llm_model_bundle_use_case=llm_bundle_use_case, + model_endpoint_service=fake_model_endpoint_service, + ) user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) response_1 = await use_case.execute( user=user, @@ -250,6 +265,7 @@ async def test_create_model_endpoint_text_generation_inference_use_case_success( "inference_framework_image_tag": create_llm_model_endpoint_text_generation_inference_request_streaming.inference_framework_image_tag, "num_shards": create_llm_model_endpoint_text_generation_inference_request_streaming.num_shards, "quantize": create_llm_model_endpoint_text_generation_inference_request_streaming.quantize, + "checkpoint_path": create_llm_model_endpoint_text_generation_inference_request_streaming.checkpoint_path, } } @@ -277,13 +293,16 @@ async def test_create_model_endpoint_trt_llm_use_case_success( docker_repository=fake_docker_repository_image_always_exists, model_primitive_gateway=fake_model_primitive_gateway, ) - use_case = CreateLLMModelEndpointV1UseCase( + llm_bundle_use_case = CreateLLMModelBundleV1UseCase( create_model_bundle_use_case=bundle_use_case, model_bundle_repository=fake_model_bundle_repository, - model_endpoint_service=fake_model_endpoint_service, llm_artifact_gateway=fake_llm_artifact_gateway, docker_repository=fake_docker_repository_image_always_exists, ) + use_case = CreateLLMModelEndpointV1UseCase( + create_llm_model_bundle_use_case=llm_bundle_use_case, + model_endpoint_service=fake_model_endpoint_service, + ) user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) response_1 = await use_case.execute( user=user, @@ -307,6 +326,7 @@ async def test_create_model_endpoint_trt_llm_use_case_success( "inference_framework_image_tag": create_llm_model_endpoint_trt_llm_request_streaming.inference_framework_image_tag, "num_shards": create_llm_model_endpoint_trt_llm_request_streaming.num_shards, "quantize": create_llm_model_endpoint_trt_llm_request_streaming.quantize, + "checkpoint_path": create_llm_model_endpoint_trt_llm_request_streaming.checkpoint_path, } } @@ -333,13 +353,16 @@ async def test_create_llm_model_endpoint_use_case_raises_invalid_value_exception docker_repository=fake_docker_repository_image_always_exists, model_primitive_gateway=fake_model_primitive_gateway, ) - use_case = CreateLLMModelEndpointV1UseCase( + llm_bundle_use_case = CreateLLMModelBundleV1UseCase( create_model_bundle_use_case=bundle_use_case, model_bundle_repository=fake_model_bundle_repository, - model_endpoint_service=fake_model_endpoint_service, llm_artifact_gateway=fake_llm_artifact_gateway, docker_repository=fake_docker_repository_image_always_exists, ) + use_case = CreateLLMModelEndpointV1UseCase( + create_llm_model_bundle_use_case=llm_bundle_use_case, + model_endpoint_service=fake_model_endpoint_service, + ) user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) with pytest.raises(ObjectHasInvalidValueException): await use_case.execute( @@ -363,13 +386,16 @@ async def test_create_llm_model_endpoint_use_case_quantization_exception( docker_repository=fake_docker_repository_image_always_exists, model_primitive_gateway=fake_model_primitive_gateway, ) - use_case = CreateLLMModelEndpointV1UseCase( + llm_bundle_use_case = CreateLLMModelBundleV1UseCase( create_model_bundle_use_case=bundle_use_case, model_bundle_repository=fake_model_bundle_repository, - model_endpoint_service=fake_model_endpoint_service, llm_artifact_gateway=fake_llm_artifact_gateway, docker_repository=fake_docker_repository_image_always_exists, ) + use_case = CreateLLMModelEndpointV1UseCase( + create_llm_model_bundle_use_case=llm_bundle_use_case, + model_endpoint_service=fake_model_endpoint_service, + ) user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) with pytest.raises(ObjectHasInvalidValueException): await use_case.execute( @@ -410,6 +436,88 @@ async def test_get_llm_model_endpoint_use_case_raises_not_authorized( ) +@pytest.mark.asyncio +async def test_update_model_endpoint_use_case_success( + test_api_key: str, + fake_model_bundle_repository, + fake_model_endpoint_service, + fake_docker_repository_image_always_exists, + fake_model_primitive_gateway, + fake_llm_artifact_gateway, + fake_llm_model_endpoint_service, + create_llm_model_endpoint_request_streaming: CreateLLMModelEndpointV1Request, + update_llm_model_endpoint_request: UpdateLLMModelEndpointV1Request, +): + fake_model_endpoint_service.model_bundle_repository = fake_model_bundle_repository + bundle_use_case = CreateModelBundleV2UseCase( + model_bundle_repository=fake_model_bundle_repository, + docker_repository=fake_docker_repository_image_always_exists, + model_primitive_gateway=fake_model_primitive_gateway, + ) + llm_bundle_use_case = CreateLLMModelBundleV1UseCase( + create_model_bundle_use_case=bundle_use_case, + model_bundle_repository=fake_model_bundle_repository, + llm_artifact_gateway=fake_llm_artifact_gateway, + docker_repository=fake_docker_repository_image_always_exists, + ) + create_use_case = CreateLLMModelEndpointV1UseCase( + create_llm_model_bundle_use_case=llm_bundle_use_case, + model_endpoint_service=fake_model_endpoint_service, + ) + update_use_case = UpdateLLMModelEndpointV1UseCase( + create_llm_model_bundle_use_case=llm_bundle_use_case, + model_endpoint_service=fake_model_endpoint_service, + llm_model_endpoint_service=fake_llm_model_endpoint_service, + ) + + user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) + + await create_use_case.execute(user=user, request=create_llm_model_endpoint_request_streaming) + endpoint = ( + await fake_model_endpoint_service.list_model_endpoints( + owner=None, + name=create_llm_model_endpoint_request_streaming.name, + order_by=None, + ) + )[0] + fake_llm_model_endpoint_service.add_model_endpoint(endpoint) + + update_response = await update_use_case.execute( + user=user, + model_endpoint_name=create_llm_model_endpoint_request_streaming.name, + request=update_llm_model_endpoint_request, + ) + assert update_response.endpoint_creation_task_id + endpoint = ( + await fake_model_endpoint_service.list_model_endpoints( + owner=None, + name=create_llm_model_endpoint_request_streaming.name, + order_by=None, + ) + )[0] + assert endpoint.record.endpoint_type == ModelEndpointType.STREAMING + assert endpoint.record.metadata == { + "_llm": { + "model_name": create_llm_model_endpoint_request_streaming.model_name, + "source": create_llm_model_endpoint_request_streaming.source, + "inference_framework": create_llm_model_endpoint_request_streaming.inference_framework, + "inference_framework_image_tag": create_llm_model_endpoint_request_streaming.inference_framework_image_tag, + "num_shards": create_llm_model_endpoint_request_streaming.num_shards, + "quantize": None, + "checkpoint_path": update_llm_model_endpoint_request.checkpoint_path, + } + } + assert endpoint.infra_state.resource_state.memory == update_llm_model_endpoint_request.memory + assert ( + endpoint.infra_state.deployment_state.min_workers + == update_llm_model_endpoint_request.min_workers + ) + assert ( + endpoint.infra_state.deployment_state.max_workers + == update_llm_model_endpoint_request.max_workers + ) + + def mocked_auto_tokenizer_from_pretrained(*args, **kwargs): # noqa class mocked_encode: def encode(self, input: str) -> List[Any]: # noqa