Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions charts/model-engine/templates/_helpers.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -270,8 +270,6 @@ env:
value: {{ .Values.azure.abs_account_name }}
- name: SERVICEBUS_NAMESPACE
value: {{ .Values.azure.servicebus_namespace }}
- name: SERVICEBUS_SAS_KEY
value: {{ .Values.azure.servicebus_sas_key }}
{{- end }}
{{- end }}

Expand Down
2 changes: 1 addition & 1 deletion model-engine/Dockerfile
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# syntax = docker/dockerfile:experimental

FROM python:3.8.8-slim as model-engine
FROM python:3.8.18-slim as model-engine
WORKDIR /workspace

RUN apt-get update && apt-get install -y \
Expand Down
4 changes: 3 additions & 1 deletion model-engine/model_engine_server/api/dependencies.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import os
import time
from dataclasses import dataclass
from typing import Callable, Optional

Expand Down Expand Up @@ -442,6 +443,7 @@ async def verify_authentication(
def get_or_create_aioredis_pool() -> aioredis.ConnectionPool:
global _pool

if _pool is None:
expiration_timestamp = hmi_config.cache_redis_url_expiration_timestamp
if _pool is None or (expiration_timestamp is not None and time.time() > expiration_timestamp):
_pool = aioredis.BlockingConnectionPool.from_url(hmi_config.cache_redis_url)
return _pool
2 changes: 1 addition & 1 deletion model-engine/model_engine_server/api/files_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ async def upload_file(
)
return await use_case.execute(
user=auth,
filename=file.filename,
filename=file.filename or "",
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when does file.filename equal None, should we just throw a 4xx in that case (if it is user fault that file.filename could be None?)

maybe this is fine if filename doesn't get used as the only part of an identifier actually, tbh I'm not sure though

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm not sure, I'm assuming this came up from lint because file.filename changed from str to Optional[str] between the old and new FastAPI versions... kinda hoping it's just always defined lol 😅

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we know when filename can be None? eg does the fastapi documentation say anything about it

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm I couldn't find anything in the documentation, but it does seem like there are methods of calling where this can be user-set

content=file.file.read(),
)

Expand Down
12 changes: 11 additions & 1 deletion model-engine/model_engine_server/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@

SERVICE_CONFIG_PATH = os.environ.get("DEPLOY_SERVICE_CONFIG_PATH", DEFAULT_SERVICE_CONFIG_PATH)

redis_cache_expiration_timestamp = None


# duplicated from llm/ia3_finetune
def get_model_cache_directory_name(model_name: str):
Expand Down Expand Up @@ -81,9 +83,17 @@ def cache_redis_url(self) -> str:

assert self.cache_redis_azure_host and infra_config().cloud_provider == "azure"
username = os.getenv("AZURE_OBJECT_ID")
password = DefaultAzureCredential().get_token("https://redis.azure.com/.default").token
token = DefaultAzureCredential().get_token("https://redis.azure.com/.default")
password = token.token
global redis_cache_expiration_timestamp
redis_cache_expiration_timestamp = token.expires_on
return f"rediss://{username}:{password}@{self.cache_redis_azure_host}"

@property
def cache_redis_url_expiration_timestamp(self) -> Optional[int]:
global redis_cache_expiration_timestamp
return redis_cache_expiration_timestamp

@property
def cache_redis_host_port(self) -> str:
# redis://redis.url:6379/<db_index>
Expand Down
6 changes: 4 additions & 2 deletions model-engine/model_engine_server/db/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def get_engine_url(env: Optional[str] = None, read_only: bool = True, sync: bool
user = os.environ.get("AZURE_IDENTITY_NAME")
password = (
DefaultAzureCredential()
.get_token("https://ossrdbms-aad.database.windows.net")
.get_token("https://ossrdbms-aad.database.windows.net/.default")
.token
)
logger.info(f"Connecting to db {db} as user {user}")
Expand All @@ -81,7 +81,9 @@ def get_engine_url(env: Optional[str] = None, read_only: bool = True, sync: bool

# For async postgres, we need to use an async dialect.
if not sync:
engine_url = engine_url.replace("postgresql://", "postgresql+asyncpg://")
engine_url = engine_url.replace("postgresql://", "postgresql+asyncpg://").replace(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what does replace('sslmode', 'ssl') do? does it remain compatible with our aws db setup?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For sync, psycopg2 needs sslmode, but for async, asyncpg needs ssl. This shouldn't affect AWS because the sslmode=require param is only added for Azure

"sslmode", "ssl"
)
return engine_url


Expand Down
6 changes: 6 additions & 0 deletions model-engine/model_engine_server/domain/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,12 @@ class DockerImageNotFoundException(DomainException):
tag: str


class DockerRepositoryNotFoundException(DomainException):
"""
Thrown when a Docker repository that is trying to be accessed doesn't exist.
"""


class DockerBuildFailedException(DomainException):
"""
Thrown if the server failed to build a docker image.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
from typing import Any, Callable, Dict, Sequence, Set, Type, Union

from fastapi import routing
from fastapi._compat import GenerateJsonSchema, get_model_definitions
from fastapi.openapi.constants import REF_TEMPLATE
from fastapi.openapi.utils import get_openapi_path
from fastapi.utils import get_model_definitions
from model_engine_server.common.dtos.tasks import (
EndpointPredictV1Request,
GetAsyncTaskV1Response,
Expand Down Expand Up @@ -119,8 +120,13 @@ def get_openapi(
if isinstance(route, routing.APIRoute):
prefix = model_endpoint_name
model_name_map = LiveModelEndpointsSchemaGateway.get_model_name_map(prefix)
schema_generator = GenerateJsonSchema(ref_template=REF_TEMPLATE)
result = get_openapi_path(
route=route, model_name_map=model_name_map, operation_ids=operation_ids
route=route,
model_name_map=model_name_map,
operation_ids=operation_ids,
schema_generator=schema_generator,
field_mapping={},
)
if result:
path, security_schemes, path_definitions = result
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from model_engine_server.common.dtos.docker_repository import BuildImageRequest, BuildImageResponse
from model_engine_server.core.config import infra_config
from model_engine_server.core.loggers import logger_name, make_logger
from model_engine_server.domain.exceptions import DockerRepositoryNotFoundException
from model_engine_server.domain.repositories import DockerRepository

logger = make_logger(logger_name())
Expand Down Expand Up @@ -36,7 +37,11 @@ def get_latest_image_tag(self, repository_name: str) -> str:
credential = DefaultAzureCredential()
client = ContainerRegistryClient(endpoint, credential)

image = client.list_manifest_properties(
repository_name, order_by="time_desc", results_per_page=1
).next()
return image.tags[0]
try:
image = client.list_manifest_properties(
repository_name, order_by="time_desc", results_per_page=1
).next()
# Azure automatically deletes empty ACR repositories, so repos will always have at least one image
return image.tags[0]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we want to throw an error if there are 0 image tags?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unless that's already handled in the ResourceNotFoundError

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just tested, it looks like Azure automatically deletes repositories that are empty, so it'll be a ResourceNotFoundError 🤔

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah ok sounds good, could we note it in the code so we know why we're not gonna IndexError?

except ResourceNotFoundError:
raise DockerRepositoryNotFoundException
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from model_engine_server.core.config import infra_config
from model_engine_server.core.loggers import logger_name, make_logger
from model_engine_server.domain.entities import GpuType, ModelEndpointInfraState
from model_engine_server.domain.exceptions import DockerRepositoryNotFoundException
from model_engine_server.domain.repositories import DockerRepository
from model_engine_server.infra.gateways.resources.image_cache_gateway import (
CachedImages,
Expand Down Expand Up @@ -78,11 +79,14 @@ def _cache_finetune_llm_images(
vllm_image_032 = DockerImage(
f"{infra_config().docker_repo_prefix}/{hmi_config.vllm_repository}", "0.3.2"
)
latest_tag = (
self.docker_repository.get_latest_image_tag(hmi_config.batch_inference_vllm_repository)
if not CIRCLECI
else "fake_docker_repository_latest_image_tag"
)
latest_tag = "fake_docker_repository_latest_image_tag"
if not CIRCLECI:
try: # pragma: no cover
latest_tag = self.docker_repository.get_latest_image_tag(
hmi_config.batch_inference_vllm_repository
)
except DockerRepositoryNotFoundException:
pass
vllm_batch_image_latest = DockerImage(
f"{infra_config().docker_repo_prefix}/{hmi_config.batch_inference_vllm_repository}",
latest_tag,
Expand Down
10 changes: 6 additions & 4 deletions model-engine/requirements.in
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,22 @@ azure-identity~=1.15.0
azure-keyvault-secrets~=4.7.0
azure-servicebus~=7.11.4
azure-storage-blob~=12.19.0
boto3-stubs[essential]==1.26.67
boto3-stubs[essential]~=1.26.67
boto3~=1.21
botocore~=1.24
build==0.8.0
celery[redis,sqs,tblib]~=5.3.6
click~=8.1
cloudpickle==2.1.0
croniter==1.4.1
cryptography>=42.0.4 # not used directly, but needs to be pinned for Microsoft security scan
dataclasses-json>=0.5.7
datadog-api-client==2.11.0
datadog~=0.47.0
ddtrace==1.8.3
deprecation~=2.1
docker~=5.0
fastapi==0.78.0
fastapi~=0.110.0
gitdb2~=2.0
gunicorn~=20.0
httptools==0.5.0
Expand All @@ -45,9 +46,10 @@ rich~=12.6
sentencepiece==0.1.99
sh~=1.13
smart-open~=5.2
sqlalchemy[asyncio]==2.0.4
sse-starlette==1.6.1
sqlalchemy[asyncio]~=2.0.4
sse-starlette==2.0.0
sseclient-py==1.7.2
starlette[full]>=0.35.0 # not used directly, but needs to be pinned for Microsoft security scan
stringcase==1.2.0
tenacity>=6.0.0,<=6.2.0
testing-postgresql==1.3.0
Expand Down
Loading