Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions model-engine/model_engine_server/api/llms_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
make_logger,
)
from model_engine_server.domain.exceptions import (
DockerImageNotFoundException,
EndpointDeleteFailedException,
EndpointLabelsException,
EndpointResourceInvalidRequestException,
Expand Down Expand Up @@ -144,6 +145,7 @@ async def create_model_endpoint(
model_bundle_repository=external_interfaces.model_bundle_repository,
model_endpoint_service=external_interfaces.model_endpoint_service,
llm_artifact_gateway=external_interfaces.llm_artifact_gateway,
docker_repository=external_interfaces.docker_repository,
)
return await use_case.execute(user=auth, request=request)
except ObjectAlreadyExistsException as exc:
Expand Down Expand Up @@ -173,6 +175,11 @@ async def create_model_endpoint(
status_code=404,
detail="The specified model bundle could not be found.",
) from exc
except DockerImageNotFoundException as exc:
raise HTTPException(
status_code=404,
detail="The specified docker image could not be found.",
) from exc


@llm_router_v1.get("/model-endpoints", response_model=ListLLMModelEndpointsV1Response)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
StreamingEnhancedRunnableImageFlavor,
)
from model_engine_server.domain.exceptions import (
DockerImageNotFoundException,
EndpointLabelsException,
EndpointUnsupportedInferenceTypeException,
InvalidRequestException,
Expand All @@ -57,6 +58,7 @@
)
from model_engine_server.domain.gateways.llm_artifact_gateway import LLMArtifactGateway
from model_engine_server.domain.repositories import ModelBundleRepository
from model_engine_server.domain.repositories.docker_repository import DockerRepository
from model_engine_server.domain.services import LLMModelEndpointService, ModelEndpointService
from model_engine_server.infra.gateways.filesystem_gateway import FilesystemGateway

Expand Down Expand Up @@ -254,12 +256,26 @@ def __init__(
model_bundle_repository: ModelBundleRepository,
model_endpoint_service: ModelEndpointService,
llm_artifact_gateway: LLMArtifactGateway,
docker_repository: DockerRepository,
):
self.authz_module = LiveAuthorizationModule()
self.create_model_bundle_use_case = create_model_bundle_use_case
self.model_bundle_repository = model_bundle_repository
self.model_endpoint_service = model_endpoint_service
self.llm_artifact_gateway = llm_artifact_gateway
self.docker_repository = docker_repository

def check_docker_image_exists_for_image_tag(
self, framework_image_tag: str, repository_name: str
):
if not self.docker_repository.image_exists(
image_tag=framework_image_tag,
repository_name=repository_name,
):
raise DockerImageNotFoundException(
repository=repository_name,
tag=framework_image_tag,
)

async def create_model_bundle(
self,
Expand All @@ -276,6 +292,7 @@ async def create_model_bundle(
) -> ModelBundle:
if source == LLMSource.HUGGING_FACE:
if framework == LLMInferenceFramework.DEEPSPEED:
self.check_docker_image_exists_for_image_tag(framework_image_tag, "instant-llm")
bundle_id = await self.create_deepspeed_bundle(
user,
model_name,
Expand All @@ -284,6 +301,9 @@ async def create_model_bundle(
endpoint_name,
)
elif framework == LLMInferenceFramework.TEXT_GENERATION_INFERENCE:
self.check_docker_image_exists_for_image_tag(
framework_image_tag, hmi_config.tgi_repository
)
bundle_id = await self.create_text_generation_inference_bundle(
user,
model_name,
Expand All @@ -294,6 +314,9 @@ async def create_model_bundle(
checkpoint_path,
)
elif framework == LLMInferenceFramework.VLLM:
self.check_docker_image_exists_for_image_tag(
framework_image_tag, hmi_config.vllm_repository
)
bundle_id = await self.create_vllm_bundle(
user,
model_name,
Expand All @@ -304,6 +327,9 @@ async def create_model_bundle(
checkpoint_path,
)
elif framework == LLMInferenceFramework.LIGHTLLM:
self.check_docker_image_exists_for_image_tag(
framework_image_tag, hmi_config.lightllm_repository
)
bundle_id = await self.create_lightllm_bundle(
user,
model_name,
Expand Down Expand Up @@ -713,7 +739,6 @@ async def execute(
if request.inference_framework in [
LLMInferenceFramework.TEXT_GENERATION_INFERENCE,
LLMInferenceFramework.VLLM,
LLMInferenceFramework.LIGHTLLM,
]:
if request.endpoint_type != ModelEndpointType.STREAMING:
raise ObjectHasInvalidValueException(
Expand Down Expand Up @@ -952,10 +977,7 @@ def validate_and_update_completion_params(
if inference_framework == LLMInferenceFramework.TEXT_GENERATION_INFERENCE:
request.top_k = None if request.top_k == -1 else request.top_k
request.top_p = None if request.top_p == 1.0 else request.top_p
if inference_framework in [
LLMInferenceFramework.VLLM,
LLMInferenceFramework.LIGHTLLM,
]:
if inference_framework in [LLMInferenceFramework.VLLM, LLMInferenceFramework.LIGHTLLM]:
request.top_k = -1 if request.top_k is None else request.top_k
request.top_p = 1.0 if request.top_p is None else request.top_p
else:
Expand All @@ -965,10 +987,7 @@ def validate_and_update_completion_params(
)

# presence_penalty, frequency_penalty
if inference_framework in [
LLMInferenceFramework.VLLM,
LLMInferenceFramework.LIGHTLLM,
]:
if inference_framework in [LLMInferenceFramework.VLLM, LLMInferenceFramework.LIGHTLLM]:
request.presence_penalty = (
0.0 if request.presence_penalty is None else request.presence_penalty
)
Expand Down Expand Up @@ -1005,7 +1024,6 @@ def model_output_to_completion_output(
with_token_probs: Optional[bool],
) -> CompletionOutput:
model_content = _model_endpoint_entity_to_get_llm_model_endpoint_response(model_endpoint)

if model_content.inference_framework == LLMInferenceFramework.DEEPSPEED:
completion_token_count = len(model_output["token_probs"]["tokens"])
tokens = None
Expand Down Expand Up @@ -1043,10 +1061,7 @@ def model_output_to_completion_output(
tokens = None
if with_token_probs:
tokens = [
TokenOutput(
token=model_output["tokens"][index],
log_prob=list(t.values())[0],
)
TokenOutput(token=model_output["tokens"][index], log_prob=list(t.values())[0])
Comment thread
tiffzhao5 marked this conversation as resolved.
for index, t in enumerate(model_output["log_probs"])
]
return CompletionOutput(
Expand Down Expand Up @@ -1160,8 +1175,7 @@ async def execute(
timeout_seconds=DOWNSTREAM_REQUEST_TIMEOUT_SECONDS,
)
predict_result = await inference_gateway.predict(
topic=model_endpoint.record.destination,
predict_request=inference_request,
topic=model_endpoint.record.destination, predict_request=inference_request
)

if predict_result.status == TaskStatus.SUCCESS and predict_result.result is not None:
Expand Down Expand Up @@ -1204,8 +1218,7 @@ async def execute(
timeout_seconds=DOWNSTREAM_REQUEST_TIMEOUT_SECONDS,
)
predict_result = await inference_gateway.predict(
topic=model_endpoint.record.destination,
predict_request=inference_request,
topic=model_endpoint.record.destination, predict_request=inference_request
)

if predict_result.status != TaskStatus.SUCCESS or predict_result.result is None:
Expand Down Expand Up @@ -1244,8 +1257,7 @@ async def execute(
timeout_seconds=DOWNSTREAM_REQUEST_TIMEOUT_SECONDS,
)
predict_result = await inference_gateway.predict(
topic=model_endpoint.record.destination,
predict_request=inference_request,
topic=model_endpoint.record.destination, predict_request=inference_request
)

if predict_result.status != TaskStatus.SUCCESS or predict_result.result is None:
Expand Down Expand Up @@ -1287,8 +1299,7 @@ async def execute(
timeout_seconds=DOWNSTREAM_REQUEST_TIMEOUT_SECONDS,
)
predict_result = await inference_gateway.predict(
topic=model_endpoint.record.destination,
predict_request=inference_request,
topic=model_endpoint.record.destination, predict_request=inference_request
)

if predict_result.status != TaskStatus.SUCCESS or predict_result.result is None:
Expand Down
8 changes: 4 additions & 4 deletions model-engine/tests/unit/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3690,7 +3690,7 @@ def llm_model_endpoint_sync_tgi(
"model_name": "llama-7b",
"source": "hugging_face",
"inference_framework": "text_generation_inference",
"inference_framework_image_tag": "123",
"inference_framework_image_tag": "0.9.4",
"num_shards": 4,
}
},
Expand Down Expand Up @@ -3752,7 +3752,7 @@ def llm_model_endpoint_sync_tgi(
"source": "hugging_face",
"status": "READY",
"inference_framework": "text_generation_inference",
"inference_framework_image_tag": "123",
"inference_framework_image_tag": "0.9.4",
"num_shards": 4,
"spec": {
"id": "test_llm_model_endpoint_id_2",
Expand All @@ -3765,7 +3765,7 @@ def llm_model_endpoint_sync_tgi(
"model_name": "llama-7b",
"source": "hugging_face",
"inference_framework": "text_generation_inference",
"inference_framework_image_tag": "123",
"inference_framework_image_tag": "0.9.4",
"num_shards": 4,
}
},
Expand Down Expand Up @@ -3887,7 +3887,7 @@ def llm_model_endpoint_text_generation_inference(
"model_name": "llama-7b",
"source": "hugging_face",
"inference_framework": "text_generation_inference",
"inference_framework_image_tag": "123",
"inference_framework_image_tag": "0.9.4",
"num_shards": 4,
}
},
Expand Down
4 changes: 2 additions & 2 deletions model-engine/tests/unit/domain/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ def create_llm_model_endpoint_request_llama_2() -> CreateLLMModelEndpointV1Reque
model_name="llama-2-7b",
source="hugging_face",
inference_framework="text_generation_inference",
inference_framework_image_tag="test_tag",
inference_framework_image_tag="0.9.4",
num_shards=2,
endpoint_type=ModelEndpointType.STREAMING,
metadata={},
Expand Down Expand Up @@ -310,7 +310,7 @@ def create_llm_model_endpoint_text_generation_inference_request_async() -> (
model_name="mpt-7b",
source="hugging_face",
inference_framework="text_generation_inference",
inference_framework_image_tag="test_tag",
inference_framework_image_tag="0.9.4",
num_shards=2,
quantize=Quantization.BITSANDBYTES,
endpoint_type=ModelEndpointType.ASYNC,
Expand Down
61 changes: 60 additions & 1 deletion model-engine/tests/unit/domain/test_llm_use_cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,13 @@
)
from model_engine_server.common.dtos.tasks import SyncEndpointPredictV1Response, TaskStatus
from model_engine_server.core.auth.authentication_repository import User
from model_engine_server.domain.entities import ModelEndpoint, ModelEndpointType
from model_engine_server.domain.entities import (
LLMInferenceFramework,
ModelEndpoint,
ModelEndpointType,
)
from model_engine_server.domain.exceptions import (
DockerImageNotFoundException,
EndpointUnsupportedInferenceTypeException,
InvalidRequestException,
LLMFineTuningQuotaReached,
Expand Down Expand Up @@ -66,6 +71,7 @@ async def test_create_model_endpoint_use_case_success(
model_bundle_repository=fake_model_bundle_repository,
model_endpoint_service=fake_model_endpoint_service,
llm_artifact_gateway=fake_llm_artifact_gateway,
docker_repository=fake_docker_repository_image_always_exists,
)
user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True)
response_1 = await use_case.execute(user=user, request=create_llm_model_endpoint_request_async)
Expand Down Expand Up @@ -147,6 +153,56 @@ async def test_create_model_endpoint_use_case_success(
assert "--max-total-tokens" in bundle.flavor.command[-1] and "4096" in bundle.flavor.command[-1]


@pytest.mark.asyncio
@pytest.mark.parametrize(
"valid, inference_framework, inference_framework_image_tag",
[
(False, LLMInferenceFramework.TEXT_GENERATION_INFERENCE, "0.9.2"),
(True, LLMInferenceFramework.TEXT_GENERATION_INFERENCE, "0.9.3"),
(False, LLMInferenceFramework.VLLM, "0.1.6"),
(True, LLMInferenceFramework.VLLM, "0.1.3.6"),
],
)
async def test_create_model_bundle_inference_framework_image_tag_validation(
test_api_key: str,
fake_model_bundle_repository,
fake_model_endpoint_service,
fake_docker_repository_image_always_exists,
fake_docker_repository_image_never_exists,
fake_model_primitive_gateway,
fake_llm_artifact_gateway,
create_llm_model_endpoint_text_generation_inference_request_streaming: CreateLLMModelEndpointV1Request,
valid,
inference_framework,
inference_framework_image_tag,
):
fake_model_endpoint_service.model_bundle_repository = fake_model_bundle_repository
bundle_use_case = CreateModelBundleV2UseCase(
model_bundle_repository=fake_model_bundle_repository,
docker_repository=fake_docker_repository_image_always_exists,
model_primitive_gateway=fake_model_primitive_gateway,
)

use_case = CreateLLMModelEndpointV1UseCase(
create_model_bundle_use_case=bundle_use_case,
model_bundle_repository=fake_model_bundle_repository,
model_endpoint_service=fake_model_endpoint_service,
llm_artifact_gateway=fake_llm_artifact_gateway,
docker_repository=fake_docker_repository_image_always_exists,
)

request = create_llm_model_endpoint_text_generation_inference_request_streaming.copy()
request.inference_framework = inference_framework
request.inference_framework_image_tag = inference_framework_image_tag
user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True)
if valid:
await use_case.execute(user=user, request=request)
else:
use_case.docker_repository = fake_docker_repository_image_never_exists
with pytest.raises(DockerImageNotFoundException):
await use_case.execute(user=user, request=request)


@pytest.mark.asyncio
async def test_create_model_endpoint_text_generation_inference_use_case_success(
test_api_key: str,
Expand All @@ -169,6 +225,7 @@ async def test_create_model_endpoint_text_generation_inference_use_case_success(
model_bundle_repository=fake_model_bundle_repository,
model_endpoint_service=fake_model_endpoint_service,
llm_artifact_gateway=fake_llm_artifact_gateway,
docker_repository=fake_docker_repository_image_always_exists,
)
user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True)
response_1 = await use_case.execute(
Expand Down Expand Up @@ -224,6 +281,7 @@ async def test_create_llm_model_endpoint_use_case_raises_invalid_value_exception
model_bundle_repository=fake_model_bundle_repository,
model_endpoint_service=fake_model_endpoint_service,
llm_artifact_gateway=fake_llm_artifact_gateway,
docker_repository=fake_docker_repository_image_always_exists,
)
user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True)
with pytest.raises(ObjectHasInvalidValueException):
Expand Down Expand Up @@ -253,6 +311,7 @@ async def test_create_llm_model_endpoint_use_case_quantization_exception(
model_bundle_repository=fake_model_bundle_repository,
model_endpoint_service=fake_model_endpoint_service,
llm_artifact_gateway=fake_llm_artifact_gateway,
docker_repository=fake_docker_repository_image_always_exists,
)
user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True)
with pytest.raises(ObjectHasInvalidValueException):
Expand Down