Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 11 additions & 62 deletions mellea/backends/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
)
from ..core.base import AbstractMelleaTool
from ..formatters import ChatFormatter, TemplateFormatter, granite as granite_formatters
from ..formatters.granite.base.util import _GuidanceLogitsProcessor
from ..helpers import message_to_openai_message, messages_to_docs, send_to_queue
from ..stdlib.components import Intrinsic, Message
from ..stdlib.requirements import ALoraRequirement, LLMaJRequirement
Expand Down Expand Up @@ -202,65 +203,6 @@ def _cleanup_kv_cache(cache_info: HFAloraCacheInfo) -> None:
torch.cuda.empty_cache()


# modified from VLLM v0.9.2 code base
# https://github.com/vllm-project/vllm/blob/v0.9.2/vllm/model_executor/guided_decoding/guidance_logits_processors.py
class _GuidanceLogitsProcessor:
def __init__(self, grammar: str, ll_tokenizer: llguidance.LLTokenizer) -> None:
self.grammar = grammar
self.vocab_size: int = ll_tokenizer.vocab_size
self.ll_tokenizer: llguidance.LLTokenizer = ll_tokenizer
self.ll_matchers: list[llguidance.LLMatcher] = []
self.bitmasks: list[torch.Tensor] = []
self.new_sampling: bool = False
self.batch_size: int = -1

def __call__(
self, batch_input_ids: torch.Tensor, batch_scores: torch.Tensor
) -> torch.Tensor:
i_batch, _ = batch_input_ids.shape
s_batch, _ = batch_scores.shape
assert i_batch == s_batch

# s_batch, s_vocab = batch_scores.shape
# assert s_vocab == self.vocab_size
#
# NOTE: somehow, this does not hold. s_vocab is not same as either of
# * self._tokenizer._tokenizer.get_vocab_size(with_added_tokens=True) == self.vocab_size == ll_tokenizer.vocab_size
# * self._tokenizer._tokenizer.get_vocab_size(with_added_tokens=False)

if self.batch_size != i_batch:
self.batch_size = i_batch
self.bitmasks = [
llguidance.torch.allocate_token_bitmask(1, self.vocab_size) # type: ignore[attr-defined]
for _ in range(self.batch_size)
]

self.ll_matchers = [
llguidance.LLMatcher(self.ll_tokenizer, self.grammar)
for _ in range(self.batch_size)
]

for input_ids, scores, ll_matcher, bitmask in zip(
batch_input_ids, batch_scores, self.ll_matchers, self.bitmasks
):
if self.new_sampling and len(input_ids) > 0:
ll_matcher.consume_token( # type: ignore[attr-defined]
input_ids.tolist()[-1]
)
err = ll_matcher.get_error() # type: ignore[attr-defined]
if err:
MelleaLogger.get_logger().warning("Error in LLMatcher: %s", err)

llguidance.torch.fill_next_token_bitmask(ll_matcher, bitmask, 0)
llguidance.torch.apply_token_bitmask_inplace(
scores, bitmask.to(scores.device)
) # type: ignore[attr-defined]

self.new_sampling = True

return batch_scores


class LocalHFBackend(FormatterBackend, AdapterMixin):
"""The LocalHFBackend uses Huggingface's transformers library for inference, and uses a Formatter to convert `Component`s into prompts. This backend also supports Activated LoRAs (ALoras)](https://arxiv.org/pdf/2504.12397).

Expand Down Expand Up @@ -362,13 +304,17 @@ def __init__(
case _:
self._tokenizer, self._model, self._device = custom_config

# Preemptively fix vocab size discrepancies between the tokenizer and model if needed.
n_vocab = max(
self._tokenizer.vocab_size, len(self._tokenizer), self._model.vocab_size
)
self._llguidance_tokenizer: llguidance.LLTokenizer = (
llguidance.hf.from_tokenizer(self._tokenizer) # type:ignore
llguidance.hf.from_tokenizer(self._tokenizer, n_vocab=n_vocab) # type:ignore
)
assert (
self._llguidance_tokenizer.vocab_size
== self._tokenizer._tokenizer.get_vocab_size(with_added_tokens=True)
), "vocab size mismatch between llguidance and huggingface tokenizers ... wtf?"
), "vocab size mismatch between llguidance and huggingface tokenizers"

self._use_caches = use_caches
self._cache = (
Expand Down Expand Up @@ -667,7 +613,10 @@ async def _generate_from_intrinsic(

generate_input, other_input = (
granite_formatters.base.util.chat_completion_request_to_transformers_inputs( # type: ignore
rewritten, self._tokenizer, self._model
rewritten,
self._tokenizer,
self._model,
ll_tokenizer=self._llguidance_tokenizer,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The util.py fallback path (line 313–321) has a good comment explaining why n_vocab matters — llguidance defaults to the tokeniser's reported size, which can be smaller than model.vocab_size on models with resized embeddings. But constructing _llguidance_tokenizer here without n_vocab means that guard is bypassed whenever the pre-built instance is passed through. Worth noting the old xgrammar path did compute max(tokenizer.vocab_size, len(tokenizer), model.vocab_size) explicitly, so this is a regression for that case.

Possible suggestion —

n_vocab = max(self._tokenizer.vocab_size, len(self._tokenizer), self._model.vocab_size)
self._llguidance_tokenizer = llguidance.hf.from_tokenizer(self._tokenizer, n_vocab=n_vocab)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added the preemptive fix to this section of the code as well.

)

Expand Down
129 changes: 95 additions & 34 deletions mellea/formatters/granite/base/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@
# Third Party
import pydantic

from ....core.utils import MelleaLogger

if TYPE_CHECKING:
import llguidance
import torch
from transformers import PreTrainedModel, PreTrainedTokenizerBase

# First Party
Comment thread
ajbozarth marked this conversation as resolved.
Expand Down Expand Up @@ -112,11 +116,77 @@ def load_transformers_lora(local_or_remote_path: str) -> tuple:
return model, tokenizer


class _GuidanceLogitsProcessor:
"""A HuggingFace logits processor that enforces an llguidance grammar."""

# Modified from VLLM v0.9.2 code base
# https://github.com/vllm-project/vllm/blob/v0.9.2/vllm/model_executor/guided_decoding/guidance_logits_processors.py

def __init__(self, grammar: str, ll_tokenizer: llguidance.LLTokenizer) -> None:
"""Initialize the processor with a compiled grammar and an llguidance tokenizer."""
with import_optional("llguidance"):
# Callers will have already had to import llguidance. Ensure it here.
import llguidance
import llguidance.torch

self.grammar = grammar
self.vocab_size: int = ll_tokenizer.vocab_size
self.ll_tokenizer: llguidance.LLTokenizer = ll_tokenizer
self.ll_matchers: list[llguidance.LLMatcher] = []
self.bitmasks: list[torch.Tensor] = []
self.new_sampling: bool = False
self.batch_size: int = -1

def __call__(
self, batch_input_ids: torch.Tensor, batch_scores: torch.Tensor
) -> torch.Tensor:
"""Apply the grammar's allowed-token bitmask to ``batch_scores`` in place."""
# Guaranteed to be imported by class __init__.
import llguidance
import llguidance.torch

i_batch, _ = batch_input_ids.shape
s_batch, _ = batch_scores.shape
if i_batch != s_batch:
raise RuntimeError(
f"batch size mismatch: input_ids={i_batch}, scores={s_batch}"
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

assert can be stripped with python -O and gives a bare AssertionError when it fires. A RuntimeError with a message would be more helpful here:

if i_batch != s_batch:
    raise RuntimeError(f"batch size mismatch: input_ids={i_batch}, scores={s_batch}")

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed.

if self.batch_size != i_batch:
self.batch_size = i_batch
self.bitmasks = [
llguidance.torch.allocate_token_bitmask(1, self.vocab_size) # type: ignore[attr-defined]
for _ in range(self.batch_size)
]
self.ll_matchers = [
llguidance.LLMatcher(self.ll_tokenizer, self.grammar)
for _ in range(self.batch_size)
]

for input_ids, scores, ll_matcher, bitmask in zip(
batch_input_ids, batch_scores, self.ll_matchers, self.bitmasks
):
if self.new_sampling and len(input_ids) > 0:
ll_matcher.consume_token(input_ids.tolist()[-1]) # type: ignore[attr-defined]
err = ll_matcher.get_error() # type: ignore[attr-defined]
if err:
MelleaLogger.get_logger().warning("Error in LLMatcher: %s", err)

llguidance.torch.fill_next_token_bitmask(ll_matcher, bitmask, 0)
llguidance.torch.apply_token_bitmask_inplace( # type: ignore[attr-defined]
scores, bitmask.to(scores.device)
)

self.new_sampling = True
return batch_scores


def chat_completion_request_to_transformers_inputs(
request: dict,
tokenizer: PreTrainedTokenizerBase | None = None,
model: PreTrainedModel | None = None,
tokenizer: PreTrainedTokenizerBase,
model: PreTrainedModel,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Type signature says tokenizer: PreTrainedTokenizerBase and model: PreTrainedModel (required, non-None), but the body still has if tokenizer is None and ll_tokenizer is None: (line ~313) and if tokenizer is None or model is None: (line ~338). Either revert to | None = None to match the runtime contract (where ll_tokenizer can stand in for tokenizer), or drop the None checks. As written, mypy and runtime disagree.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for catching this after my most recent update. I will fix this in the next patch. @frreiss, when you review this PR, can you please comment here? This function originally listed these parameters as optional even though they were required in the implementation. I moved towards forcing them to be not None (and if you agree, will fix the None checks in the function body).

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I recall correctly, older versions of this function were able to get by with either a model or a tokenizer, depending on what features the chat completion request enabled. With the current proliferation of corner-case-covering code, I think that making those parameters optional is no longer practical.

constrained_decoding_prefix: str | None = None,
ll_tokenizer: llguidance.LLTokenizer | None = None,
) -> tuple[dict, dict]:
"""Translate an OpenAI-style chat completion request.

Expand All @@ -125,24 +195,25 @@ def chat_completion_request_to_transformers_inputs(

Args:
request: Request as parsed JSON or equivalent dataclass.
tokenizer: HuggingFace tokenizer for the model. Only required if the request
uses constrained decoding.
model: HuggingFace model object. Only required if the request uses constrained
decoding.
tokenizer: HuggingFace tokenizer.
model: HuggingFace model object. Used for `model.device` placement and
when `constrained_decoding_prefix` is set.
constrained_decoding_prefix: Optional generation prefix to append to the prompt.
ll_tokenizer: Pre-built `llguidance.LLTokenizer`. Only used when the request
uses constrained decoding; if not provided, one is constructed from
`tokenizer`. Pass an existing instance to avoid the construction cost.

Returns:
Tuple of `(generate_input, other_input)` where `generate_input` contains
kwargs to pass directly to `generate()` and `other_input` contains
additional parameters for `generate_with_transformers`.

Raises:
ImportError: If `torch`, `transformers`, or `xgrammar` packages
ImportError: If `torch`, `transformers`, or `llguidance` packages
are not installed (the latter only when constrained decoding is used).
TypeError: If `tokenizer.apply_chat_template()` returns an unexpected type.
ValueError: If padding or end-of-sequence token IDs cannot be determined
from the tokenizer, or if a constrained-decoding request is made
without passing a `tokenizer` or `model` argument.
from the tokenizer.
"""
with import_optional("torch"):
# Third Party
Expand Down Expand Up @@ -191,7 +262,7 @@ def chat_completion_request_to_transformers_inputs(

# generate() will fail with many different creative error messages if tokens aren't
# on the right device.
input_tokens = input_tokens.to(model.device) # type: ignore[union-attr]
input_tokens = input_tokens.to(model.device)
generate_input["input_tokens"] = input_tokens

# The generate() method sometimes needs to know what is the integer ID
Expand Down Expand Up @@ -234,33 +305,23 @@ def chat_completion_request_to_transformers_inputs(
):
# Constrained decoding in Hugging Face requires using a third-party library
# to create a callback function to be invoked from inside generate()
with import_optional("xgrammar"):
with import_optional("llguidance"):
# Third Party
import xgrammar as xgr # type: ignore[import-not-found]
if tokenizer is None:
raise ValueError(
"Request specifies constrained decoding, but no "
"tokenizer object was passed to this function."
)
if model is None:
raise ValueError(
"Request specifies constrained decoding, but no "
"tokenizer object was passed to this function."
)

# Different parts of a Hugging Face model will have different opinions about
# the number of tokens in the tokenizer's vocabulary, because of course they do.
# Gather together all the possibilities and pick the biggest one.
vocab_size = max(tokenizer.vocab_size, len(tokenizer), model.vocab_size)

tokenizer_info = xgr.TokenizerInfo.from_huggingface(
tokenizer, vocab_size=vocab_size
)
grammar_compiler = xgr.GrammarCompiler(tokenizer_info)
compiled_grammar = grammar_compiler.compile_json_schema(
import llguidance
import llguidance.hf

if ll_tokenizer is None:
# HF model components disagree on vocab size (resized embeddings, added
# special tokens, etc.). Pass the maximum so the bitmask covers every
# token id the model can emit. llguidance defaults to the tokenizer's
# value when n_vocab is None, which can be smaller than model.vocab_size.
n_vocab = max(tokenizer.vocab_size, len(tokenizer), model.vocab_size) # type: ignore[arg-type]
ll_tokenizer = llguidance.hf.from_tokenizer(tokenizer, n_vocab=n_vocab) # type: ignore[arg-type]

grammar = llguidance.LLMatcher.grammar_from_json_schema(
request["extra_body"]["structured_outputs"]["json"]
)
logits_processor = xgr.contrib.hf.LogitsProcessor(compiled_grammar)
logits_processor = _GuidanceLogitsProcessor(grammar, ll_tokenizer)

# The "logits_processor" argument to generate() must be a list.
generate_input["logits_processor"] = [logits_processor] # type: ignore[assignment]
Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ hf = [
"peft>=0.18.1", # Native aLoRA support added in PEFT 0.18.0
"transformers>=4.53.2,<5",
"trl==0.19.1",
"xgrammar==0.1.33", # Necessary for granite_common intrinsics. Pinned due to Issue 990.
"huggingface-hub>=0.33.4",
]

Expand Down
4 changes: 4 additions & 0 deletions test/backends/test_document_rendering_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,10 @@ async def test_huggingface_renders_documents_in_prompt():
mock_llg.hf.from_tokenizer.return_value = MagicMock(vocab_size=32000)
mock_tokenizer._tokenizer = MagicMock()
mock_tokenizer._tokenizer.get_vocab_size.return_value = 32000
# Needed for the n_vocab preemptive check in LocalHFBackend.__init__
mock_tokenizer.vocab_size = 32000
mock_tokenizer.__len__ = MagicMock(return_value=32000)
mock_model.vocab_size = 32000

from mellea.backends.huggingface import LocalHFBackend

Expand Down
Loading
Loading