diff --git a/clients/python/llmengine/__init__.py b/clients/python/llmengine/__init__.py index dfae78cf7..09dd8526c 100644 --- a/clients/python/llmengine/__init__.py +++ b/clients/python/llmengine/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "0.0.0b28" +__version__ = "0.0.0b29" import os from typing import Sequence diff --git a/clients/python/llmengine/api_engine.py b/clients/python/llmengine/api_engine.py index 3abf86d65..a1b955be0 100644 --- a/clients/python/llmengine/api_engine.py +++ b/clients/python/llmengine/api_engine.py @@ -12,12 +12,23 @@ from llmengine.errors import parse_error SPELLBOOK_API_URL = "https://api.spellbook.scale.com/llm-engine/" -LLM_ENGINE_BASE_PATH = os.getenv("LLM_ENGINE_BASE_PATH", SPELLBOOK_API_URL) DEFAULT_TIMEOUT: int = 10 +base_path = None api_key = None +def set_base_path(path): + global base_path + base_path = path + + +def get_base_path() -> str: + if base_path is not None: + return base_path + return os.getenv("LLM_ENGINE_BASE_PATH", SPELLBOOK_API_URL) + + def set_api_key(key): global api_key api_key = key @@ -33,7 +44,7 @@ def get_api_key() -> str: def assert_self_hosted(func): @wraps(func) def inner(*args, **kwargs): - if SPELLBOOK_API_URL == LLM_ENGINE_BASE_PATH: + if SPELLBOOK_API_URL == get_base_path(): raise ValueError("This feature is only available for self-hosted users.") return func(*args, **kwargs) @@ -43,16 +54,17 @@ def inner(*args, **kwargs): class APIEngine: @classmethod def validate_api_key(cls): - if SPELLBOOK_API_URL == LLM_ENGINE_BASE_PATH and not get_api_key(): + if SPELLBOOK_API_URL == get_base_path() and not get_api_key(): raise ValueError( "You must set SCALE_API_KEY in your environment to to use the LLM Engine API." ) @classmethod def _get(cls, resource_name: str, timeout: int) -> Dict[str, Any]: + base_path = get_base_path() api_key = get_api_key() response = requests.get( - urljoin(LLM_ENGINE_BASE_PATH, resource_name), + urljoin(base_path, resource_name), timeout=timeout, headers={"x-api-key": api_key}, auth=(api_key, ""), @@ -66,9 +78,10 @@ def _get(cls, resource_name: str, timeout: int) -> Dict[str, Any]: def put( cls, resource_name: str, data: Optional[Dict[str, Any]], timeout: int ) -> Dict[str, Any]: + base_path = get_base_path() api_key = get_api_key() response = requests.put( - urljoin(LLM_ENGINE_BASE_PATH, resource_name), + urljoin(base_path, resource_name), json=data, timeout=timeout, headers={"x-api-key": api_key}, @@ -81,9 +94,10 @@ def put( @classmethod def _delete(cls, resource_name: str, timeout: int) -> Dict[str, Any]: + base_path = get_base_path() api_key = get_api_key() response = requests.delete( - urljoin(LLM_ENGINE_BASE_PATH, resource_name), + urljoin(base_path, resource_name), timeout=timeout, headers={"x-api-key": api_key}, auth=(api_key, ""), @@ -95,9 +109,10 @@ def _delete(cls, resource_name: str, timeout: int) -> Dict[str, Any]: @classmethod def post_sync(cls, resource_name: str, data: Dict[str, Any], timeout: int) -> Dict[str, Any]: + base_path = get_base_path() api_key = get_api_key() response = requests.post( - urljoin(LLM_ENGINE_BASE_PATH, resource_name), + urljoin(base_path, resource_name), json=data, timeout=timeout, headers={"x-api-key": api_key}, @@ -112,9 +127,10 @@ def post_sync(cls, resource_name: str, data: Dict[str, Any], timeout: int) -> Di def post_stream( cls, resource_name: str, data: Dict[str, Any], timeout: int ) -> Iterator[Dict[str, Any]]: + base_path = get_base_path() api_key = get_api_key() response = requests.post( - urljoin(LLM_ENGINE_BASE_PATH, resource_name), + urljoin(base_path, resource_name), json=data, timeout=timeout, headers={"x-api-key": api_key}, @@ -144,9 +160,10 @@ def post_stream( def post_file( cls, resource_name: str, files: Dict[str, BufferedReader], timeout: int ) -> Dict[str, Any]: + base_path = get_base_path() api_key = get_api_key() response = requests.post( - urljoin(LLM_ENGINE_BASE_PATH, resource_name), + urljoin(base_path, resource_name), files=files, timeout=timeout, headers={"x-api-key": api_key}, @@ -161,15 +178,14 @@ def post_file( async def apost_sync( cls, resource_name: str, data: Dict[str, Any], timeout: int ) -> Dict[str, Any]: + base_path = get_base_path() api_key = get_api_key() async with ClientSession( timeout=ClientTimeout(timeout), headers={"x-api-key": api_key}, auth=BasicAuth(api_key, ""), ) as session: - async with session.post( - urljoin(LLM_ENGINE_BASE_PATH, resource_name), json=data - ) as resp: + async with session.post(urljoin(base_path, resource_name), json=data) as resp: if resp.status != 200: raise parse_error(resp.status, await resp.read()) payload = await resp.json() @@ -179,15 +195,14 @@ async def apost_sync( async def apost_stream( cls, resource_name: str, data: Dict[str, Any], timeout: int ) -> AsyncIterable[Dict[str, Any]]: + base_path = get_base_path() api_key = get_api_key() async with ClientSession( timeout=ClientTimeout(timeout), headers={"x-api-key": api_key}, auth=BasicAuth(api_key, ""), ) as session: - async with session.post( - urljoin(LLM_ENGINE_BASE_PATH, resource_name), json=data - ) as resp: + async with session.post(urljoin(base_path, resource_name), json=data) as resp: if resp.status != 200: raise parse_error(resp.status, await resp.read()) async for byte_payload in resp.content: diff --git a/clients/python/pyproject.toml b/clients/python/pyproject.toml index 8ddec08fb..977196099 100644 --- a/clients/python/pyproject.toml +++ b/clients/python/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "scale-llm-engine" -version = "0.0.0.beta28" +version = "0.0.0.beta29" description = "Scale LLM Engine Python client" license = "Apache-2.0" authors = ["Phil Chen "] diff --git a/clients/python/setup.py b/clients/python/setup.py index a33e6a03f..5da0008ab 100644 --- a/clients/python/setup.py +++ b/clients/python/setup.py @@ -3,6 +3,6 @@ setup( name="scale-llm-engine", python_requires=">=3.7", - version="0.0.0.beta28", + version="0.0.0.beta29", packages=find_packages(), )