Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions docs/examples/aLora/101_example.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# pytest: skip, huggingface, requires_heavy_ram, llm
# SKIP REASON: Example broken since intrinsics refactor - see issue #385
# pytest: huggingface, requires_heavy_ram, llm

import time

Expand Down
3 changes: 3 additions & 0 deletions docs/examples/aLora/102_example.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# pytest: skip, huggingface, requires_heavy_ram, llm
# SKIP REASON: Requires user input; tests same functionality as 101_example.py.

from stembolts_intrinsic import (
async_stembolt_failure_analysis,
stembolt_failure_analysis,
Expand Down
1 change: 1 addition & 0 deletions docs/examples/intrinsics/query_clarification.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# pytest: huggingface, requires_heavy_ram, llm
"""
Example usage of the query clarification intrinsic for RAG applications.

Expand Down
2 changes: 1 addition & 1 deletion docs/examples/mini_researcher/researcher.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# pytest: ollama, qualitative, llm
# pytest: ollama, qualitative, llm, slow

from collections.abc import Callable
from functools import cache
Expand Down
3 changes: 2 additions & 1 deletion mellea/backends/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -1015,7 +1015,8 @@ async def post_processing(
mot._meta["hf_output"] = full_output

# The ModelOutputThunk must be computed by this point.
assert mot.value is not None
if mot.value is None:
return

# Store KV cache in LRU separately (not in mot._meta) to enable proper cleanup on eviction.
# This prevents GPU memory from being held by ModelOutputThunk references.
Expand Down
9 changes: 6 additions & 3 deletions mellea/backends/litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,8 +441,11 @@ async def post_processing(
# OpenAI-like streamed responses potentially give you chunks of tool calls.
# As a result, we have to store data between calls and only then
# check for complete tool calls in the post_processing step.
tool_chunk = extract_model_tool_requests(
tools, mot._meta["litellm_chat_response"]
litellm_response = mot._meta.get("litellm_chat_response")
tool_chunk = (
extract_model_tool_requests(tools, litellm_response)
if litellm_response is not None
else None
)
if tool_chunk is not None:
if mot.tool_calls is None:
Expand All @@ -457,7 +460,7 @@ async def post_processing(
generate_log.backend = f"litellm::{self.model_id!s}"
generate_log.model_options = mot._model_options
generate_log.date = datetime.datetime.now()
generate_log.model_output = mot._meta["litellm_chat_response"]
generate_log.model_output = mot._meta.get("litellm_chat_response")
generate_log.extra = {
"format": _format,
"tools_available": tools,
Expand Down
6 changes: 5 additions & 1 deletion mellea/backends/ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,10 @@ async def generate_from_raw(
result = None
error = None
if isinstance(response, BaseException):
FancyLogger.get_logger().warning(
f"generate_from_raw: request {i} failed with "
f"{type(response).__name__}: {response}"
)
result = ModelOutputThunk(value="")
error = response
else:
Expand Down Expand Up @@ -596,7 +600,7 @@ async def post_processing(
generate_log.backend = f"ollama::{self._get_ollama_model_id()}"
generate_log.model_options = mot._model_options
generate_log.date = datetime.datetime.now()
generate_log.model_output = mot._meta["chat_response"]
generate_log.model_output = mot._meta.get("chat_response")
generate_log.extra = {
"format": _format,
"thinking": mot._model_options.get(ModelOption.THINKING, None),
Expand Down
23 changes: 14 additions & 9 deletions mellea/backends/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,9 +575,13 @@ async def post_processing(
# check for complete tool calls in the post_processing step.
# Use the choice format for tool extraction (backward compatibility)
choice_response = mot._meta.get(
"oai_chat_response_choice", mot._meta["oai_chat_response"]
"oai_chat_response_choice", mot._meta.get("oai_chat_response")
)
tool_chunk = (
extract_model_tool_requests(tools, choice_response)
if choice_response is not None
else None
)
tool_chunk = extract_model_tool_requests(tools, choice_response)
if tool_chunk is not None:
if mot.tool_calls is None:
mot.tool_calls = {}
Expand All @@ -592,7 +596,7 @@ async def post_processing(
generate_log.model_options = mot._model_options
generate_log.date = datetime.datetime.now()
# Store the full response (includes usage info)
generate_log.model_output = mot._meta["oai_chat_response"]
generate_log.model_output = mot._meta.get("oai_chat_response")
generate_log.extra = {
"format": _format,
"thinking": thinking,
Expand All @@ -613,12 +617,13 @@ async def post_processing(
record_token_usage,
)

response = mot._meta["oai_chat_response"]
# response is a dict from model_dump(), extract usage if present
usage = response.get("usage") if isinstance(response, dict) else None
if usage:
record_token_usage(span, usage)
record_response_metadata(span, response)
response = mot._meta.get("oai_chat_response")
if response is not None:
# response is a dict from model_dump(), extract usage if present
usage = response.get("usage") if isinstance(response, dict) else None
if usage:
record_token_usage(span, usage)
record_response_metadata(span, response)
# Close the span now that async operation is complete
end_backend_span(span)
# Clean up the span reference
Expand Down
3 changes: 2 additions & 1 deletion mellea/backends/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,8 @@ async def post_processing(
):
"""Called when generation is done."""
# The ModelOutputThunk must be computed by this point.
assert mot.value is not None
if mot.value is None:
return

# Only scan for tools if we are not doing structured output and tool calls were provided to the model.
if _format is None and tool_calls:
Expand Down
9 changes: 7 additions & 2 deletions mellea/backends/watsonx.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,7 +469,12 @@ async def post_processing(
# OpenAI streamed responses give you chunks of tool calls.
# As a result, we have to store data between calls and only then
# check for complete tool calls in the post_processing step.
tool_chunk = extract_model_tool_requests(tools, mot._meta["oai_chat_response"])
oai_response = mot._meta.get("oai_chat_response")
tool_chunk = (
extract_model_tool_requests(tools, oai_response)
if oai_response is not None
else None
)
if tool_chunk is not None:
if mot.tool_calls is None:
mot.tool_calls = {}
Expand Down Expand Up @@ -509,7 +514,7 @@ async def post_processing(
generate_log.backend = f"watsonx::{self.model_id!s}"
generate_log.model_options = mot._model_options
generate_log.date = datetime.datetime.now()
generate_log.model_output = mot._meta["oai_chat_response"]
generate_log.model_output = mot._meta.get("oai_chat_response")
generate_log.extra = {
"format": _format,
"tools_available": tools,
Expand Down
8 changes: 6 additions & 2 deletions mellea/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,10 +325,14 @@ async def astream(self) -> str:
elif isinstance(chunks[-1], Exception):
# Mark as computed so post_process runs in finally block
self._computed = True
# Store exception to re-raise after cleanup
exception_to_raise = chunks[-1]
# Remove the exception from chunks so _process doesn't receive it
exception_to_raise = chunks.pop()

for chunk in chunks:
# Belt-and-suspenders: skip non-chunk objects that should
# have been removed above (exceptions, sentinel None).
if chunk is None or isinstance(chunk, Exception):
continue
assert self._process is not None
await self._process(self, chunk)

Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ dev = [
"python-semantic-release~=7.32",
"nbmake>=1.5.5",
"langchain-core>=1.2.7", # Necessary for mypy and some tool tests
"sentencepiece==0.2.1", # Necessary for test_huggingface_tools test because of Mistral model
]

notebook = [
Expand Down Expand Up @@ -243,7 +244,7 @@ markers = [
"requires_gpu: Tests requiring GPU",
"requires_heavy_ram: Tests requiring 48GB+ RAM",
"qualitative: Non-deterministic quality tests",
"slow: Tests taking >5 minutes (e.g., dataset loading)",
"slow: Tests taking >1 minute (e.g., multi-step pipelines like researcher)",

# Composite markers
"llm: Tests that make LLM calls (needs at least Ollama)",
Expand All @@ -255,7 +256,6 @@ addopts = [
# Run qualitative tests by default (use -m "not qualitative" for fast tests)
"--cov=mellea",
"--cov=cli",
"--cov-report=term",
"--cov-report=html",
"--cov-report=json",
# Set timeout to 15 minutes for full test suite
Expand Down
3 changes: 0 additions & 3 deletions test/backends/test_litellm_watsonx.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,6 @@ async def test_generate_from_raw(session) -> None:


@pytest.mark.qualitative
@pytest.mark.xfail(
reason="litellm has a bug with watsonx; once that is fixed, this should pass."
)
async def test_multiple_async_funcs(session) -> None:
"""If this test passes, remove the _has_potential_event_loop_errors func from litellm."""
session.chat(
Expand Down
30 changes: 19 additions & 11 deletions test/backends/test_ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ class Email(pydantic.BaseModel):
output = session.instruct(
"Write a short email to Olivia, thanking her for organizing a sailing activity. Her email server is example.com. No more than two sentences. ",
format=Email,
model_options={ModelOption.MAX_NEW_TOKENS: 2**8},
model_options={ModelOption.MAX_NEW_TOKENS: 2**10},
)
print("Formatted output:")
email = Email.model_validate_json(
Expand All @@ -102,18 +102,22 @@ class Email(pydantic.BaseModel):


@pytest.mark.qualitative
@pytest.mark.timeout(150)
async def test_generate_from_raw(session) -> None:
prompts = ["what is 1+1?", "what is 2+2?", "what is 3+3?", "what is 4+4?"]

results = await session.backend.generate_from_raw(
actions=[CBlock(value=prompt) for prompt in prompts], ctx=session.ctx
actions=[CBlock(value=prompt) for prompt in prompts],
ctx=session.ctx,
model_options={ModelOption.CONTEXT_WINDOW: 2048},
)

assert len(results) == len(prompts)
assert results[0].value is not None
assert all(r.value for r in results), (
f"One or more requests returned empty (possible backend timeout): {[r.value for r in results]}"
)


@pytest.mark.xfail(reason="ollama sometimes fails generated structured outputs")
async def test_generate_from_raw_with_format(session) -> None:
prompts = ["what is 1+1?", "what is 2+2?", "what is 3+3?", "what is 4+4?"]

Expand All @@ -125,17 +129,21 @@ class Answer(pydantic.BaseModel):
actions=[CBlock(value=prompt) for prompt in prompts],
ctx=session.ctx,
format=Answer,
model_options={ModelOption.CONTEXT_WINDOW: 2048},
)

assert len(results) == len(prompts)
assert all(r.value for r in results), (
f"One or more requests returned empty (possible backend timeout): {[r.value for r in results]}"
)

random_result = results[0]
try:
Answer.model_validate_json(random_result.value)
except pydantic.ValidationError as e:
assert False, (
f"formatting directive failed for {random_result.value}: {e.json()}"
)
for result in results:
try:
Answer.model_validate_json(result.value)
except pydantic.ValidationError as e:
assert False, (
f"formatting directive failed for {result.value}: {e.json()}"
)


async def test_async_parallel_requests(session) -> None:
Expand Down
2 changes: 1 addition & 1 deletion test/backends/test_openai_ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ class Email(pydantic.BaseModel):
output = m_session.instruct(
"Write a short email to Olivia, thanking her for organizing a sailing activity. Her email server is example.com. No more than two sentences. ",
format=Email,
model_options={ModelOption.MAX_NEW_TOKENS: 2**8},
model_options={ModelOption.MAX_NEW_TOKENS: 2**10},
)
print("Formatted output:")
email = Email.model_validate_json(
Expand Down
21 changes: 14 additions & 7 deletions test/backends/test_openai_vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,17 +79,24 @@ def vllm_process():

yield process

except Exception as e:
pytest.skip(
f"vLLM process not available: {e}. May need to install with: pip install mellea[vllm]",
allow_module_level=True,
)

# --- Teardown (always runs) ---
finally:
try:
os.killpg(process.pid, signal.SIGTERM) # kill the session group
process.wait(timeout=30)
except Exception:
if process is not None:
try:
os.killpg(process.pid, signal.SIGKILL)
os.killpg(process.pid, signal.SIGTERM) # kill the session group
process.wait(timeout=30)
except Exception:
pass
process.wait()
try:
os.killpg(process.pid, signal.SIGKILL)
except Exception:
pass
process.wait()


@pytest.fixture(scope="module")
Expand Down
Loading
Loading