diff --git a/README.md b/README.md index a3ab108c6b..9e282b3e10 100644 --- a/README.md +++ b/README.md @@ -176,9 +176,10 @@ Also see [architecture](docs/ARCHITECTURE.md). | Provider | Status | Provider | Status | |----------|--------|----------|--------| | OpenAI | ✅ | Azure OpenAI | ✅ | -| Anthropic Claude | ✅ | Google Gemini | ✅ | -| AWS Bedrock | ✅ | Mistral AI | ✅ | -| Ollama (local) | ✅ | Anyscale | ✅ | +| OpenAI Compatible | ✅ | Anthropic Claude | ✅ | +| AWS Bedrock | ✅ | Google Gemini | ✅ | +| Ollama (local) | ✅ | Mistral AI | ✅ | +| Anyscale | ✅ | | | ### Vector Databases diff --git a/frontend/public/icons/adapter-icons/OpenAICompatible.png b/frontend/public/icons/adapter-icons/OpenAICompatible.png new file mode 100644 index 0000000000..ec23189d9a Binary files /dev/null and b/frontend/public/icons/adapter-icons/OpenAICompatible.png differ diff --git a/unstract/sdk1/src/unstract/sdk1/adapters/base1.py b/unstract/sdk1/src/unstract/sdk1/adapters/base1.py index 666aa50c7b..7bc58296c8 100644 --- a/unstract/sdk1/src/unstract/sdk1/adapters/base1.py +++ b/unstract/sdk1/src/unstract/sdk1/adapters/base1.py @@ -346,6 +346,32 @@ def validate_model(adapter_metadata: dict[str, "Any"]) -> str: return f"openai/{model}" +class OpenAICompatibleLLMParameters(BaseChatCompletionParameters): + """See https://docs.litellm.ai/docs/providers/openai_compatible/.""" + + api_key: str | None = None + api_base: str + + @staticmethod + def validate(adapter_metadata: dict[str, "Any"]) -> dict[str, "Any"]: + adapter_metadata["model"] = OpenAICompatibleLLMParameters.validate_model( + adapter_metadata + ) + api_key = adapter_metadata.get("api_key") + if isinstance(api_key, str) and not api_key.strip(): + adapter_metadata["api_key"] = None + return OpenAICompatibleLLMParameters(**adapter_metadata).model_dump() + + @staticmethod + def validate_model(adapter_metadata: dict[str, "Any"]) -> str: + model = str(adapter_metadata.get("model", "")).strip() + if not model: + raise ValueError("model is required for the OpenAI Compatible adapter.") + if model.startswith("custom_openai/"): + return model + return f"custom_openai/{model}" + + class AzureOpenAILLMParameters(BaseChatCompletionParameters): """See https://docs.litellm.ai/docs/providers/azure/#completion---using-azure_ad_token-api_base-api_version.""" diff --git a/unstract/sdk1/src/unstract/sdk1/adapters/llm1/__init__.py b/unstract/sdk1/src/unstract/sdk1/adapters/llm1/__init__.py index c23a33390a..1da3590f51 100644 --- a/unstract/sdk1/src/unstract/sdk1/adapters/llm1/__init__.py +++ b/unstract/sdk1/src/unstract/sdk1/adapters/llm1/__init__.py @@ -8,6 +8,7 @@ from unstract.sdk1.adapters.llm1.bedrock import AWSBedrockLLMAdapter from unstract.sdk1.adapters.llm1.ollama import OllamaLLMAdapter from unstract.sdk1.adapters.llm1.openai import OpenAILLMAdapter +from unstract.sdk1.adapters.llm1.openai_compatible import OpenAICompatibleLLMAdapter from unstract.sdk1.adapters.llm1.vertexai import VertexAILLMAdapter adapters: dict[str, dict[str, Any]] = {} @@ -22,5 +23,6 @@ "AzureOpenAILLMAdapter", "OllamaLLMAdapter", "OpenAILLMAdapter", + "OpenAICompatibleLLMAdapter", "VertexAILLMAdapter", ] diff --git a/unstract/sdk1/src/unstract/sdk1/adapters/llm1/openai_compatible.py b/unstract/sdk1/src/unstract/sdk1/adapters/llm1/openai_compatible.py new file mode 100644 index 0000000000..3cb3ceafc4 --- /dev/null +++ b/unstract/sdk1/src/unstract/sdk1/adapters/llm1/openai_compatible.py @@ -0,0 +1,46 @@ +from typing import Any + +from unstract.sdk1.adapters.base1 import BaseAdapter, OpenAICompatibleLLMParameters +from unstract.sdk1.adapters.enums import AdapterTypes + +DESCRIPTION = ( + "Adapter for servers that implement the OpenAI Chat Completions API " + "(vLLM, LM Studio, self-hosted gateways, and third-party providers). " + "Use OpenAI for the official OpenAI service." +) + + +class OpenAICompatibleLLMAdapter(OpenAICompatibleLLMParameters, BaseAdapter): + @staticmethod + def get_id() -> str: + return "openaicompatible|b6d10f33-2c41-49fc-a8c2-58d2b247fc09" + + @staticmethod + def get_metadata() -> dict[str, Any]: + return { + "name": "OpenAI Compatible", + "version": "1.0.0", + "adapter": OpenAICompatibleLLMAdapter, + "description": DESCRIPTION, + "is_active": True, + } + + @staticmethod + def get_name() -> str: + return "OpenAI Compatible" + + @staticmethod + def get_description() -> str: + return DESCRIPTION + + @staticmethod + def get_provider() -> str: + return "custom_openai" + + @staticmethod + def get_icon() -> str: + return "/icons/adapter-icons/OpenAICompatible.png" + + @staticmethod + def get_adapter_type() -> AdapterTypes: + return AdapterTypes.LLM diff --git a/unstract/sdk1/src/unstract/sdk1/adapters/llm1/static/custom_openai.json b/unstract/sdk1/src/unstract/sdk1/adapters/llm1/static/custom_openai.json new file mode 100644 index 0000000000..8767ffdcf4 --- /dev/null +++ b/unstract/sdk1/src/unstract/sdk1/adapters/llm1/static/custom_openai.json @@ -0,0 +1,60 @@ +{ + "title": "OpenAI Compatible", + "type": "object", + "required": [ + "adapter_name", + "api_base" + ], + "properties": { + "adapter_name": { + "type": "string", + "title": "Name", + "default": "", + "description": "Provide a unique name for this adapter instance. Example: compatible-gateway-1" + }, + "api_key": { + "type": [ + "string", + "null" + ], + "title": "API Key", + "format": "password", + "description": "API key for your OpenAI-compatible endpoint. Leave empty if the endpoint does not require one." + }, + "model": { + "type": "string", + "title": "Model", + "description": "The model name expected by your OpenAI-compatible endpoint. Examples: gateway-model, gpt-4o-mini, openai/gpt-4o" + }, + "api_base": { + "type": "string", + "format": "url", + "title": "API Base", + "description": "Base URL for the OpenAI-compatible endpoint. Examples: https://gateway.example.com/v1, https://llm.example.net/openai/v1" + }, + "max_tokens": { + "type": "number", + "minimum": 0, + "multipleOf": 1, + "title": "Maximum Output Tokens", + "default": 4096, + "description": "Maximum number of output tokens to limit LLM replies. Leave it empty to use the provider default." + }, + "max_retries": { + "type": "number", + "minimum": 0, + "multipleOf": 1, + "title": "Max Retries", + "default": 5, + "description": "The maximum number of times to retry a request if it fails." + }, + "timeout": { + "type": "number", + "minimum": 0, + "multipleOf": 1, + "title": "Timeout", + "default": 900, + "description": "Timeout in seconds." + } + } +} diff --git a/unstract/sdk1/tests/test_openai_compatible_adapter.py b/unstract/sdk1/tests/test_openai_compatible_adapter.py new file mode 100644 index 0000000000..addbf080ce --- /dev/null +++ b/unstract/sdk1/tests/test_openai_compatible_adapter.py @@ -0,0 +1,172 @@ +import json +from functools import lru_cache +from importlib import import_module +from unittest.mock import MagicMock, patch + +from unstract.sdk1.adapters.base1 import OpenAICompatibleLLMParameters +from unstract.sdk1.adapters.constants import Common +from unstract.sdk1.adapters.llm1 import adapters +from unstract.sdk1.adapters.llm1.openai_compatible import OpenAICompatibleLLMAdapter + +OPENAI_COMPATIBLE_DESCRIPTION = ( + "Adapter for servers that implement the OpenAI Chat Completions API " + "(vLLM, LM Studio, self-hosted gateways, and third-party providers). " + "Use OpenAI for the official OpenAI service." +) + + +@lru_cache(maxsize=1) +def _load_llm_module() -> object: + import sys + from types import ModuleType + + # Stub python-magic so importing LLM does not depend on libmagic + # being available in the test environment. sys.modules entries set + # here must persist (no patch.dict) so litellm and other lazy-loaded + # modules stay resolvable across tests. + sys.modules.setdefault("magic", ModuleType("magic")) + return import_module("unstract.sdk1.llm") + + +def _load_llm_class() -> type: + return _load_llm_module().LLM + + +def test_openai_compatible_adapter_is_registered() -> None: + adapter_id = OpenAICompatibleLLMAdapter.get_id() + + assert adapter_id in adapters + assert adapters[adapter_id][Common.MODULE] is OpenAICompatibleLLMAdapter + + +def test_openai_compatible_validate_prefixes_model() -> None: + validated = OpenAICompatibleLLMParameters.validate( + { + "api_base": "https://gateway.example.com/v1", + "api_key": "test-key", + "model": "gateway-model", + } + ) + + assert validated["model"] == "custom_openai/gateway-model" + + +def test_openai_compatible_validate_preserves_prefixed_model() -> None: + validated = OpenAICompatibleLLMParameters.validate( + { + "api_base": "https://gateway.example.com/v1", + "model": "custom_openai/openai/gpt-4o", + } + ) + + assert validated["model"] == "custom_openai/openai/gpt-4o" + assert validated["api_key"] is None + + +def test_openai_compatible_validate_normalizes_blank_api_key_to_none() -> None: + validated = OpenAICompatibleLLMParameters.validate( + { + "api_base": "https://gateway.example.com/v1", + "api_key": " ", + "model": "gateway-model", + } + ) + + assert validated["api_key"] is None + + +def test_openai_compatible_schema_is_loadable() -> None: + schema = json.loads(OpenAICompatibleLLMAdapter.get_json_schema()) + + assert schema["title"] == "OpenAI Compatible" + assert schema["properties"]["api_key"]["type"] == ["string", "null"] + assert "default" not in schema["properties"]["model"] + assert "gateway-model" in schema["properties"]["model"]["description"] + assert "ERNIE" not in schema["properties"]["model"]["description"] + assert "qianfan" not in schema["properties"]["api_base"]["description"].lower() + assert "default" not in schema["properties"]["api_base"] + + +def test_openai_compatible_adapter_uses_distinct_description_and_icon() -> None: + metadata = OpenAICompatibleLLMAdapter.get_metadata() + + assert OpenAICompatibleLLMAdapter.get_description() == OPENAI_COMPATIBLE_DESCRIPTION + assert metadata["description"] == OPENAI_COMPATIBLE_DESCRIPTION + assert OpenAICompatibleLLMAdapter.get_icon() == ( + "/icons/adapter-icons/OpenAICompatible.png" + ) + + +def _build_llm_for_record_usage(llm_cls: type) -> object: + llm = llm_cls.__new__(llm_cls) + llm._platform_api_key = "platform-key" + llm.platform_kwargs = {"run_id": "run-1"} + llm._usage_kwargs = {} + llm._pending_usage = [] + llm.adapter = MagicMock() + llm.adapter.get_provider.return_value = "custom_openai" + return llm + + +def test_record_usage_uses_reported_prompt_tokens_without_estimating() -> None: + llm_module = _load_llm_module() + llm = _build_llm_for_record_usage(llm_module.LLM) + + with ( + patch.object(llm_module.litellm, "token_counter") as mock_token_counter, + patch.object(llm_module.litellm, "cost_per_token", return_value=(0.0, 0.0)), + ): + llm._record_usage( + model="custom_openai/gateway-model", + messages=[{"role": "user", "content": "hello"}], + usage={"prompt_tokens": 3, "completion_tokens": 4, "total_tokens": 7}, + llm_api="complete", + ) + + mock_token_counter.assert_not_called() + assert len(llm._pending_usage) == 1 + assert llm._pending_usage[0]["prompt_tokens"] == 3 + + +def test_record_usage_tolerates_unmapped_models_without_prompt_tokens() -> None: + llm_module = _load_llm_module() + llm = _build_llm_for_record_usage(llm_module.LLM) + + with ( + patch.object( + llm_module.litellm, "token_counter", side_effect=Exception("unmapped") + ), + patch.object(llm_module.litellm, "cost_per_token", return_value=(0.0, 0.0)), + patch.object(llm_module.logger, "warning") as mock_warning, + ): + llm._record_usage( + model="custom_openai/gateway-model", + messages=[{"role": "user", "content": "hello"}], + usage={"completion_tokens": 4, "total_tokens": 7}, + llm_api="complete", + ) + + assert len(llm._pending_usage) == 1 + assert llm._pending_usage[0]["prompt_tokens"] == 0 + assert "litellm.token_counter() fallback failed" in mock_warning.call_args.args[0] + + +def test_record_usage_uses_estimated_prompt_tokens_when_usage_has_none() -> None: + llm_module = _load_llm_module() + llm = _build_llm_for_record_usage(llm_module.LLM) + + with ( + patch.object( + llm_module.litellm, "token_counter", return_value=9 + ) as mock_token_counter, + patch.object(llm_module.litellm, "cost_per_token", return_value=(0.0, 0.0)), + ): + llm._record_usage( + model="custom_openai/gateway-model", + messages=[{"role": "user", "content": "hello"}], + usage={"completion_tokens": 4, "total_tokens": 13}, + llm_api="complete", + ) + + mock_token_counter.assert_called_once() + assert llm._pending_usage[0]["prompt_tokens"] == 9