Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions model-engine/model_engine_server/api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@
from model_engine_server.api.tasks_v1 import inference_task_router_v1
from model_engine_server.api.triggers_v1 import trigger_router_v1
from model_engine_server.core.loggers import (
LoggerTagKey,
LoggerTagManager,
filename_wo_ext,
get_request_id,
make_logger,
set_request_id,
)

logger = make_logger(filename_wo_ext(__name__))
Expand All @@ -47,11 +47,11 @@
@app.middleware("http")
async def dispatch(request: Request, call_next):
try:
set_request_id(str(uuid.uuid4()))
LoggerTagManager.set(LoggerTagKey.REQUEST_ID, str(uuid.uuid4()))
return await call_next(request)
except Exception as e:
tb_str = traceback.format_exception(etype=type(e), value=e, tb=e.__traceback__)
request_id = get_request_id()
request_id = LoggerTagManager.get(LoggerTagKey.REQUEST_ID)
timestamp = datetime.now(pytz.timezone("US/Pacific")).strftime("%Y-%m-%d %H:%M:%S %Z")
structured_log = {
"error": str(e),
Expand Down
11 changes: 10 additions & 1 deletion model-engine/model_engine_server/api/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,12 @@
from model_engine_server.core.auth.fake_authentication_repository import (
FakeAuthenticationRepository,
)
from model_engine_server.core.loggers import filename_wo_ext, make_logger
from model_engine_server.core.loggers import (
LoggerTagKey,
LoggerTagManager,
filename_wo_ext,
make_logger,
)
from model_engine_server.db.base import SessionAsync, SessionReadOnlyAsync
from model_engine_server.domain.gateways import (
CronJobGateway,
Expand Down Expand Up @@ -330,6 +335,10 @@ async def verify_authentication(
headers={"WWW-Authenticate": "Basic"},
)

# set logger context with identity data
LoggerTagManager.set(LoggerTagKey.USER_ID, auth.user_id)
LoggerTagManager.set(LoggerTagKey.TEAM_ID, auth.team_id)

return auth


Expand Down
11 changes: 8 additions & 3 deletions model-engine/model_engine_server/api/llms_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,12 @@
)
from model_engine_server.common.dtos.model_endpoints import ModelEndpointOrderBy
from model_engine_server.core.auth.authentication_repository import User
from model_engine_server.core.loggers import filename_wo_ext, get_request_id, make_logger
from model_engine_server.core.loggers import (
LoggerTagKey,
LoggerTagManager,
filename_wo_ext,
make_logger,
)
from model_engine_server.domain.exceptions import (
EndpointDeleteFailedException,
EndpointLabelsException,
Expand Down Expand Up @@ -82,7 +87,7 @@ def handle_streaming_exception(
message: str,
):
tb_str = traceback.format_exception(etype=type(e), value=e, tb=e.__traceback__)
request_id = get_request_id()
request_id = LoggerTagManager.get(LoggerTagKey.REQUEST_ID)
timestamp = datetime.now(pytz.timezone("US/Pacific")).strftime("%Y-%m-%d %H:%M:%S %Z")
structured_log = {
"error": message,
Expand Down Expand Up @@ -223,7 +228,7 @@ async def create_completion_sync_task(
user=auth, model_endpoint_name=model_endpoint_name, request=request
)
except UpstreamServiceError:
request_id = get_request_id()
request_id = LoggerTagManager.get(LoggerTagKey.REQUEST_ID)
logger.exception(f"Upstream service error for request {request_id}")
raise HTTPException(
status_code=500,
Expand Down
48 changes: 33 additions & 15 deletions model-engine/model_engine_server/core/loggers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
import sys
import warnings
from contextlib import contextmanager
from typing import Optional, Sequence
from enum import Enum
from typing import Dict, Optional, Sequence

import ddtrace
import json_log_formatter
Expand All @@ -16,8 +17,6 @@
LOG_FORMAT: str = "%(asctime)s %(levelname)s [%(name)s] [%(filename)s:%(lineno)d] - %(message)s"
# REQUIRED FOR DATADOG COMPATIBILITY

ctx_var_request_id = contextvars.ContextVar("ctx_var_request_id", default=None)

__all__: Sequence[str] = (
# most common imports
"make_logger",
Expand All @@ -35,19 +34,37 @@
"loggers_at_level",
# utils
"filename_wo_ext",
"get_request_id",
"set_request_id",
"LoggerTagKey",
"LoggerTagManager",
)


def get_request_id() -> Optional[str]:
"""Get the request id from the context variable."""
return ctx_var_request_id.get()
class LoggerTagKey(str, Enum):
REQUEST_ID = "request_id"
TEAM_ID = "team_id"
USER_ID = "user_id"


class LoggerTagManager:
_context_vars: Dict[LoggerTagKey, contextvars.ContextVar] = {}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wonder if a defaultdict works here, although I guess it's probably more effort to get the code into a form where it works with defaultdict


@classmethod
def get(cls, key: LoggerTagKey) -> Optional[str]:
"""Get the value from the context variable."""
ctx_var = cls._context_vars.get(key)
if ctx_var is not None:
return ctx_var.get()
return None

def set_request_id(request_id: str) -> None:
"""Set the request id in the context variable."""
ctx_var_request_id.set(request_id) # type: ignore
@classmethod
def set(cls, key: LoggerTagKey, value: Optional[str]) -> None:
"""Set the value in the context variable."""
if value is not None:
ctx_var = cls._context_vars.get(key)
if ctx_var is None:
ctx_var = contextvars.ContextVar(f"ctx_var_{key.name.lower()}", default=None)
cls._context_vars[key] = ctx_var
ctx_var.set(value)


def make_standard_logger(name: str, log_level: int = logging.INFO) -> logging.Logger:
Expand Down Expand Up @@ -77,10 +94,11 @@ def json_record(self, message: str, extra: dict, record: logging.LogRecord) -> d
extra["lineno"] = record.lineno
extra["pathname"] = record.pathname

# add the http request id if it exists
request_id = ctx_var_request_id.get()
if request_id:
extra["request_id"] = request_id
# add additional logger tags
for tag_key in LoggerTagKey:
tag_value = LoggerTagManager.get(tag_key)
if tag_value:
extra[tag_key.value] = tag_value

current_span = tracer.current_span()
extra["dd.trace_id"] = current_span.trace_id if current_span else 0
Expand Down