Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion clients/python/llmengine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

__version__ = "0.0.0b28"
__version__ = "0.0.0b29"

import os
from typing import Sequence
Expand Down
45 changes: 30 additions & 15 deletions clients/python/llmengine/api_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,23 @@
from llmengine.errors import parse_error

SPELLBOOK_API_URL = "https://api.spellbook.scale.com/llm-engine/"
LLM_ENGINE_BASE_PATH = os.getenv("LLM_ENGINE_BASE_PATH", SPELLBOOK_API_URL)
DEFAULT_TIMEOUT: int = 10

base_path = None
api_key = None


def set_base_path(path):
global base_path
base_path = path


def get_base_path() -> str:
if base_path is not None:
return base_path
return os.getenv("LLM_ENGINE_BASE_PATH", SPELLBOOK_API_URL)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: since we're changing things with this var, should we rename SPELLBOOK as well?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to keep the Spellbook URL up anyway, and the new APIs also don't support the LLM Engine mirrors that we need, so I think keeping this as Spellbook is still accurate? :)



def set_api_key(key):
global api_key
api_key = key
Expand All @@ -33,7 +44,7 @@ def get_api_key() -> str:
def assert_self_hosted(func):
@wraps(func)
def inner(*args, **kwargs):
if SPELLBOOK_API_URL == LLM_ENGINE_BASE_PATH:
if SPELLBOOK_API_URL == get_base_path():
raise ValueError("This feature is only available for self-hosted users.")
return func(*args, **kwargs)

Expand All @@ -43,16 +54,17 @@ def inner(*args, **kwargs):
class APIEngine:
@classmethod
def validate_api_key(cls):
if SPELLBOOK_API_URL == LLM_ENGINE_BASE_PATH and not get_api_key():
if SPELLBOOK_API_URL == get_base_path() and not get_api_key():
raise ValueError(
"You must set SCALE_API_KEY in your environment to to use the LLM Engine API."
)

@classmethod
def _get(cls, resource_name: str, timeout: int) -> Dict[str, Any]:
base_path = get_base_path()
api_key = get_api_key()
response = requests.get(
urljoin(LLM_ENGINE_BASE_PATH, resource_name),
urljoin(base_path, resource_name),
timeout=timeout,
headers={"x-api-key": api_key},
auth=(api_key, ""),
Expand All @@ -66,9 +78,10 @@ def _get(cls, resource_name: str, timeout: int) -> Dict[str, Any]:
def put(
cls, resource_name: str, data: Optional[Dict[str, Any]], timeout: int
) -> Dict[str, Any]:
base_path = get_base_path()
api_key = get_api_key()
response = requests.put(
urljoin(LLM_ENGINE_BASE_PATH, resource_name),
urljoin(base_path, resource_name),
json=data,
timeout=timeout,
headers={"x-api-key": api_key},
Expand All @@ -81,9 +94,10 @@ def put(

@classmethod
def _delete(cls, resource_name: str, timeout: int) -> Dict[str, Any]:
base_path = get_base_path()
api_key = get_api_key()
response = requests.delete(
urljoin(LLM_ENGINE_BASE_PATH, resource_name),
urljoin(base_path, resource_name),
timeout=timeout,
headers={"x-api-key": api_key},
auth=(api_key, ""),
Expand All @@ -95,9 +109,10 @@ def _delete(cls, resource_name: str, timeout: int) -> Dict[str, Any]:

@classmethod
def post_sync(cls, resource_name: str, data: Dict[str, Any], timeout: int) -> Dict[str, Any]:
base_path = get_base_path()
api_key = get_api_key()
response = requests.post(
urljoin(LLM_ENGINE_BASE_PATH, resource_name),
urljoin(base_path, resource_name),
json=data,
timeout=timeout,
headers={"x-api-key": api_key},
Expand All @@ -112,9 +127,10 @@ def post_sync(cls, resource_name: str, data: Dict[str, Any], timeout: int) -> Di
def post_stream(
cls, resource_name: str, data: Dict[str, Any], timeout: int
) -> Iterator[Dict[str, Any]]:
base_path = get_base_path()
api_key = get_api_key()
response = requests.post(
urljoin(LLM_ENGINE_BASE_PATH, resource_name),
urljoin(base_path, resource_name),
json=data,
timeout=timeout,
headers={"x-api-key": api_key},
Expand Down Expand Up @@ -144,9 +160,10 @@ def post_stream(
def post_file(
cls, resource_name: str, files: Dict[str, BufferedReader], timeout: int
) -> Dict[str, Any]:
base_path = get_base_path()
api_key = get_api_key()
response = requests.post(
urljoin(LLM_ENGINE_BASE_PATH, resource_name),
urljoin(base_path, resource_name),
files=files,
timeout=timeout,
headers={"x-api-key": api_key},
Expand All @@ -161,15 +178,14 @@ def post_file(
async def apost_sync(
cls, resource_name: str, data: Dict[str, Any], timeout: int
) -> Dict[str, Any]:
base_path = get_base_path()
api_key = get_api_key()
async with ClientSession(
timeout=ClientTimeout(timeout),
headers={"x-api-key": api_key},
auth=BasicAuth(api_key, ""),
) as session:
async with session.post(
urljoin(LLM_ENGINE_BASE_PATH, resource_name), json=data
) as resp:
async with session.post(urljoin(base_path, resource_name), json=data) as resp:
if resp.status != 200:
raise parse_error(resp.status, await resp.read())
payload = await resp.json()
Expand All @@ -179,15 +195,14 @@ async def apost_sync(
async def apost_stream(
cls, resource_name: str, data: Dict[str, Any], timeout: int
) -> AsyncIterable[Dict[str, Any]]:
base_path = get_base_path()
api_key = get_api_key()
async with ClientSession(
timeout=ClientTimeout(timeout),
headers={"x-api-key": api_key},
auth=BasicAuth(api_key, ""),
) as session:
async with session.post(
urljoin(LLM_ENGINE_BASE_PATH, resource_name), json=data
) as resp:
async with session.post(urljoin(base_path, resource_name), json=data) as resp:
if resp.status != 200:
raise parse_error(resp.status, await resp.read())
async for byte_payload in resp.content:
Expand Down
2 changes: 1 addition & 1 deletion clients/python/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "scale-llm-engine"
version = "0.0.0.beta28"
version = "0.0.0.beta29"
description = "Scale LLM Engine Python client"
license = "Apache-2.0"
authors = ["Phil Chen <phil.chen@scale.com>"]
Expand Down
2 changes: 1 addition & 1 deletion clients/python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,6 @@
setup(
name="scale-llm-engine",
python_requires=">=3.7",
version="0.0.0.beta28",
version="0.0.0.beta29",
packages=find_packages(),
)