diff --git a/charts/model-engine/templates/_helpers.tpl b/charts/model-engine/templates/_helpers.tpl index 3df2ea81b..d13af3584 100644 --- a/charts/model-engine/templates/_helpers.tpl +++ b/charts/model-engine/templates/_helpers.tpl @@ -270,8 +270,6 @@ env: value: {{ .Values.azure.abs_account_name }} - name: SERVICEBUS_NAMESPACE value: {{ .Values.azure.servicebus_namespace }} - - name: SERVICEBUS_SAS_KEY - value: {{ .Values.azure.servicebus_sas_key }} {{- end }} {{- end }} diff --git a/model-engine/Dockerfile b/model-engine/Dockerfile index 809395592..23eacd9c7 100644 --- a/model-engine/Dockerfile +++ b/model-engine/Dockerfile @@ -1,6 +1,6 @@ # syntax = docker/dockerfile:experimental -FROM python:3.8.8-slim as model-engine +FROM python:3.8.18-slim as model-engine WORKDIR /workspace RUN apt-get update && apt-get install -y \ diff --git a/model-engine/model_engine_server/api/dependencies.py b/model-engine/model_engine_server/api/dependencies.py index b65f11898..713938d11 100644 --- a/model-engine/model_engine_server/api/dependencies.py +++ b/model-engine/model_engine_server/api/dependencies.py @@ -1,5 +1,6 @@ import asyncio import os +import time from dataclasses import dataclass from typing import Callable, Optional @@ -442,6 +443,7 @@ async def verify_authentication( def get_or_create_aioredis_pool() -> aioredis.ConnectionPool: global _pool - if _pool is None: + expiration_timestamp = hmi_config.cache_redis_url_expiration_timestamp + if _pool is None or (expiration_timestamp is not None and time.time() > expiration_timestamp): _pool = aioredis.BlockingConnectionPool.from_url(hmi_config.cache_redis_url) return _pool diff --git a/model-engine/model_engine_server/api/files_v1.py b/model-engine/model_engine_server/api/files_v1.py index 556566d5e..d3c093f02 100644 --- a/model-engine/model_engine_server/api/files_v1.py +++ b/model-engine/model_engine_server/api/files_v1.py @@ -44,7 +44,7 @@ async def upload_file( ) return await use_case.execute( user=auth, - filename=file.filename, + filename=file.filename or "", content=file.file.read(), ) diff --git a/model-engine/model_engine_server/common/config.py b/model-engine/model_engine_server/common/config.py index 6c7088fc2..ac92cf435 100644 --- a/model-engine/model_engine_server/common/config.py +++ b/model-engine/model_engine_server/common/config.py @@ -30,6 +30,8 @@ SERVICE_CONFIG_PATH = os.environ.get("DEPLOY_SERVICE_CONFIG_PATH", DEFAULT_SERVICE_CONFIG_PATH) +redis_cache_expiration_timestamp = None + # duplicated from llm/ia3_finetune def get_model_cache_directory_name(model_name: str): @@ -81,9 +83,17 @@ def cache_redis_url(self) -> str: assert self.cache_redis_azure_host and infra_config().cloud_provider == "azure" username = os.getenv("AZURE_OBJECT_ID") - password = DefaultAzureCredential().get_token("https://redis.azure.com/.default").token + token = DefaultAzureCredential().get_token("https://redis.azure.com/.default") + password = token.token + global redis_cache_expiration_timestamp + redis_cache_expiration_timestamp = token.expires_on return f"rediss://{username}:{password}@{self.cache_redis_azure_host}" + @property + def cache_redis_url_expiration_timestamp(self) -> Optional[int]: + global redis_cache_expiration_timestamp + return redis_cache_expiration_timestamp + @property def cache_redis_host_port(self) -> str: # redis://redis.url:6379/ diff --git a/model-engine/model_engine_server/db/base.py b/model-engine/model_engine_server/db/base.py index b69496173..9acf95c09 100644 --- a/model-engine/model_engine_server/db/base.py +++ b/model-engine/model_engine_server/db/base.py @@ -58,7 +58,7 @@ def get_engine_url(env: Optional[str] = None, read_only: bool = True, sync: bool user = os.environ.get("AZURE_IDENTITY_NAME") password = ( DefaultAzureCredential() - .get_token("https://ossrdbms-aad.database.windows.net") + .get_token("https://ossrdbms-aad.database.windows.net/.default") .token ) logger.info(f"Connecting to db {db} as user {user}") @@ -81,7 +81,9 @@ def get_engine_url(env: Optional[str] = None, read_only: bool = True, sync: bool # For async postgres, we need to use an async dialect. if not sync: - engine_url = engine_url.replace("postgresql://", "postgresql+asyncpg://") + engine_url = engine_url.replace("postgresql://", "postgresql+asyncpg://").replace( + "sslmode", "ssl" + ) return engine_url diff --git a/model-engine/model_engine_server/domain/exceptions.py b/model-engine/model_engine_server/domain/exceptions.py index 32a16bd80..c64e3beb7 100644 --- a/model-engine/model_engine_server/domain/exceptions.py +++ b/model-engine/model_engine_server/domain/exceptions.py @@ -41,6 +41,12 @@ class DockerImageNotFoundException(DomainException): tag: str +class DockerRepositoryNotFoundException(DomainException): + """ + Thrown when a Docker repository that is trying to be accessed doesn't exist. + """ + + class DockerBuildFailedException(DomainException): """ Thrown if the server failed to build a docker image. diff --git a/model-engine/model_engine_server/infra/gateways/live_model_endpoints_schema_gateway.py b/model-engine/model_engine_server/infra/gateways/live_model_endpoints_schema_gateway.py index 1f6dd7b09..5fac2841a 100644 --- a/model-engine/model_engine_server/infra/gateways/live_model_endpoints_schema_gateway.py +++ b/model-engine/model_engine_server/infra/gateways/live_model_endpoints_schema_gateway.py @@ -3,8 +3,9 @@ from typing import Any, Callable, Dict, Sequence, Set, Type, Union from fastapi import routing +from fastapi._compat import GenerateJsonSchema, get_model_definitions +from fastapi.openapi.constants import REF_TEMPLATE from fastapi.openapi.utils import get_openapi_path -from fastapi.utils import get_model_definitions from model_engine_server.common.dtos.tasks import ( EndpointPredictV1Request, GetAsyncTaskV1Response, @@ -119,8 +120,13 @@ def get_openapi( if isinstance(route, routing.APIRoute): prefix = model_endpoint_name model_name_map = LiveModelEndpointsSchemaGateway.get_model_name_map(prefix) + schema_generator = GenerateJsonSchema(ref_template=REF_TEMPLATE) result = get_openapi_path( - route=route, model_name_map=model_name_map, operation_ids=operation_ids + route=route, + model_name_map=model_name_map, + operation_ids=operation_ids, + schema_generator=schema_generator, + field_mapping={}, ) if result: path, security_schemes, path_definitions = result diff --git a/model-engine/model_engine_server/infra/repositories/acr_docker_repository.py b/model-engine/model_engine_server/infra/repositories/acr_docker_repository.py index 2d6e1cc3f..7f9137feb 100644 --- a/model-engine/model_engine_server/infra/repositories/acr_docker_repository.py +++ b/model-engine/model_engine_server/infra/repositories/acr_docker_repository.py @@ -6,6 +6,7 @@ from model_engine_server.common.dtos.docker_repository import BuildImageRequest, BuildImageResponse from model_engine_server.core.config import infra_config from model_engine_server.core.loggers import logger_name, make_logger +from model_engine_server.domain.exceptions import DockerRepositoryNotFoundException from model_engine_server.domain.repositories import DockerRepository logger = make_logger(logger_name()) @@ -36,7 +37,11 @@ def get_latest_image_tag(self, repository_name: str) -> str: credential = DefaultAzureCredential() client = ContainerRegistryClient(endpoint, credential) - image = client.list_manifest_properties( - repository_name, order_by="time_desc", results_per_page=1 - ).next() - return image.tags[0] + try: + image = client.list_manifest_properties( + repository_name, order_by="time_desc", results_per_page=1 + ).next() + # Azure automatically deletes empty ACR repositories, so repos will always have at least one image + return image.tags[0] + except ResourceNotFoundError: + raise DockerRepositoryNotFoundException diff --git a/model-engine/model_engine_server/infra/services/image_cache_service.py b/model-engine/model_engine_server/infra/services/image_cache_service.py index 5d5c9d135..beab3ec86 100644 --- a/model-engine/model_engine_server/infra/services/image_cache_service.py +++ b/model-engine/model_engine_server/infra/services/image_cache_service.py @@ -7,6 +7,7 @@ from model_engine_server.core.config import infra_config from model_engine_server.core.loggers import logger_name, make_logger from model_engine_server.domain.entities import GpuType, ModelEndpointInfraState +from model_engine_server.domain.exceptions import DockerRepositoryNotFoundException from model_engine_server.domain.repositories import DockerRepository from model_engine_server.infra.gateways.resources.image_cache_gateway import ( CachedImages, @@ -78,11 +79,14 @@ def _cache_finetune_llm_images( vllm_image_032 = DockerImage( f"{infra_config().docker_repo_prefix}/{hmi_config.vllm_repository}", "0.3.2" ) - latest_tag = ( - self.docker_repository.get_latest_image_tag(hmi_config.batch_inference_vllm_repository) - if not CIRCLECI - else "fake_docker_repository_latest_image_tag" - ) + latest_tag = "fake_docker_repository_latest_image_tag" + if not CIRCLECI: + try: # pragma: no cover + latest_tag = self.docker_repository.get_latest_image_tag( + hmi_config.batch_inference_vllm_repository + ) + except DockerRepositoryNotFoundException: + pass vllm_batch_image_latest = DockerImage( f"{infra_config().docker_repo_prefix}/{hmi_config.batch_inference_vllm_repository}", latest_tag, diff --git a/model-engine/requirements.in b/model-engine/requirements.in index 49984a543..756df6c3a 100644 --- a/model-engine/requirements.in +++ b/model-engine/requirements.in @@ -9,7 +9,7 @@ azure-identity~=1.15.0 azure-keyvault-secrets~=4.7.0 azure-servicebus~=7.11.4 azure-storage-blob~=12.19.0 -boto3-stubs[essential]==1.26.67 +boto3-stubs[essential]~=1.26.67 boto3~=1.21 botocore~=1.24 build==0.8.0 @@ -17,13 +17,14 @@ celery[redis,sqs,tblib]~=5.3.6 click~=8.1 cloudpickle==2.1.0 croniter==1.4.1 +cryptography>=42.0.4 # not used directly, but needs to be pinned for Microsoft security scan dataclasses-json>=0.5.7 datadog-api-client==2.11.0 datadog~=0.47.0 ddtrace==1.8.3 deprecation~=2.1 docker~=5.0 -fastapi==0.78.0 +fastapi~=0.110.0 gitdb2~=2.0 gunicorn~=20.0 httptools==0.5.0 @@ -45,9 +46,10 @@ rich~=12.6 sentencepiece==0.1.99 sh~=1.13 smart-open~=5.2 -sqlalchemy[asyncio]==2.0.4 -sse-starlette==1.6.1 +sqlalchemy[asyncio]~=2.0.4 +sse-starlette==2.0.0 sseclient-py==1.7.2 +starlette[full]>=0.35.0 # not used directly, but needs to be pinned for Microsoft security scan stringcase==1.2.0 tenacity>=6.0.0,<=6.2.0 testing-postgresql==1.3.0 diff --git a/model-engine/requirements.txt b/model-engine/requirements.txt index 47d2fcef8..bc0052c1e 100644 --- a/model-engine/requirements.txt +++ b/model-engine/requirements.txt @@ -8,19 +8,21 @@ aiofiles==23.1.0 # via quart aiohttp==3.9.2 # via - # -r requirements.in + # -r model-engine/requirements.in # kubernetes-asyncio aioredis==2.0.1 - # via -r requirements.in + # via -r model-engine/requirements.in aiosignal==1.3.1 # via aiohttp alembic==1.8.1 - # via -r requirements.in + # via -r model-engine/requirements.in amqp==5.1.1 # via kombu anyio==3.7.1 # via # azure-core + # httpx + # sse-starlette # starlette asgiref==3.7.2 # via uvicorn @@ -32,7 +34,7 @@ async-timeout==4.0.2 # aioredis # redis asyncpg==0.27.0 - # via -r requirements.in + # via -r model-engine/requirements.in attrs==23.1.0 # via # aiohttp @@ -43,7 +45,7 @@ attrs==23.1.0 azure-common==1.1.28 # via azure-keyvault-secrets azure-containerregistry==1.2.0 - # via -r requirements.in + # via -r model-engine/requirements.in azure-core==1.29.6 # via # azure-containerregistry @@ -52,13 +54,13 @@ azure-core==1.29.6 # azure-servicebus # azure-storage-blob azure-identity==1.15.0 - # via -r requirements.in + # via -r model-engine/requirements.in azure-keyvault-secrets==4.7.0 - # via -r requirements.in + # via -r model-engine/requirements.in azure-servicebus==7.11.4 - # via -r requirements.in + # via -r model-engine/requirements.in azure-storage-blob==12.19.0 - # via -r requirements.in + # via -r model-engine/requirements.in backports-zoneinfo[tzdata]==0.2.1 # via # celery @@ -71,22 +73,20 @@ blinker==1.6.2 # via quart boto3==1.28.1 # via - # -r requirements.in + # -r model-engine/requirements.in # celery # kombu boto3-stubs[essential]==1.26.67 - # via - # -r requirements.in - # boto3-stubs + # via -r model-engine/requirements.in botocore==1.31.1 # via - # -r requirements.in + # -r model-engine/requirements.in # boto3 # s3transfer botocore-stubs==1.29.165 # via boto3-stubs build==0.8.0 - # via -r requirements.in + # via -r model-engine/requirements.in bytecode==0.14.2 # via ddtrace cachetools==5.3.1 @@ -94,12 +94,12 @@ cachetools==5.3.1 cattrs==23.1.2 # via ddtrace celery[redis,sqs,tblib]==5.3.6 - # via - # -r requirements.in - # celery + # via -r model-engine/requirements.in certifi==2023.7.22 # via # datadog-api-client + # httpcore + # httpx # kubernetes # kubernetes-asyncio # requests @@ -109,7 +109,7 @@ charset-normalizer==3.2.0 # via requests click==8.1.4 # via - # -r requirements.in + # -r model-engine/requirements.in # celery # click-didyoumean # click-plugins @@ -123,34 +123,35 @@ click-plugins==1.1.1 click-repl==0.3.0 # via celery cloudpickle==2.1.0 - # via -r requirements.in + # via -r model-engine/requirements.in colorama==0.4.6 # via twine commonmark==0.9.1 # via rich croniter==1.4.1 - # via -r requirements.in -cryptography==41.0.7 + # via -r model-engine/requirements.in +cryptography==42.0.5 # via + # -r model-engine/requirements.in # azure-identity # azure-storage-blob # msal # pyjwt # secretstorage dataclasses-json==0.5.9 - # via -r requirements.in + # via -r model-engine/requirements.in datadog==0.47.0 - # via -r requirements.in + # via -r model-engine/requirements.in datadog-api-client==2.11.0 - # via -r requirements.in + # via -r model-engine/requirements.in ddsketch==2.0.4 # via ddtrace ddtrace==1.8.3 - # via -r requirements.in + # via -r model-engine/requirements.in deprecation==2.1.0 - # via -r requirements.in + # via -r model-engine/requirements.in docker==5.0.3 - # via -r requirements.in + # via -r model-engine/requirements.in docutils==0.20.1 # via readme-renderer envier==0.4.0 @@ -159,8 +160,8 @@ exceptiongroup==1.2.0 # via # anyio # cattrs -fastapi==0.78.0 - # via -r requirements.in +fastapi==0.110.0 + # via -r model-engine/requirements.in filelock==3.13.1 # via # huggingface-hub @@ -174,17 +175,18 @@ fsspec==2023.10.0 gitdb==4.0.10 # via gitpython gitdb2==2.0.6 - # via -r requirements.in + # via -r model-engine/requirements.in gitpython==3.1.41 - # via -r requirements.in + # via -r model-engine/requirements.in google-auth==2.21.0 # via kubernetes greenlet==2.0.2 # via sqlalchemy gunicorn==20.1.0 - # via -r requirements.in + # via -r model-engine/requirements.in h11==0.14.0 # via + # httpcore # hypercorn # uvicorn # wsproto @@ -192,8 +194,12 @@ h2==4.1.0 # via hypercorn hpack==4.0.0 # via h2 +httpcore==1.0.4 + # via httpx httptools==0.5.0 - # via -r requirements.in + # via -r model-engine/requirements.in +httpx==0.27.0 + # via starlette huggingface-hub==0.20.3 # via # tokenizers @@ -205,6 +211,7 @@ hyperframe==6.0.1 idna==3.4 # via # anyio + # httpx # requests # yarl importlib-metadata==6.8.0 @@ -226,7 +233,9 @@ isodate==0.6.1 # azure-servicebus # azure-storage-blob itsdangerous==2.1.2 - # via quart + # via + # quart + # starlette jaraco-classes==3.3.0 # via keyring jeepney==0.8.0 @@ -235,14 +244,15 @@ jeepney==0.8.0 # secretstorage jinja2==3.0.3 # via - # -r requirements.in + # -r model-engine/requirements.in # quart + # starlette jmespath==1.0.1 # via # boto3 # botocore json-log-formatter==0.5.2 - # via -r requirements.in + # via -r model-engine/requirements.in jsonschema==4.19.0 # via ddtrace jsonschema-specifications==2023.7.1 @@ -252,11 +262,11 @@ keyring==24.2.0 kombu[sqs]==5.3.5 # via celery kubeconfig==1.1.1 - # via -r requirements.in + # via -r model-engine/requirements.in kubernetes==25.3.0 - # via -r requirements.in + # via -r model-engine/requirements.in kubernetes-asyncio==25.11.0 - # via -r requirements.in + # via -r model-engine/requirements.in mako==1.2.4 # via alembic markupsafe==2.1.3 @@ -304,7 +314,7 @@ numpy==1.24.4 oauthlib==3.2.2 # via requests-oauthlib orjson==3.9.15 - # via -r requirements.in + # via -r model-engine/requirements.in packaging==23.1 # via # build @@ -330,13 +340,13 @@ prompt-toolkit==3.0.39 # via click-repl protobuf==3.20.3 # via - # -r requirements.in + # -r model-engine/requirements.in # ddsketch # ddtrace psycopg2-binary==2.9.3 - # via -r requirements.in + # via -r model-engine/requirements.in py-xid==0.3.0 - # via -r requirements.in + # via -r model-engine/requirements.in pyasn1==0.5.0 # via # pyasn1-modules @@ -347,21 +357,19 @@ pycparser==2.21 # via cffi pycurl==7.45.2 # via - # -r requirements.in + # -r model-engine/requirements.in # celery # kombu pydantic==1.10.11 # via - # -r requirements.in + # -r model-engine/requirements.in # fastapi pygments==2.15.1 # via # readme-renderer # rich pyjwt[crypto]==2.8.0 - # via - # msal - # pyjwt + # via msal python-dateutil==2.8.2 # via # botocore @@ -372,16 +380,19 @@ python-dateutil==2.8.2 # kubernetes-asyncio # pg8000 python-multipart==0.0.7 - # via -r requirements.in + # via + # -r model-engine/requirements.in + # starlette pyyaml==6.0.1 # via # huggingface-hub # kubeconfig # kubernetes # kubernetes-asyncio + # starlette # transformers quart==0.18.3 - # via -r requirements.in + # via -r model-engine/requirements.in readme-renderer==40.0 # via twine redis==4.6.0 @@ -394,7 +405,7 @@ regex==2023.10.3 # via transformers requests==2.31.0 # via - # -r requirements.in + # -r model-engine/requirements.in # azure-core # datadog # docker @@ -407,7 +418,7 @@ requests==2.31.0 # transformers # twine requests-auth-aws-sigv4==0.7 - # via -r requirements.in + # via -r model-engine/requirements.in requests-oauthlib==1.3.1 # via kubernetes requests-toolbelt==1.0.0 @@ -415,7 +426,7 @@ requests-toolbelt==1.0.0 rfc3986==2.0.0 # via twine rich==12.6.0 - # via -r requirements.in + # via -r model-engine/requirements.in rpds-py==0.10.0 # via # jsonschema @@ -431,9 +442,9 @@ scramp==1.4.4 secretstorage==3.3.3 # via keyring sentencepiece==0.1.99 - # via -r requirements.in + # via -r model-engine/requirements.in sh==1.14.3 - # via -r requirements.in + # via -r model-engine/requirements.in six==1.16.0 # via # azure-core @@ -447,7 +458,7 @@ six==1.16.0 # python-dateutil # tenacity smart-open==5.2.1 - # via -r requirements.in + # via -r model-engine/requirements.in smmap==5.0.0 # via # gitdb @@ -455,35 +466,37 @@ smmap==5.0.0 smmap2==3.0.1 # via gitdb2 sniffio==1.3.0 - # via anyio + # via + # anyio + # httpx sqlalchemy[asyncio]==2.0.4 # via - # -r requirements.in + # -r model-engine/requirements.in # alembic - # sqlalchemy -sse-starlette==1.6.1 - # via -r requirements.in +sse-starlette==2.0.0 + # via -r model-engine/requirements.in sseclient-py==1.7.2 - # via -r requirements.in -starlette==0.19.1 + # via -r model-engine/requirements.in +starlette[full]==0.36.3 # via + # -r model-engine/requirements.in # fastapi # sse-starlette stringcase==1.2.0 - # via -r requirements.in + # via -r model-engine/requirements.in tblib==2.0.0 # via celery tenacity==6.2.0 # via - # -r requirements.in + # -r model-engine/requirements.in # ddtrace testing-common-database==2.0.3 # via testing-postgresql testing-postgresql==1.3.0 - # via -r requirements.in + # via -r model-engine/requirements.in tokenizers==0.15.2 # via - # -r requirements.in + # -r model-engine/requirements.in # transformers tomli==2.0.1 # via @@ -492,21 +505,21 @@ tomli==2.0.1 # pep517 tqdm==4.65.0 # via - # -r requirements.in + # -r model-engine/requirements.in # huggingface-hub # transformers # twine transformers==4.38.0 - # via -r requirements.in + # via -r model-engine/requirements.in twine==3.7.1 - # via -r requirements.in + # via -r model-engine/requirements.in types-awscrt==0.16.23 # via # botocore-stubs # types-s3transfer types-s3transfer==0.6.1 # via boto3-stubs -typing-extensions==4.7.1 +typing-extensions==4.10.0 # via # aioredis # asgiref @@ -520,6 +533,7 @@ typing-extensions==4.7.1 # cattrs # datadog-api-client # ddtrace + # fastapi # huggingface-hub # kombu # mypy-boto3-cloudformation @@ -551,9 +565,11 @@ urllib3==1.26.16 # kubernetes-asyncio # requests uvicorn==0.17.6 - # via -r requirements.in + # via + # -r model-engine/requirements.in + # sse-starlette uvloop==0.17.0 - # via -r requirements.in + # via -r model-engine/requirements.in vine==5.1.0 # via # amqp @@ -575,7 +591,7 @@ xmltodict==0.13.0 # via ddtrace yarl==1.9.2 # via - # -r requirements.in + # -r model-engine/requirements.in # aiohttp zipp==3.16.0 # via diff --git a/model-engine/tests/unit/api/test_tasks.py b/model-engine/tests/unit/api/test_tasks.py index 611195bd3..f9a0f0620 100644 --- a/model-engine/tests/unit/api/test_tasks.py +++ b/model-engine/tests/unit/api/test_tasks.py @@ -360,15 +360,14 @@ def test_create_streaming_task_success( fake_batch_job_progress_gateway_contents={}, fake_docker_image_batch_job_bundle_repository_contents={}, ) - response = client.post( - f"/v1/streaming-tasks?model_endpoint_id={model_endpoint_streaming.record.id}", + with client.stream( + method="POST", + url=f"/v1/streaming-tasks?model_endpoint_id={model_endpoint_streaming.record.id}", auth=(test_api_key, ""), json=endpoint_predict_request_1[1], - stream=True, - ) - assert response.status_code == 200 - count = 0 - for message in response: - assert message == b'data: {"status": "SUCCESS", "result": null, "traceback": null}\r\n\r\n' - count += 1 - assert count == 1 + ) as response: + assert response.status_code == 200 + assert ( + response.read() + == b'data: {"status": "SUCCESS", "result": null, "traceback": null}\r\n\r\n' + )