-
Notifications
You must be signed in to change notification settings - Fork 122
feat: consolidate to llguidance from xgrammar #1077
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
372d940
c16cc6b
92269d9
2e19382
d6a5798
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,7 +17,11 @@ | |
| # Third Party | ||
| import pydantic | ||
|
|
||
| from ....core.utils import MelleaLogger | ||
|
|
||
| if TYPE_CHECKING: | ||
| import llguidance | ||
| import torch | ||
| from transformers import PreTrainedModel, PreTrainedTokenizerBase | ||
|
|
||
| # First Party | ||
|
ajbozarth marked this conversation as resolved.
|
||
|
|
@@ -112,11 +116,77 @@ def load_transformers_lora(local_or_remote_path: str) -> tuple: | |
| return model, tokenizer | ||
|
|
||
|
|
||
| class _GuidanceLogitsProcessor: | ||
| """A HuggingFace logits processor that enforces an llguidance grammar.""" | ||
|
|
||
| # Modified from VLLM v0.9.2 code base | ||
| # https://github.com/vllm-project/vllm/blob/v0.9.2/vllm/model_executor/guided_decoding/guidance_logits_processors.py | ||
|
|
||
| def __init__(self, grammar: str, ll_tokenizer: llguidance.LLTokenizer) -> None: | ||
| """Initialize the processor with a compiled grammar and an llguidance tokenizer.""" | ||
| with import_optional("llguidance"): | ||
| # Callers will have already had to import llguidance. Ensure it here. | ||
| import llguidance | ||
| import llguidance.torch | ||
|
|
||
| self.grammar = grammar | ||
| self.vocab_size: int = ll_tokenizer.vocab_size | ||
| self.ll_tokenizer: llguidance.LLTokenizer = ll_tokenizer | ||
| self.ll_matchers: list[llguidance.LLMatcher] = [] | ||
| self.bitmasks: list[torch.Tensor] = [] | ||
| self.new_sampling: bool = False | ||
| self.batch_size: int = -1 | ||
|
|
||
| def __call__( | ||
| self, batch_input_ids: torch.Tensor, batch_scores: torch.Tensor | ||
| ) -> torch.Tensor: | ||
| """Apply the grammar's allowed-token bitmask to ``batch_scores`` in place.""" | ||
| # Guaranteed to be imported by class __init__. | ||
| import llguidance | ||
| import llguidance.torch | ||
|
|
||
| i_batch, _ = batch_input_ids.shape | ||
| s_batch, _ = batch_scores.shape | ||
| if i_batch != s_batch: | ||
| raise RuntimeError( | ||
| f"batch size mismatch: input_ids={i_batch}, scores={s_batch}" | ||
| ) | ||
|
|
||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
if i_batch != s_batch:
raise RuntimeError(f"batch size mismatch: input_ids={i_batch}, scores={s_batch}")
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Changed. |
||
| if self.batch_size != i_batch: | ||
| self.batch_size = i_batch | ||
| self.bitmasks = [ | ||
| llguidance.torch.allocate_token_bitmask(1, self.vocab_size) # type: ignore[attr-defined] | ||
| for _ in range(self.batch_size) | ||
| ] | ||
| self.ll_matchers = [ | ||
| llguidance.LLMatcher(self.ll_tokenizer, self.grammar) | ||
| for _ in range(self.batch_size) | ||
| ] | ||
|
|
||
| for input_ids, scores, ll_matcher, bitmask in zip( | ||
| batch_input_ids, batch_scores, self.ll_matchers, self.bitmasks | ||
| ): | ||
| if self.new_sampling and len(input_ids) > 0: | ||
| ll_matcher.consume_token(input_ids.tolist()[-1]) # type: ignore[attr-defined] | ||
| err = ll_matcher.get_error() # type: ignore[attr-defined] | ||
| if err: | ||
| MelleaLogger.get_logger().warning("Error in LLMatcher: %s", err) | ||
|
|
||
| llguidance.torch.fill_next_token_bitmask(ll_matcher, bitmask, 0) | ||
| llguidance.torch.apply_token_bitmask_inplace( # type: ignore[attr-defined] | ||
| scores, bitmask.to(scores.device) | ||
| ) | ||
|
|
||
| self.new_sampling = True | ||
| return batch_scores | ||
|
|
||
|
|
||
| def chat_completion_request_to_transformers_inputs( | ||
| request: dict, | ||
| tokenizer: PreTrainedTokenizerBase | None = None, | ||
| model: PreTrainedModel | None = None, | ||
| tokenizer: PreTrainedTokenizerBase, | ||
| model: PreTrainedModel, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Type signature says
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for catching this after my most recent update. I will fix this in the next patch. @frreiss, when you review this PR, can you please comment here? This function originally listed these parameters as optional even though they were required in the implementation. I moved towards forcing them to be not None (and if you agree, will fix the None checks in the function body).
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If I recall correctly, older versions of this function were able to get by with either a |
||
| constrained_decoding_prefix: str | None = None, | ||
| ll_tokenizer: llguidance.LLTokenizer | None = None, | ||
| ) -> tuple[dict, dict]: | ||
| """Translate an OpenAI-style chat completion request. | ||
|
|
||
|
|
@@ -125,24 +195,25 @@ def chat_completion_request_to_transformers_inputs( | |
|
|
||
| Args: | ||
| request: Request as parsed JSON or equivalent dataclass. | ||
| tokenizer: HuggingFace tokenizer for the model. Only required if the request | ||
| uses constrained decoding. | ||
| model: HuggingFace model object. Only required if the request uses constrained | ||
| decoding. | ||
| tokenizer: HuggingFace tokenizer. | ||
| model: HuggingFace model object. Used for `model.device` placement and | ||
| when `constrained_decoding_prefix` is set. | ||
| constrained_decoding_prefix: Optional generation prefix to append to the prompt. | ||
| ll_tokenizer: Pre-built `llguidance.LLTokenizer`. Only used when the request | ||
| uses constrained decoding; if not provided, one is constructed from | ||
| `tokenizer`. Pass an existing instance to avoid the construction cost. | ||
|
|
||
| Returns: | ||
| Tuple of `(generate_input, other_input)` where `generate_input` contains | ||
| kwargs to pass directly to `generate()` and `other_input` contains | ||
| additional parameters for `generate_with_transformers`. | ||
|
|
||
| Raises: | ||
| ImportError: If `torch`, `transformers`, or `xgrammar` packages | ||
| ImportError: If `torch`, `transformers`, or `llguidance` packages | ||
| are not installed (the latter only when constrained decoding is used). | ||
| TypeError: If `tokenizer.apply_chat_template()` returns an unexpected type. | ||
| ValueError: If padding or end-of-sequence token IDs cannot be determined | ||
| from the tokenizer, or if a constrained-decoding request is made | ||
| without passing a `tokenizer` or `model` argument. | ||
| from the tokenizer. | ||
| """ | ||
| with import_optional("torch"): | ||
| # Third Party | ||
|
|
@@ -191,7 +262,7 @@ def chat_completion_request_to_transformers_inputs( | |
|
|
||
| # generate() will fail with many different creative error messages if tokens aren't | ||
| # on the right device. | ||
| input_tokens = input_tokens.to(model.device) # type: ignore[union-attr] | ||
| input_tokens = input_tokens.to(model.device) | ||
| generate_input["input_tokens"] = input_tokens | ||
|
|
||
| # The generate() method sometimes needs to know what is the integer ID | ||
|
|
@@ -234,33 +305,23 @@ def chat_completion_request_to_transformers_inputs( | |
| ): | ||
| # Constrained decoding in Hugging Face requires using a third-party library | ||
| # to create a callback function to be invoked from inside generate() | ||
| with import_optional("xgrammar"): | ||
| with import_optional("llguidance"): | ||
| # Third Party | ||
| import xgrammar as xgr # type: ignore[import-not-found] | ||
| if tokenizer is None: | ||
| raise ValueError( | ||
| "Request specifies constrained decoding, but no " | ||
| "tokenizer object was passed to this function." | ||
| ) | ||
| if model is None: | ||
| raise ValueError( | ||
| "Request specifies constrained decoding, but no " | ||
| "tokenizer object was passed to this function." | ||
| ) | ||
|
|
||
| # Different parts of a Hugging Face model will have different opinions about | ||
| # the number of tokens in the tokenizer's vocabulary, because of course they do. | ||
| # Gather together all the possibilities and pick the biggest one. | ||
| vocab_size = max(tokenizer.vocab_size, len(tokenizer), model.vocab_size) | ||
|
|
||
| tokenizer_info = xgr.TokenizerInfo.from_huggingface( | ||
| tokenizer, vocab_size=vocab_size | ||
| ) | ||
| grammar_compiler = xgr.GrammarCompiler(tokenizer_info) | ||
| compiled_grammar = grammar_compiler.compile_json_schema( | ||
| import llguidance | ||
| import llguidance.hf | ||
|
|
||
| if ll_tokenizer is None: | ||
| # HF model components disagree on vocab size (resized embeddings, added | ||
| # special tokens, etc.). Pass the maximum so the bitmask covers every | ||
| # token id the model can emit. llguidance defaults to the tokenizer's | ||
| # value when n_vocab is None, which can be smaller than model.vocab_size. | ||
| n_vocab = max(tokenizer.vocab_size, len(tokenizer), model.vocab_size) # type: ignore[arg-type] | ||
| ll_tokenizer = llguidance.hf.from_tokenizer(tokenizer, n_vocab=n_vocab) # type: ignore[arg-type] | ||
|
|
||
| grammar = llguidance.LLMatcher.grammar_from_json_schema( | ||
| request["extra_body"]["structured_outputs"]["json"] | ||
| ) | ||
| logits_processor = xgr.contrib.hf.LogitsProcessor(compiled_grammar) | ||
| logits_processor = _GuidanceLogitsProcessor(grammar, ll_tokenizer) | ||
|
|
||
| # The "logits_processor" argument to generate() must be a list. | ||
| generate_input["logits_processor"] = [logits_processor] # type: ignore[assignment] | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
util.pyfallback path (line 313–321) has a good comment explaining whyn_vocabmatters — llguidance defaults to the tokeniser's reported size, which can be smaller thanmodel.vocab_sizeon models with resized embeddings. But constructing_llguidance_tokenizerhere withoutn_vocabmeans that guard is bypassed whenever the pre-built instance is passed through. Worth noting the old xgrammar path did computemax(tokenizer.vocab_size, len(tokenizer), model.vocab_size)explicitly, so this is a regression for that case.Possible suggestion —
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added the preemptive fix to this section of the code as well.