Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 36 additions & 8 deletions model-engine/model_engine_server/api/dependencies.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import asyncio
import os
from dataclasses import dataclass
from typing import Callable, Iterator, Optional
from typing import Callable, Optional

import aioredis
from fastapi import Depends, HTTPException, status
Expand All @@ -26,6 +26,7 @@
FileStorageGateway,
LLMArtifactGateway,
ModelPrimitiveGateway,
MonitoringMetricsGateway,
TaskQueueGateway,
)
from model_engine_server.domain.repositories import (
Expand Down Expand Up @@ -134,6 +135,24 @@ class ExternalInterfaces:
cron_job_gateway: CronJobGateway


def get_default_monitoring_metrics_gateway() -> MonitoringMetricsGateway:
monitoring_metrics_gateway = FakeMonitoringMetricsGateway()
return monitoring_metrics_gateway


def get_monitoring_metrics_gateway() -> MonitoringMetricsGateway:
try:
from plugins.dependencies import (
get_monitoring_metrics_gateway as get_custom_monitoring_metrics_gateway,
)

return get_custom_monitoring_metrics_gateway()
except ModuleNotFoundError:
return get_default_monitoring_metrics_gateway()
finally:
pass


def _get_external_interfaces(
read_only: bool, session: Callable[[], AsyncSession]
) -> ExternalInterfaces:
Expand All @@ -144,7 +163,7 @@ def _get_external_interfaces(
redis_task_queue_gateway = CeleryTaskQueueGateway(broker_type=BrokerType.REDIS)
redis_24h_task_queue_gateway = CeleryTaskQueueGateway(broker_type=BrokerType.REDIS_24H)
sqs_task_queue_gateway = CeleryTaskQueueGateway(broker_type=BrokerType.SQS)
monitoring_metrics_gateway = FakeMonitoringMetricsGateway()
monitoring_metrics_gateway = get_monitoring_metrics_gateway()
model_endpoint_record_repo = DbModelEndpointRecordRepository(
monitoring_metrics_gateway=monitoring_metrics_gateway,
session=session,
Expand Down Expand Up @@ -300,12 +319,21 @@ async def get_external_interfaces_read_only():
pass


def get_auth_repository() -> Iterator[AuthenticationRepository]:
def get_default_auth_repository() -> AuthenticationRepository:
auth_repo = FakeAuthenticationRepository()
return auth_repo


async def get_auth_repository():
"""
Dependency for an AuthenticationRepository. This implementation returns a fake repository.
"""
try:
yield FakeAuthenticationRepository()
from plugins.dependencies import get_auth_repository as get_custom_auth_repository

yield get_custom_auth_repository()
except ModuleNotFoundError:
yield get_default_auth_repository()
finally:
pass

Expand All @@ -318,15 +346,15 @@ async def verify_authentication(
Verifies the authentication headers and returns a (user_id, team_id) auth tuple. Otherwise,
raises a 401.
"""
user_id = credentials.username if credentials is not None else None
if user_id is None:
username = credentials.username if credentials is not None else None
if username is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="No user id was passed in",
detail="No authentication was passed in",
headers={"WWW-Authenticate": "Basic"},
)

auth = await auth_repo.get_auth_from_user_id_async(user_id=user_id)
auth = await auth_repo.get_auth_from_username_async(username=username)

if not auth:
raise HTTPException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,25 +24,13 @@ def is_allowed_team(team: str) -> bool:
"""

@abstractmethod
def get_auth_from_user_id(self, user_id: str) -> Optional[User]:
def get_auth_from_username(self, username: str) -> Optional[User]:
"""
Returns authentication information associated with a given user_id.
Returns authentication information associated with a given Basic Auth username.
"""

@abstractmethod
def get_auth_from_api_key(self, api_key: str) -> Optional[User]:
async def get_auth_from_username_async(self, username: str) -> Optional[User]:
"""
Returns authentication information associated with a given api_key.
"""

@abstractmethod
async def get_auth_from_user_id_async(self, user_id: str) -> Optional[User]:
"""
Returns authentication information associated with a given user_id.
"""

@abstractmethod
async def get_auth_from_api_key_async(self, api_key: str) -> Optional[User]:
"""
Returns authentication information associated with a given api_key.
Returns authentication information associated with a given Basic Auth username.
"""
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,10 @@ def __init__(self, user_team_override: Optional[Dict[str, str]] = None):
def is_allowed_team(team: str) -> bool:
return True

def get_auth_from_user_id(self, user_id: str) -> Optional[User]:
team_id = self.user_team_override.get(user_id, user_id)
return User(user_id=user_id, team_id=team_id, is_privileged_user=True)
def get_auth_from_username(self, username: str) -> Optional[User]:
team_id = self.user_team_override.get(username, username)
return User(user_id=username, team_id=team_id, is_privileged_user=True)

async def get_auth_from_user_id_async(self, user_id: str) -> Optional[User]:
team_id = self.user_team_override.get(user_id, user_id)
return User(user_id=user_id, team_id=team_id, is_privileged_user=True)

def get_auth_from_api_key(self, api_key: str) -> Optional[User]:
return User(user_id=api_key, team_id=api_key, is_privileged_user=True)

async def get_auth_from_api_key_async(self, api_key: str) -> Optional[User]:
return User(user_id=api_key, team_id=api_key, is_privileged_user=True)
async def get_auth_from_username_async(self, username: str) -> Optional[User]:
team_id = self.user_team_override.get(username, username)
return User(user_id=username, team_id=team_id, is_privileged_user=True)
15 changes: 4 additions & 11 deletions model-engine/model_engine_server/entrypoints/k8s_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,13 @@

from kubernetes import config as kube_config
from kubernetes.config.config_exception import ConfigException
from model_engine_server.api.dependencies import get_monitoring_metrics_gateway
from model_engine_server.common.config import hmi_config
from model_engine_server.common.constants import READYZ_FPATH
from model_engine_server.common.env_vars import CIRCLECI, SKIP_AUTH
from model_engine_server.common.env_vars import CIRCLECI
from model_engine_server.core.loggers import logger_name, make_logger
from model_engine_server.db.base import SessionAsyncNullPool
from model_engine_server.domain.gateways import MonitoringMetricsGateway
from model_engine_server.domain.repositories import DockerRepository
from model_engine_server.infra.gateways import (
DatadogMonitoringMetricsGateway,
FakeMonitoringMetricsGateway,
)
from model_engine_server.infra.gateways.resources.endpoint_resource_gateway import (
EndpointResourceGateway,
)
Expand Down Expand Up @@ -95,16 +91,13 @@ async def main(args: Any):
logger.info(f"Using cache redis url {redis_url}")
cache_repo = RedisModelEndpointCacheRepository(redis_info=redis_url)

monitoring_metrics_gateway: MonitoringMetricsGateway
if SKIP_AUTH:
monitoring_metrics_gateway = FakeMonitoringMetricsGateway()
else:
monitoring_metrics_gateway = DatadogMonitoringMetricsGateway()
monitoring_metrics_gateway = get_monitoring_metrics_gateway()
endpoint_record_repo = DbModelEndpointRecordRepository(
monitoring_metrics_gateway=monitoring_metrics_gateway,
session=SessionAsyncNullPool,
read_only=True,
)

sqs_delegate: SQSEndpointResourceDelegate
if CIRCLECI:
sqs_delegate = FakeSQSEndpointResourceDelegate()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,14 @@
from datetime import timedelta

import aioredis
from model_engine_server.api.dependencies import get_monitoring_metrics_gateway
from model_engine_server.common.config import hmi_config
from model_engine_server.common.dtos.model_endpoints import BrokerType
from model_engine_server.common.env_vars import CIRCLECI, SKIP_AUTH
from model_engine_server.common.env_vars import CIRCLECI
from model_engine_server.db.base import SessionAsyncNullPool
from model_engine_server.domain.entities import BatchJobSerializationFormat
from model_engine_server.domain.gateways import MonitoringMetricsGateway
from model_engine_server.infra.gateways import (
CeleryTaskQueueGateway,
DatadogMonitoringMetricsGateway,
FakeMonitoringMetricsGateway,
LiveAsyncModelEndpointInferenceGateway,
LiveBatchJobProgressGateway,
LiveModelEndpointInfraGateway,
Expand Down Expand Up @@ -58,11 +56,8 @@ async def run_batch_job(
redis = aioredis.Redis(connection_pool=pool)
redis_task_queue_gateway = CeleryTaskQueueGateway(broker_type=BrokerType.REDIS)
sqs_task_queue_gateway = CeleryTaskQueueGateway(broker_type=BrokerType.SQS)
monitoring_metrics_gateway: MonitoringMetricsGateway
if SKIP_AUTH:
monitoring_metrics_gateway = FakeMonitoringMetricsGateway()
else:
monitoring_metrics_gateway = DatadogMonitoringMetricsGateway()

monitoring_metrics_gateway = get_monitoring_metrics_gateway()
model_endpoint_record_repo = DbModelEndpointRecordRepository(
monitoring_metrics_gateway=monitoring_metrics_gateway, session=session, read_only=False
)
Expand Down
2 changes: 0 additions & 2 deletions model-engine/model_engine_server/infra/gateways/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from .batch_job_orchestration_gateway import BatchJobOrchestrationGateway
from .batch_job_progress_gateway import BatchJobProgressGateway
from .celery_task_queue_gateway import CeleryTaskQueueGateway
from .datadog_monitoring_metrics_gateway import DatadogMonitoringMetricsGateway
from .fake_model_primitive_gateway import FakeModelPrimitiveGateway
from .fake_monitoring_metrics_gateway import FakeMonitoringMetricsGateway
from .live_async_model_endpoint_inference_gateway import LiveAsyncModelEndpointInferenceGateway
Expand All @@ -26,7 +25,6 @@
"BatchJobOrchestrationGateway",
"BatchJobProgressGateway",
"CeleryTaskQueueGateway",
"DatadogMonitoringMetricsGateway",
"FakeModelPrimitiveGateway",
"FakeMonitoringMetricsGateway",
"LiveAsyncModelEndpointInferenceGateway",
Expand Down

This file was deleted.

18 changes: 4 additions & 14 deletions model-engine/model_engine_server/service_builder/tasks_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,17 @@

import aioredis
from celery.signals import worker_process_init
from model_engine_server.api.dependencies import get_monitoring_metrics_gateway
from model_engine_server.common.config import hmi_config
from model_engine_server.common.constants import READYZ_FPATH
from model_engine_server.common.dtos.endpoint_builder import (
BuildEndpointRequest,
BuildEndpointResponse,
)
from model_engine_server.common.env_vars import CIRCLECI, SKIP_AUTH
from model_engine_server.common.env_vars import CIRCLECI
from model_engine_server.core.fake_notification_gateway import FakeNotificationGateway
from model_engine_server.db.base import SessionAsyncNullPool
from model_engine_server.domain.gateways.monitoring_metrics_gateway import MonitoringMetricsGateway
from model_engine_server.infra.gateways import (
DatadogMonitoringMetricsGateway,
FakeMonitoringMetricsGateway,
S3FilesystemGateway,
)
from model_engine_server.infra.gateways import S3FilesystemGateway
from model_engine_server.infra.gateways.resources.fake_sqs_endpoint_resource_delegate import (
FakeSQSEndpointResourceDelegate,
)
Expand Down Expand Up @@ -61,14 +57,8 @@ def get_live_endpoint_builder_service(
sqs_profile=os.getenv("SQS_PROFILE", hmi_config.sqs_profile)
)
notification_gateway = FakeNotificationGateway()
monitoring_metrics_gateway: MonitoringMetricsGateway
if SKIP_AUTH:
monitoring_metrics_gateway = FakeMonitoringMetricsGateway()
else:
monitoring_metrics_gateway = DatadogMonitoringMetricsGateway()

monitoring_metrics_gateway = get_monitoring_metrics_gateway()
docker_repository = ECRDockerRepository() if not CIRCLECI else FakeDockerRepository()

service = LiveEndpointBuilderService(
docker_repository=docker_repository,
resource_gateway=LiveEndpointResourceGateway(
Expand Down
8 changes: 4 additions & 4 deletions model-engine/tests/unit/api/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,11 @@ def fake_verify_authentication(
Verifies the authentication headers and returns a (user_id, team_id) auth tuple. Otherwise,
raises a 401.
"""
auth_user_id = credentials.username if credentials is not None else None
if not auth_user_id:
raise HTTPException(status_code=401, detail="No user id was passed in")
auth_username = credentials.username if credentials is not None else None
if not auth_username:
raise HTTPException(status_code=401, detail="No authentication was passed in")

auth = auth_repo.get_auth_from_user_id(user_id=auth_user_id)
auth = auth_repo.get_auth_from_username(username=auth_username)
if not auth:
raise HTTPException(status_code=401, detail="Could not authenticate user")

Expand Down