diff --git a/.gitignore b/.gitignore index 3b58d13fc..4ae1304e3 100644 --- a/.gitignore +++ b/.gitignore @@ -50,6 +50,7 @@ coverage.xml .hypothesis/ .pytest_cache/ cover/ +tests/test_results/ # Translations *.mo diff --git a/pdm.lock b/pdm.lock index 92f684197..5de43914b 100644 --- a/pdm.lock +++ b/pdm.lock @@ -5,7 +5,7 @@ groups = ["default", "dev"] strategy = ["inherit_metadata"] lock_version = "4.5.0" -content_hash = "sha256:db82049ff8c8d98dacd64aa05d47b871510a406ed837a638ea81b30eeece7ab5" +content_hash = "sha256:f3dde2e916169abc41f23400e9488ae43f2eeb203c4dbb2e505ba28b9a853677" [[metadata.targets]] requires_python = ">=3.11.1,<=3.12.10" @@ -1061,6 +1061,20 @@ files = [ {file = "pytest_cov-6.1.1.tar.gz", hash = "sha256:46935f7aaefba760e716c2ebfbe1c216240b9592966e7da99ea8292d4d3e2a0a"}, ] +[[package]] +name = "pytest-mock" +version = "3.14.1" +requires_python = ">=3.8" +summary = "Thin-wrapper around the mock package for easier use with pytest" +groups = ["dev"] +dependencies = [ + "pytest>=6.2.5", +] +files = [ + {file = "pytest_mock-3.14.1-py3-none-any.whl", hash = "sha256:178aefcd11307d874b4cd3100344e7e2d888d9791a6a1d9bfe90fbc1b74fd1d0"}, + {file = "pytest_mock-3.14.1.tar.gz", hash = "sha256:159e9edac4c451ce77a5cdb9fc5d1100708d2dd4ba3c3df572f14097351af80e"}, +] + [[package]] name = "python-dateutil" version = "2.9.0.post0" diff --git a/pyproject.toml b/pyproject.toml index b5f751c9a..e41a71a42 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,6 +25,7 @@ dev = [ "black>=25.1.0", "pytest>=8.3.2", "pytest-cov>=5.0.0", + "pytest-mock>=3.14.0", ] [tool.pdm.scripts] diff --git a/src/app/endpoints/query.py b/src/app/endpoints/query.py index abe2eaa9f..bea2162d5 100644 --- a/src/app/endpoints/query.py +++ b/src/app/endpoints/query.py @@ -7,46 +7,82 @@ from llama_stack_client import LlamaStackClient # type: ignore from llama_stack_client.types import UserMessage # type: ignore -from fastapi import APIRouter, Request +from fastapi import APIRouter, Request, HTTPException, status from client import get_llama_stack_client from configuration import configuration from models.responses import QueryResponse +from models.requests import QueryRequest, Attachment +import constants logger = logging.getLogger("app.endpoints.handlers") -router = APIRouter(tags=["models"]) +router = APIRouter(tags=["query"]) query_response: dict[int | str, dict[str, Any]] = { 200: { - "query": "User query", - "answer": "LLM ansert", + "conversation_id": "123e4567-e89b-12d3-a456-426614174000", + "response": "LLM ansert", }, } @router.post("/query", responses=query_response) -def query_endpoint_handler(request: Request, query: str) -> QueryResponse: +def query_endpoint_handler( + request: Request, query_request: QueryRequest +) -> QueryResponse: llama_stack_config = configuration.llama_stack_configuration logger.info("LLama stack config: %s", llama_stack_config) - client = get_llama_stack_client(llama_stack_config) - - # retrieve list of available models - models = client.models.list() - - # select the first LLM - llm = next(m for m in models if m.model_type == "llm") - model_id = llm.identifier - - logger.info("Model: %s", model_id) - - response = retrieve_response(client, model_id, query) - - return QueryResponse(query=query, response=response) + model_id = select_model_id(client, query_request) + response = retrieve_response(client, model_id, query_request) + return QueryResponse( + conversation_id=query_request.conversation_id, response=response + ) -def retrieve_response(client: LlamaStackClient, model_id: str, prompt: str) -> str: +def select_model_id(client: LlamaStackClient, query_request: QueryRequest) -> str: + """Select the model ID based on the request or available models.""" + models = client.models.list() + model_id = query_request.model + provider_id = query_request.provider + + # TODO(lucasagomes): support default model selection via configuration + if not model_id: + logger.info("No model specified in request, using the first available LLM") + try: + return next(m for m in models if m.model_type == "llm").identifier + except (StopIteration, AttributeError): + message = "No LLM model found in available models" + logger.error(message) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail={ + "response": constants.UNABLE_TO_PROCESS_RESPONSE, + "cause": message, + }, + ) + + logger.info(f"Searching for model: {model_id}, provider: {provider_id}") + if not any( + m.identifier == model_id and m.provider_id == provider_id for m in models + ): + message = f"Model {model_id} from provider {provider_id} not found in available models" + logger.error(message) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail={ + "response": constants.UNABLE_TO_PROCESS_RESPONSE, + "cause": message, + }, + ) + + return model_id + + +def retrieve_response( + client: LlamaStackClient, model_id: str, query_request: QueryRequest +) -> str: available_shields = [shield.identifier for shield in client.shields.list()] if not available_shields: @@ -54,18 +90,61 @@ def retrieve_response(client: LlamaStackClient, model_id: str, prompt: str) -> s else: logger.info(f"Available shields found: {available_shields}") + # use system prompt from request or default one + system_prompt = ( + query_request.system_prompt + if query_request.system_prompt + else constants.DEFAULT_SYSTEM_PROMPT + ) + logger.debug(f"Using system prompt: {system_prompt}") + + # TODO(lucasagomes): redact attachments content before sending to LLM + # if attachments are provided, validate them + if query_request.attachments: + validate_attachments_metadata(query_request.attachments) + agent = Agent( client, model=model_id, - instructions="You are a helpful assistant", + instructions=system_prompt, input_shields=available_shields if available_shields else [], tools=[], ) session_id = agent.create_session("chat_session") response = agent.create_turn( - messages=[UserMessage(role="user", content=prompt)], + messages=[UserMessage(role="user", content=query_request.query)], session_id=session_id, + documents=query_request.get_documents(), stream=False, ) return str(response.output_message.content) + + +def validate_attachments_metadata(attachments: list[Attachment]) -> None: + """Validate the attachments metadata provided in the request. + Raises HTTPException if any attachment has an improper type or content type. + """ + for attachment in attachments: + if attachment.attachment_type not in constants.ATTACHMENT_TYPES: + message = ( + f"Attachment with improper type {attachment.attachment_type} detected" + ) + logger.error(message) + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail={ + "response": constants.UNABLE_TO_PROCESS_RESPONSE, + "cause": message, + }, + ) + if attachment.content_type not in constants.ATTACHMENT_CONTENT_TYPES: + message = f"Attachment with improper content type {attachment.content_type} detected" + logger.error(message) + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail={ + "response": constants.UNABLE_TO_PROCESS_RESPONSE, + "cause": message, + }, + ) diff --git a/src/constants.py b/src/constants.py new file mode 100644 index 000000000..5699962e8 --- /dev/null +++ b/src/constants.py @@ -0,0 +1,21 @@ +UNABLE_TO_PROCESS_RESPONSE = "Unable to process this request" + +# Supported attachment types +ATTACHMENT_TYPES = frozenset( + { + "alert", + "api object", + "configuration", + "error message", + "event", + "log", + "stack trace", + } +) + +# Supported attachment content types +ATTACHMENT_CONTENT_TYPES = frozenset( + {"text/plain", "application/json", "application/yaml", "application/xml"} +) + +DEFAULT_SYSTEM_PROMPT = "You are a helpful assistant" diff --git a/src/models/requests.py b/src/models/requests.py new file mode 100644 index 000000000..bddd6b796 --- /dev/null +++ b/src/models/requests.py @@ -0,0 +1,127 @@ +from pydantic import BaseModel, model_validator +from llama_stack_client.types.agents.turn_create_params import Document +from typing import Optional, Self + + +class Attachment(BaseModel): + """Model representing an attachment that can be send from UI as part of query. + + List of attachments can be optional part of 'query' request. + + Attributes: + attachment_type: The attachment type, like "log", "configuration" etc. + content_type: The content type as defined in MIME standard + content: The actual attachment content + + YAML attachments with **kind** and **metadata/name** attributes will + be handled as resources with specified name: + ``` + kind: Pod + metadata: + name: private-reg + ``` + """ + + attachment_type: str + content_type: str + content: str + + # provides examples for /docs endpoint + model_config = { + "json_schema_extra": { + "examples": [ + { + "attachment_type": "log", + "content_type": "text/plain", + "content": "this is attachment", + }, + { + "attachment_type": "configuration", + "content_type": "application/yaml", + "content": "kind: Pod\n metadata:\n name: private-reg", + }, + { + "attachment_type": "configuration", + "content_type": "application/yaml", + "content": "foo: bar", + }, + ] + } + } + + +# TODO(lucasagomes): add media_type when needed, current implementation +# does not support streaming response, so this is not used +class QueryRequest(BaseModel): + """Model representing a request for the LLM (Language Model). + + Attributes: + query: The query string. + conversation_id: The optional conversation ID (UUID). + provider: The optional provider. + model: The optional model. + attachments: The optional attachments. + + Example: + ```python + query_request = QueryRequest(query="Tell me about Kubernetes") + ``` + """ + + query: str + conversation_id: Optional[str] = None + provider: Optional[str] = None + model: Optional[str] = None + system_prompt: Optional[str] = None + attachments: Optional[list[Attachment]] = None + + # provides examples for /docs endpoint + model_config = { + "extra": "forbid", + "json_schema_extra": { + "examples": [ + { + "query": "write a deployment yaml for the mongodb image", + "conversation_id": "123e4567-e89b-12d3-a456-426614174000", + "provider": "openai", + "model": "model-name", + "system_prompt": "You are a helpful assistant", + "attachments": [ + { + "attachment_type": "log", + "content_type": "text/plain", + "content": "this is attachment", + }, + { + "attachment_type": "configuration", + "content_type": "application/yaml", + "content": "kind: Pod\n metadata:\n name: private-reg", + }, + { + "attachment_type": "configuration", + "content_type": "application/yaml", + "content": "foo: bar", + }, + ], + } + ] + }, + } + + def get_documents(self) -> list[Document]: + """Returns the list of documents from the attachments.""" + if not self.attachments: + return [] + return [ + Document(content=att.content, mime_type=att.content_type) + for att in self.attachments + ] + + @model_validator(mode="after") + def validate_provider_and_model(self) -> Self: + """Perform validation on the provider and model.""" + if self.model and not self.provider: + raise ValueError("Provider must be specified if model is specified") + if self.provider and not self.model: + raise ValueError("Model must be specified if provider is specified") + return self diff --git a/src/models/responses.py b/src/models/responses.py index 5aad89ce9..b13982690 100644 --- a/src/models/responses.py +++ b/src/models/responses.py @@ -1,5 +1,5 @@ from pydantic import BaseModel -from typing import Any +from typing import Any, Optional class ModelsResponse(BaseModel): @@ -8,12 +8,40 @@ class ModelsResponse(BaseModel): models: list[dict[str, Any]] +# TODO(lucasagomes): a lot of fields to add to QueryResponse. For now +# we are keeping it simple. The missing fields are: +# - referenced_documents: The optional URLs and titles for the documents used +# to generate the response. +# - truncated: Set to True if conversation history was truncated to be within context window. +# - input_tokens: Number of tokens sent to LLM +# - output_tokens: Number of tokens received from LLM +# - available_quotas: Quota available as measured by all configured quota limiters +# - tool_calls: List of tool requests. +# - tool_results: List of tool results. +# See LLMResponse in ols-service for more details. class QueryResponse(BaseModel): - """Model representing LLM response to a query.""" + """Model representing LLM response to a query. - query: str + Attributes: + conversation_id: The optional conversation ID (UUID). + response: The response. + """ + + conversation_id: Optional[str] = None response: str + # provides examples for /docs endpoint + model_config = { + "json_schema_extra": { + "examples": [ + { + "conversation_id": "123e4567-e89b-12d3-a456-426614174000", + "response": "Operator Lifecycle Manager (OLM) helps users install...", + } + ] + } + } + class InfoResponse(BaseModel): """Model representing a response to a info request. diff --git a/tests/unit/app/endpoints/test_query.py b/tests/unit/app/endpoints/test_query.py new file mode 100644 index 000000000..db9698e7d --- /dev/null +++ b/tests/unit/app/endpoints/test_query.py @@ -0,0 +1,174 @@ +from fastapi import HTTPException, status +import pytest + +from app.endpoints.query import ( + query_endpoint_handler, + select_model_id, + retrieve_response, + validate_attachments_metadata, +) +from models.requests import QueryRequest, Attachment +from llama_stack_client.types import UserMessage # type: ignore + + +def test_query_endpoint_handler(mocker): + """Test the query endpoint handler.""" + mock_client = mocker.Mock() + mock_client.models.list.return_value = [ + mocker.Mock(identifier="model1", model_type="llm", provider_id="provider1"), + mocker.Mock(identifier="model2", model_type="llm", provider_id="provider2"), + ] + + mocker.patch( + "app.endpoints.query.configuration", + return_value=mocker.Mock(), + ) + mocker.patch("app.endpoints.query.get_llama_stack_client", return_value=mock_client) + mocker.patch("app.endpoints.query.retrieve_response", return_value="LLM answer") + mocker.patch("app.endpoints.query.select_model_id", return_value="fake_model_id") + + query_request = QueryRequest(query="What is OpenStack?") + + response = query_endpoint_handler(None, query_request) + + assert response.response == "LLM answer" + + +def test_select_model_id(mocker): + """Test the select_model_id function.""" + mock_client = mocker.Mock() + mock_client.models.list.return_value = [ + mocker.Mock(identifier="model1", model_type="llm", provider_id="provider1"), + mocker.Mock(identifier="model2", model_type="llm", provider_id="provider2"), + ] + + query_request = QueryRequest( + query="What is OpenStack?", model="model1", provider="provider1" + ) + + model_id = select_model_id(mock_client, query_request) + + assert model_id == "model1" + + +def test_select_model_id_no_model(mocker): + """Test the select_model_id function when no model is specified.""" + mock_client = mocker.Mock() + mock_client.models.list.return_value = [ + mocker.Mock( + identifier="not_llm_type", model_type="embedding", provider_id="provider1" + ), + mocker.Mock( + identifier="first_model", model_type="llm", provider_id="provider1" + ), + mocker.Mock( + identifier="second_model", model_type="llm", provider_id="provider2" + ), + ] + + query_request = QueryRequest(query="What is OpenStack?") + + model_id = select_model_id(mock_client, query_request) + + # Assert return the first available LLM model + assert model_id == "first_model" + + +def test_select_model_id_invalid_model(mocker): + """Test the select_model_id function with an invalid model.""" + mock_client = mocker.Mock() + mock_client.models.list.return_value = [ + mocker.Mock(identifier="model1", model_type="llm", provider_id="provider1"), + ] + + query_request = QueryRequest( + query="What is OpenStack?", model="invalid_model", provider="provider1" + ) + + with pytest.raises(Exception) as exc_info: + select_model_id(mock_client, query_request) + + assert ( + "Model invalid_model from provider provider1 not found in available models" + in str(exc_info.value) + ) + + +def test_validate_attachments_metadata(): + """Test the validate_attachments_metadata function.""" + attachments = [ + Attachment( + attachment_type="log", + content_type="text/plain", + content="this is attachment", + ), + Attachment( + attachment_type="configuration", + content_type="application/yaml", + content="kind: Pod\n metadata:\n name: private-reg", + ), + ] + + # If no exception is raised, the test passes + validate_attachments_metadata(attachments) + + +def test_validate_attachments_metadata_invalid_type(): + """Test the validate_attachments_metadata function with invalid attachment type.""" + attachments = [ + Attachment( + attachment_type="invalid_type", + content_type="text/plain", + content="this is attachment", + ), + ] + + with pytest.raises(HTTPException) as exc_info: + validate_attachments_metadata(attachments) + assert exc_info.value.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY + assert ( + "Attachment with improper type invalid_type detected" + in exc_info.value.detail["cause"] + ) + + +def test_validate_attachments_metadata_invalid_content_type(): + """Test the validate_attachments_metadata function with invalid attachment type.""" + attachments = [ + Attachment( + attachment_type="log", + content_type="text/invalid_content_type", + content="this is attachment", + ), + ] + + with pytest.raises(HTTPException) as exc_info: + validate_attachments_metadata(attachments) + assert exc_info.value.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY + assert ( + "Attachment with improper content type text/invalid_content_type detected" + in exc_info.value.detail["cause"] + ) + + +def test_retrieve_response(mocker): + """Test the retrieve_response function.""" + mock_agent = mocker.Mock() + mock_agent.create_turn.return_value.output_message.content = "LLM answer" + mock_client = mocker.Mock() + mock_client.shields.list.return_value = [] + + mocker.patch("app.endpoints.query.Agent", return_value=mock_agent) + + query_request = QueryRequest(query="What is OpenStack?") + model_id = "fake_model_id" + + response = retrieve_response(mock_client, model_id, query_request) + + assert response == "LLM answer" + mock_agent.create_turn.assert_called_once_with( + messages=[UserMessage(content="What is OpenStack?", role="user", context=None)], + session_id=mocker.ANY, + documents=[], + stream=False, + ) diff --git a/tests/unit/models/test_requests.py b/tests/unit/models/test_requests.py new file mode 100644 index 000000000..a9db072f1 --- /dev/null +++ b/tests/unit/models/test_requests.py @@ -0,0 +1,125 @@ +import pytest + +from models.requests import QueryRequest, Attachment + + +class TestAttachment: + """Test cases for the Attachment model.""" + + def test_constructor(self) -> None: + """Test the Attachment with custom values.""" + a = Attachment( + attachment_type="configuration", + content_type="application/yaml", + content="kind: Pod\n metadata:\n name: private-reg", + ) + assert a.attachment_type == "configuration" + assert a.content_type == "application/yaml" + assert a.content == "kind: Pod\n metadata:\n name: private-reg" + + +class TestQueryRequest: + """Test cases for the QueryRequest model.""" + + def test_constructor(self) -> None: + """Test the QueryRequest constructor.""" + qr = QueryRequest(query="Tell me about Kubernetes") + + assert qr.query == "Tell me about Kubernetes" + assert qr.conversation_id is None + assert qr.provider is None + assert qr.model is None + assert qr.system_prompt is None + assert qr.attachments is None + + def test_with_attachments(self) -> None: + """Test the QueryRequest with attachments.""" + attachments = [ + Attachment( + attachment_type="log", + content_type="text/plain", + content="this is attachment", + ), + Attachment( + attachment_type="configuration", + content_type="application/yaml", + content="kind: Pod\n metadata:\n name: private-reg", + ), + ] + qr = QueryRequest( + query="Tell me about Kubernetes", + attachments=attachments, + ) + assert len(qr.attachments) == 2 + assert qr.attachments[0].attachment_type == "log" + assert qr.attachments[0].content_type == "text/plain" + assert qr.attachments[0].content == "this is attachment" + assert qr.attachments[1].attachment_type == "configuration" + assert qr.attachments[1].content_type == "application/yaml" + assert ( + qr.attachments[1].content == "kind: Pod\n metadata:\n name: private-reg" + ) + + def test_with_optional_fields(self) -> None: + """Test the QueryRequest with optional fields.""" + qr = QueryRequest( + query="Tell me about Kubernetes", + conversation_id="123e4567-e89b-12d3-a456-426614174000", + provider="OpenAI", + model="gpt-3.5-turbo", + system_prompt="You are a helpful assistant", + ) + assert qr.query == "Tell me about Kubernetes" + assert qr.conversation_id == "123e4567-e89b-12d3-a456-426614174000" + assert qr.provider == "OpenAI" + assert qr.model == "gpt-3.5-turbo" + assert qr.system_prompt == "You are a helpful assistant" + assert qr.attachments is None + + def test_get_documents(self) -> None: + """Test the get_documents method.""" + attachments = [ + Attachment( + attachment_type="log", + content_type="text/plain", + content="this is attachment", + ), + Attachment( + attachment_type="configuration", + content_type="application/yaml", + content="kind: Pod\n metadata:\n name: private-reg", + ), + ] + qr = QueryRequest( + query="Tell me about Kubernetes", + attachments=attachments, + ) + documents = qr.get_documents() + assert len(documents) == 2 + assert documents[0]["content"] == "this is attachment" + assert documents[0]["mime_type"] == "text/plain" + assert documents[1]["content"] == "kind: Pod\n metadata:\n name: private-reg" + assert documents[1]["mime_type"] == "application/yaml" + + def test_validate_provider_and_model(self) -> None: + """Test the validate_provider_and_model method.""" + qr = QueryRequest( + query="Tell me about Kubernetes", + provider="OpenAI", + model="gpt-3.5-turbo", + ) + validated_qr = qr.validate_provider_and_model() + assert validated_qr.provider == "OpenAI" + assert validated_qr.model == "gpt-3.5-turbo" + + # Test with missing provider + with pytest.raises( + ValueError, match="Provider must be specified if model is specified" + ): + QueryRequest(query="Tell me about Kubernetes", model="gpt-3.5-turbo") + + # Test with missing model + with pytest.raises( + ValueError, match="Model must be specified if provider is specified" + ): + QueryRequest(query="Tell me about Kubernetes", provider="OpenAI") diff --git a/tests/unit/models/test_responses.py b/tests/unit/models/test_responses.py new file mode 100644 index 000000000..b9db5f29a --- /dev/null +++ b/tests/unit/models/test_responses.py @@ -0,0 +1,20 @@ +from models.responses import QueryResponse + + +class TestQueryResponse: + """Test cases for the QueryResponse model.""" + + def test_constructor(self) -> None: + """Test the QueryResponse constructor.""" + qr = QueryResponse( + conversation_id="123e4567-e89b-12d3-a456-426614174000", + response="LLM answer", + ) + assert qr.conversation_id == "123e4567-e89b-12d3-a456-426614174000" + assert qr.response == "LLM answer" + + def test_optional_conversation_id(self) -> None: + """Test the QueryResponse with default conversation ID.""" + qr = QueryResponse(response="LLM answer") + assert qr.conversation_id is None + assert qr.response == "LLM answer"