diff --git a/model-engine/model_engine_server/common/datadog_utils.py b/model-engine/model_engine_server/common/datadog_utils.py index 3e3513cb7..5707d964c 100644 --- a/model-engine/model_engine_server/common/datadog_utils.py +++ b/model-engine/model_engine_server/common/datadog_utils.py @@ -1,10 +1,15 @@ +from typing import Optional + from ddtrace import tracer -def add_trace_request_id(request_id: str): +def add_trace_request_id(request_id: Optional[str]): """Adds a custom tag to a given dd trace corresponding to the request id so that we can filter in Datadog easier """ + if not request_id: + return + current_span = tracer.current_span() if current_span: current_span.set_tag("launch.request_id", request_id) diff --git a/model-engine/model_engine_server/common/dtos/llms.py b/model-engine/model_engine_server/common/dtos/llms.py index fc531c1f7..6e991e45e 100644 --- a/model-engine/model_engine_server/common/dtos/llms.py +++ b/model-engine/model_engine_server/common/dtos/llms.py @@ -199,7 +199,7 @@ class CompletionSyncV1Response(BaseModel): Response object for a synchronous prompt completion task. """ - request_id: str + request_id: Optional[str] output: Optional[CompletionOutput] = None @@ -273,7 +273,7 @@ class CompletionStreamV1Response(BaseModel): Response object for a stream prompt completion task. """ - request_id: str + request_id: Optional[str] output: Optional[CompletionStreamOutput] = None error: Optional[StreamError] = None """Error of the response (if any).""" diff --git a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py index 359c525b1..dcf0d0d26 100644 --- a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py @@ -9,7 +9,6 @@ import os from dataclasses import asdict from typing import Any, AsyncIterable, Dict, List, Optional, Union -from uuid import uuid4 from model_engine_server.common.config import hmi_config from model_engine_server.common.dtos.llms import ( @@ -35,7 +34,12 @@ from model_engine_server.common.dtos.tasks import SyncEndpointPredictV1Request, TaskStatus from model_engine_server.common.resource_limits import validate_resource_requests from model_engine_server.core.auth.authentication_repository import User -from model_engine_server.core.loggers import logger_name, make_logger +from model_engine_server.core.loggers import ( + LoggerTagKey, + LoggerTagManager, + logger_name, + make_logger, +) from model_engine_server.domain.entities import ( LLMInferenceFramework, LLMMetadata, @@ -1448,7 +1452,7 @@ async def execute( ObjectNotAuthorizedException: If the owner does not own the model endpoint. """ - request_id = str(uuid4()) + request_id = LoggerTagManager.get(LoggerTagKey.REQUEST_ID) add_trace_request_id(request_id) model_endpoints = await self.llm_model_endpoint_service.list_llm_model_endpoints( @@ -1736,7 +1740,7 @@ async def execute( ObjectNotAuthorizedException: If the owner does not own the model endpoint. """ - request_id = str(uuid4()) + request_id = LoggerTagManager.get(LoggerTagKey.REQUEST_ID) add_trace_request_id(request_id) model_endpoints = await self.llm_model_endpoint_service.list_llm_model_endpoints( diff --git a/model-engine/tests/unit/domain/test_llm_use_cases.py b/model-engine/tests/unit/domain/test_llm_use_cases.py index c4fbb31f6..31a579a6f 100644 --- a/model-engine/tests/unit/domain/test_llm_use_cases.py +++ b/model-engine/tests/unit/domain/test_llm_use_cases.py @@ -948,7 +948,6 @@ async def test_completion_stream_use_case_success( output_texts = ["I", " am", " a", " new", "bie", ".", "I am a newbie."] i = 0 async for message in response_1: - assert message.dict()["request_id"] assert message.dict()["output"]["text"] == output_texts[i] if i == 6: assert message.dict()["output"]["num_prompt_tokens"] == 7 @@ -1016,7 +1015,6 @@ async def test_completion_stream_text_generation_inference_use_case_success( output_texts = ["I", " am", " a", " new", "bie", ".", "I am a newbie."] i = 0 async for message in response_1: - assert message.dict()["request_id"] assert message.dict()["output"]["text"] == output_texts[i] if i == 5: assert message.dict()["output"]["num_prompt_tokens"] == 7 @@ -1079,7 +1077,6 @@ async def test_completion_stream_trt_llm_use_case_success( output_texts = ["Machine", "learning", "is", "a", "branch"] i = 0 async for message in response_1: - assert message.dict()["request_id"] assert message.dict()["output"]["text"] == output_texts[i] assert message.dict()["output"]["num_prompt_tokens"] == 7 assert message.dict()["output"]["num_completion_tokens"] == i + 1