From 8a0faab90da60598d3ef72705c20ea1d68e6ecc7 Mon Sep 17 00:00:00 2001 From: davidberenstein1957 Date: Sat, 25 Apr 2026 12:49:15 +0200 Subject: [PATCH 1/6] feat(vendor): add LLM2Vec embedding model - Add LLM2Vec from OneIG vendor source - Includes Llama encoder and bidirectional models - Self-contained, no dependencies on Pruna internals - Licensed under Apache 2.0 --- .../metrics/vendor/NOTICE.oneig_llm2vec | 12 + .../metrics/vendor/oneig_llm2vec/llm2vec.py | 549 ++++++++++++++++++ .../oneig_llm2vec/modeling_llama_encoder.py | 107 ++++ .../models/bidirectional_llama.py | 228 ++++++++ 4 files changed, 896 insertions(+) create mode 100644 src/pruna/evaluation/metrics/vendor/NOTICE.oneig_llm2vec create mode 100644 src/pruna/evaluation/metrics/vendor/oneig_llm2vec/llm2vec.py create mode 100644 src/pruna/evaluation/metrics/vendor/oneig_llm2vec/modeling_llama_encoder.py create mode 100644 src/pruna/evaluation/metrics/vendor/oneig_llm2vec/models/bidirectional_llama.py diff --git a/src/pruna/evaluation/metrics/vendor/NOTICE.oneig_llm2vec b/src/pruna/evaluation/metrics/vendor/NOTICE.oneig_llm2vec new file mode 100644 index 00000000..01654bd4 --- /dev/null +++ b/src/pruna/evaluation/metrics/vendor/NOTICE.oneig_llm2vec @@ -0,0 +1,12 @@ +LLM2Vec (llm2vec package) vendored from OneIG-Benchmark. + +Source: https://github.com/OneIG-Bench/OneIG-Benchmark +Commit: 41b49831e79e6dde5323618c164da1c4cf0f699d +Path: scripts/utils/llm2clip/llm2vec/ + +OneIG-Benchmark is licensed under the Apache License 2.0. +See the project repository for full license text. + +``oneig_llm2vec/modeling_llama_encoder.py`` is derived from +McGill-NLP/LLM2Vec-Meta-Llama-3-8B-Instruct-mntp (Hugging Face Hub); +Pruna relaxes the upstream flash-attention-only constraint for CPU use. diff --git a/src/pruna/evaluation/metrics/vendor/oneig_llm2vec/llm2vec.py b/src/pruna/evaluation/metrics/vendor/oneig_llm2vec/llm2vec.py new file mode 100644 index 00000000..102f5b28 --- /dev/null +++ b/src/pruna/evaluation/metrics/vendor/oneig_llm2vec/llm2vec.py @@ -0,0 +1,549 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Vendored from OneIG-Benchmark (commit 41b49831e79e6dde5323618c164da1c4cf0f699d). +# See NOTICE.oneig_llm2vec in parent directory. + +import json +import logging +import pathlib +from typing import Any, Dict, List, Optional, Union + +import numpy as np +import torch +import torch.multiprocessing as mp +from peft import PeftModel +from torch import Tensor, device, nn +from tqdm import trange +from transformers import ( + AutoConfig, + AutoModel, + AutoTokenizer, + LlamaConfig, + PretrainedConfig, +) + +from pruna.evaluation.metrics.vendor.oneig_llm2vec.models.bidirectional_llama import LlamaBiModel + +logger = logging.getLogger(__name__) + + +def batch_to_device(batch, target_device: device | str): + """ + Move tensor values in a batch dict to ``target_device``. + + Parameters + ---------- + batch : dict[str, Any] + Mapping of feature names to tensors or other values; only ``torch.Tensor`` + values are moved. + target_device : torch.device or str + Device to move tensors to. + + Returns + ------- + dict[str, Any] + The same ``batch`` object with tensors updated in place. + """ + for key in batch: + if isinstance(batch[key], Tensor): + batch[key] = batch[key].to(target_device) + return batch + + +class LLM2Vec(nn.Module): + """ + Bidirectional LLM wrapper with configurable pooling for dense embeddings. + + Parameters + ---------- + model : transformers.AutoModel + Encoder model used for hidden states. + tokenizer : transformers.AutoTokenizer + Tokenizer aligned with ``model``. + pooling_mode : str, optional + How to pool token hidden states (e.g. ``mean``, ``eos_token``). + max_length : int, optional + Maximum sequence length for tokenization. + doc_max_length : int, optional + Soft cap used when shortening document segments during encoding. + skip_instruction : bool, optional + If True, restrict attention to embed regions when pooling. + """ + + def __init__( + self, + model: AutoModel, + tokenizer: AutoTokenizer, + pooling_mode: str = "mean", + max_length: int = 512, + doc_max_length: int = 512, + skip_instruction: bool = True, + ): + super().__init__() + self.model = model + self.tokenizer = tokenizer + self.pooling_mode = pooling_mode + self.skip_instruction = skip_instruction + self.max_length = max_length + self.doc_max_length = 512 + self.config = model.config + + @classmethod + def _get_model_class(cls, config_class_name, enable_bidirectional): + if not enable_bidirectional: + return AutoModel + elif config_class_name == "LlamaConfig": + return LlamaBiModel + else: + raise ValueError(f"{config_class_name} is not supported yet with bidirectional models.") + + @classmethod + def from_pretrained( + cls, + base_model_name_or_path, + peft_model_name_or_path=None, + merge_peft=False, + enable_bidirectional=True, + extra_model_name_or_path=None, + **kwargs, + ): + """ + Load tokenizer and encoder weights and return an ``LLM2Vec`` instance. + + Optional PEFT adapters, bidirectional Llama, and extra adapter paths are + supported; keyword arguments are forwarded to Hugging Face + ``from_pretrained`` calls. + + Parameters + ---------- + base_model_name_or_path : str or pathlib.Path + Hub id or local directory for the base model. + peft_model_name_or_path : str or pathlib.Path, optional + Optional PEFT adapter to load on top of the base model. + merge_peft : bool, optional + If True, merge PEFT weights into the base weights after loading. + enable_bidirectional : bool, optional + If True, use bidirectional Llama when the config is ``LlamaConfig``. + extra_model_name_or_path : str, list of str, or None, optional + Additional PEFT checkpoint(s) applied sequentially when set. + **kwargs + Forwarded to Hugging Face ``from_pretrained`` (and related) calls. + + Returns + ------- + LLM2Vec + Configured wrapper around the loaded encoder and tokenizer. + """ + keys = ["pooling_mode", "max_length", "doc_max_length", "skip_instruction"] + encoder_args = {key: kwargs.pop(key, None) for key in keys if kwargs.get(key) is not None} + + tokenizer = AutoTokenizer.from_pretrained(base_model_name_or_path) + tokenizer.pad_token = tokenizer.eos_token + tokenizer.padding_side = "left" + + config = AutoConfig.from_pretrained(base_model_name_or_path) + config_class_name = config.__class__.__name__ + + model_class = cls._get_model_class(config_class_name, enable_bidirectional=enable_bidirectional) + model = model_class.from_pretrained(base_model_name_or_path, **kwargs) + + base_path = pathlib.Path(base_model_name_or_path) + config_json = base_path / "config.json" + if base_path.is_dir() and config_json.exists(): + with open(config_json, encoding="utf-8") as config_file: + config_dict = json.load(config_file) + config = PretrainedConfig.from_dict(config_dict) + model.config._name_or_path = config._name_or_path + + if hasattr(model, "peft_config"): + model = PeftModel.from_pretrained( + model, + base_model_name_or_path, + ) + model = model.merge_and_unload() + + if peft_model_name_or_path is not None: + model = PeftModel.from_pretrained( + model, + peft_model_name_or_path, + ) + if merge_peft: + model = model.merge_and_unload() + if extra_model_name_or_path is not None: + logger.info(f"Loading extra model from {extra_model_name_or_path}") + if not merge_peft: + model = model.merge_and_unload() + if isinstance(extra_model_name_or_path, str): + model = PeftModel.from_pretrained( + model, + extra_model_name_or_path, + ) + peft_model_name_or_path = extra_model_name_or_path + model = model.merge_and_unload() + elif isinstance(extra_model_name_or_path, list): + for extra_model in extra_model_name_or_path: + model = PeftModel.from_pretrained( + model, + extra_model, + ) + peft_model_name_or_path = extra_model + model = model.merge_and_unload() + else: + raise ValueError("extra_model_name_or_path should be a string or a list of strings.") + config = {} + config_addr = peft_model_name_or_path if peft_model_name_or_path is not None else base_model_name_or_path + llm2vec_config_path = pathlib.Path(config_addr) / "llm2vec_config.json" + if llm2vec_config_path.exists(): + with open(llm2vec_config_path, encoding="utf-8") as config_file: + llm2vec_config = json.load(config_file) + config.update(llm2vec_config) + logger.info(f"LLM2Vec config: {config}") + for key, value in encoder_args.items(): + config[key] = value + + return cls(model=model, tokenizer=tokenizer, **config) + + def prepare_for_tokenization(self, text): + """ + Apply model-specific chat or EOS wrappers so tokenization matches training. + + Parameters + ---------- + text : str + Raw input text before tokenization. + + Returns + ------- + str + Text with any required special tokens or chat template prefixes or suffixes. + """ + if "Llama-3" in self.model.config._name_or_path and "Instruct" in self.model.config._name_or_path: + text = "<|start_header_id|>user<|end_header_id|>\n\n" + text.strip() + "<|eot_id|>" + return text + if self.model.config._name_or_path == "microsoft/Phi-3.5-mini-instruct": + text = "<|user|>\n" + text.strip() + "<|end|>\n" + return text + if self.pooling_mode == "eos_token": + if self.model.config._name_or_path == "meta-llama/Meta-Llama-3-8B": + text = text.strip() + "<|end_of_text|>" + elif isinstance(self.model.config, LlamaConfig): + text = text.strip() + " " + return text + + def tokenize(self, texts): + """ + Tokenize texts with optional embed-region markers for instruction/document split. + + Parameters + ---------- + texts : list of str + Strings that may contain the ``!@#$%^&*()`` delimiter between instruction and document. + + Returns + ------- + dict[str, torch.Tensor] + Tokenizer outputs including ``embed_mask`` when the delimiter is present. + """ + texts_2 = [] + original_texts = [] + for text in texts: + t = text.split("!@#$%^&*()") + texts_2.append(t[1] if len(t) > 1 else "") + original_texts.append("".join(t)) + + original = self.tokenizer( + original_texts, + return_tensors="pt", + padding=True, + truncation=True, + max_length=self.max_length, + ) + embed_mask = None + for t_i, t in enumerate(texts_2): + ids = self.tokenizer( + [t], + return_tensors="pt", + padding=True, + truncation=True, + max_length=self.max_length, + add_special_tokens=False, + ) + if embed_mask is None: + e_m = torch.zeros_like(original["attention_mask"][t_i]) + if len(ids["input_ids"][0]) > 0: + e_m[-len(ids["input_ids"][0]) :] = torch.ones(len(ids["input_ids"][0])) + embed_mask = e_m.unsqueeze(0) + else: + e_m = torch.zeros_like(original["attention_mask"][t_i]) + if len(ids["input_ids"][0]) > 0: + e_m[-len(ids["input_ids"][0]) :] = torch.ones(len(ids["input_ids"][0])) + embed_mask = torch.cat((embed_mask, e_m.unsqueeze(0)), dim=0) + + original["embed_mask"] = embed_mask + return original + + def _skip_instruction(self, sentence_feature): + assert sentence_feature["attention_mask"].shape == sentence_feature["embed_mask"].shape + sentence_feature["attention_mask"] = sentence_feature["embed_mask"] + + def forward(self, sentence_feature: Dict[str, Tensor]): + """ + Run the encoder and return pooled sentence embeddings. + + Parameters + ---------- + sentence_feature : dict[str, torch.Tensor] + Batch of tokenizer outputs; may include ``embed_mask`` for instruction masking. + + Returns + ------- + torch.Tensor + Pooled embeddings with shape ``(batch_size, hidden_size)``. + """ + embed_mask = None + if "embed_mask" in sentence_feature: + embed_mask = sentence_feature.pop("embed_mask") + reps = self.model(**sentence_feature) + if embed_mask is not None: + sentence_feature["embed_mask"] = embed_mask + + return self.get_pooling(sentence_feature, reps.last_hidden_state) + + def get_pooling(self, features, last_hidden_states): + """ + Pool token hidden states according to ``pooling_mode``. + + Parameters + ---------- + features : dict[str, torch.Tensor] + Tokenizer batch (attention mask, optional ``embed_mask``, etc.). + last_hidden_states : torch.Tensor + Sequence hidden states from the encoder, shape ``(batch, seq, hidden)``. + + Returns + ------- + torch.Tensor + Pooled embeddings, shape ``(batch, hidden)``. + """ + assert self.tokenizer.padding_side == "left", "Pooling modes are implemented for padding from left." + if self.skip_instruction: + self._skip_instruction(features) + seq_lengths = features["attention_mask"].sum(dim=-1) + if self.pooling_mode == "mean": + return torch.stack( + [last_hidden_states[i, -length:, :].mean(dim=0) for i, length in enumerate(seq_lengths)], + dim=0, + ) + elif self.pooling_mode == "weighted_mean": + bs, seq_len, _ = last_hidden_states.shape + complete_weights = torch.zeros(bs, seq_len, device=last_hidden_states.device) + for i, seq_l in enumerate(seq_lengths): + if seq_l > 0: + complete_weights[i, -seq_l:] = torch.arange(seq_l) + 1 + complete_weights[i] /= torch.clamp(complete_weights[i].sum(), min=1e-9) + return torch.sum(last_hidden_states * complete_weights.unsqueeze(-1), dim=1) + elif self.pooling_mode == "eos_token" or self.pooling_mode == "last_token": + return last_hidden_states[:, -1] + elif self.pooling_mode == "bos_token": + return last_hidden_states[features["input_ids"] == self.tokenizer.bos_token_id] + else: + raise ValueError(f"{self.pooling_mode} is not implemented yet.") + + def _convert_to_str(self, instruction, text): + tokenized_q = self.tokenizer( + text, + return_tensors="pt", + padding=True, + truncation=True, + max_length=self.max_length, + add_special_tokens=False, + ) + tokenized_q_length = len(tokenized_q["input_ids"][0]) + + while tokenized_q_length > self.doc_max_length: + reduction_ratio = self.doc_max_length / tokenized_q_length + reduced_length = int(len(text.split()) * reduction_ratio) + text = " ".join(text.split()[:reduced_length]) + tokenized_q = self.tokenizer( + text, + return_tensors="pt", + padding=True, + truncation=True, + max_length=self.max_length, + add_special_tokens=False, + ) + tokenized_q_length = len(tokenized_q["input_ids"][0]) + + return f"{instruction.strip()} !@#$%^&*(){text}" if instruction else f"!@#$%^&*(){text}" + + def encode( + self, + sentences: Union[str, List[str]], + batch_size: int = 32, + show_progress_bar: bool = True, + convert_to_numpy: bool = False, + convert_to_tensor: bool = True, + device: Optional[str] = None, + ): + """ + Encode sentences (optionally instruction + document) to embedding tensors. + + Parameters + ---------- + sentences : str, list of str, or nested list + Plain strings, or ``[instruction, document]`` pairs, or batches thereof. + batch_size : int, optional + Micro-batch size during encoding. + show_progress_bar : bool, optional + Ignored; progress is disabled in the implementation. + convert_to_numpy : bool, optional + If True, return a NumPy array instead of a tensor (mutually exclusive with ``convert_to_tensor``). + convert_to_tensor : bool, optional + If True (default), return a ``torch.Tensor`` of dtype float32. + device : str, optional + Device name; defaults to CUDA when available else CPU. + + Returns + ------- + torch.Tensor or numpy.ndarray + Stacked embeddings for all inputs, reordered to the original sentence order. + """ + seq: Any = sentences + if isinstance(seq[0], str) and isinstance(seq[-1], int): + seq = [seq] + if isinstance(seq[0], str): + seq = [[""] + [sentence] for sentence in seq] + + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + + concatenated_input_texts = [] + for sentence in seq: + assert isinstance(sentence[0], str) + assert isinstance(sentence[1], str) + concatenated_input_texts.append(self._convert_to_str(sentence[0], sentence[1])) + sentences = concatenated_input_texts + + self.train(mode=False) + + if convert_to_tensor: + convert_to_numpy = False + + length_sorted_idx = np.argsort([-self._text_length(sen) for sen in sentences]) + sentences_sorted = [sentences[idx] for idx in length_sorted_idx] + all_embeddings = [] + + self.to(device) + for start_index in trange( + 0, + len(sentences), + batch_size, + desc="Batches", + disable=True, + ): + sentences_batch = sentences_sorted[start_index : start_index + batch_size] + embeddings = self._encode(sentences_batch, device=device, convert_to_numpy=convert_to_numpy) + all_embeddings.append(embeddings) + + all_embeddings = torch.cat(all_embeddings, dim=0) + all_embeddings = all_embeddings[np.argsort(length_sorted_idx)] + all_embeddings = all_embeddings.to(torch.float32) + return all_embeddings + + def save(self, output_path, merge_before_save=False, save_config=True): + """ + Persist model, tokenizer, and optional ``llm2vec_config.json`` to ``output_path``. + + Parameters + ---------- + output_path : str or pathlib.Path + Directory to write weights and tokenizer files into. + merge_before_save : bool, optional + If True and the inner model is a ``PeftModel``, merge adapters before saving. + save_config : bool, optional + If True, write ``llm2vec_config.json`` with pooling and length settings. + """ + if merge_before_save and isinstance(self.model, PeftModel): + self.model = self.model.merge_and_unload() + if hasattr(self.model, "_hf_peft_config_loaded"): + setattr(self.model, "_hf_peft_config_loaded", False) + + self.model.save_pretrained(output_path) + self.tokenizer.save_pretrained(output_path) + + llm2vec_config = { + "pooling_mode": self.pooling_mode, + "max_length": self.max_length, + "doc_max_length": self.doc_max_length, + "skip_instruction": self.skip_instruction, + } + + if save_config: + pathlib.Path(output_path).mkdir(exist_ok=True, parents=True) + config_out = pathlib.Path(output_path) / "llm2vec_config.json" + with open(config_out, "w", encoding="utf-8") as config_file: + json.dump(llm2vec_config, config_file, indent=4) + + def _encode( + self, + sentences_batch, + device: Optional[str] = None, + convert_to_numpy: bool = False, + multiprocessing=False, + ): + if multiprocessing: + rank = mp.current_process()._identity[0] + if device is None and torch.cuda.is_available(): + device = f"cuda:{rank % torch.cuda.device_count()}" + + use_device = device if device is not None else ("cuda" if torch.cuda.is_available() else "cpu") + self.to(use_device) + features = self.tokenize([self.prepare_for_tokenization(sentence) for sentence in sentences_batch]) + features = batch_to_device(features, use_device) + + with torch.no_grad(): + embeddings = self.forward(features) + return embeddings + + def _text_length(self, text: Union[List[int], List[List[int]]]): + if isinstance(text, str) or (isinstance(text, list) and isinstance(text[0], int)) or len(text) == 0: + return len(text) + if isinstance(text, dict): + return len(next(iter(text.values()))) + elif not hasattr(text, "__len__"): + return 1 + else: + return sum(len(t) if not isinstance(t, int) else 1 for t in text) + + def resize_token_embeddings( + self, + new_num_tokens: Optional[int] = None, + pad_to_multiple_of: Optional[int] = None, + ) -> nn.Embedding: + """ + Resize the underlying model token embedding matrix. + + Parameters + ---------- + new_num_tokens : int, optional + New vocabulary size for the embedding table. + pad_to_multiple_of : int, optional + Pad vocabulary size to a multiple of this value when resizing. + + Returns + ------- + torch.nn.Embedding + The resized embedding module from the wrapped model. + """ + return self.model.resize_token_embeddings(new_num_tokens=new_num_tokens, pad_to_multiple_of=pad_to_multiple_of) + + def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None): + """ + Enable gradient checkpointing on the wrapped model. + + Parameters + ---------- + gradient_checkpointing_kwargs : dict, optional + Keyword arguments forwarded to the underlying ``gradient_checkpointing_enable`` call. + """ + self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=gradient_checkpointing_kwargs) diff --git a/src/pruna/evaluation/metrics/vendor/oneig_llm2vec/modeling_llama_encoder.py b/src/pruna/evaluation/metrics/vendor/oneig_llm2vec/modeling_llama_encoder.py new file mode 100644 index 00000000..cf9b4df8 --- /dev/null +++ b/src/pruna/evaluation/metrics/vendor/oneig_llm2vec/modeling_llama_encoder.py @@ -0,0 +1,107 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Derived from McGill-NLP/LLM2Vec-Meta-Llama-3-8B-Instruct-mntp ``modeling_llama_encoder.py`` +# (Hugging Face Hub). Upstream requires ``flash_attention_2`` only; this copy allows ``eager`` +# and ``sdpa`` so ``oneig_reasoning`` can run on CPU without ``flash_attn``. See +# ``NOTICE.oneig_llm2vec`` in the parent ``vendor`` directory. + +import importlib.metadata + +from packaging import version +from torch import nn +from transformers import LlamaConfig, LlamaModel, LlamaPreTrainedModel +from transformers.modeling_layers import GradientCheckpointingLayer +from transformers.models.llama.modeling_llama import ( + LlamaAttention, + LlamaDecoderLayer, + LlamaMLP, + LlamaRMSNorm, + LlamaRotaryEmbedding, +) +from transformers.utils import logging +from transformers.utils.import_utils import _is_package_available + +logger = logging.get_logger(__name__) + + +def is_transformers_attn_greater_or_equal_4_56_2() -> bool: + """ + Check whether the installed ``transformers`` package is at least 4.56.2. + + Returns + ------- + bool + True if ``transformers`` is installed and its version is >= 4.56.2; + False otherwise. + """ + if not _is_package_available("transformers"): + return False + return version.parse(importlib.metadata.version("transformers")) >= version.parse("4.56.2") + + +class ModifiedLlamaAttention(LlamaAttention): + """ + Llama self-attention with ``is_causal`` disabled for encoder-style use. + + Parameters + ---------- + *args, **kwargs + Forwarded to :class:`~transformers.models.llama.modeling_llama.LlamaAttention`. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.is_causal = False + + +class ModifiedLlamaDecoderLayer(LlamaDecoderLayer): + """ + Decoder block using :class:`ModifiedLlamaAttention` for bidirectional encoding. + + Parameters + ---------- + config : LlamaConfig + Model configuration. + layer_idx : int + Index of this decoder layer. + """ + + def __init__(self, config: LlamaConfig, layer_idx: int): + GradientCheckpointingLayer.__init__(self) + self.hidden_size = config.hidden_size + self.self_attn = ModifiedLlamaAttention(config=config, layer_idx=layer_idx) + self.mlp = LlamaMLP(config) + self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + +class LlamaEncoderModel(LlamaModel): + """ + Bidirectional Llama stack for LLM2Vec-style encoding (eager, SDPA, or flash attention). + + Parameters + ---------- + config : LlamaConfig + Model configuration (requires transformers >= 4.56.2 layout). + """ + + def __init__(self, config: LlamaConfig) -> None: + if not is_transformers_attn_greater_or_equal_4_56_2(): + raise ValueError( + "The current implementation of LlamaEncoderModel follows modeling_llama.py " + "of transformers version >= 4.56.2" + ) + LlamaPreTrainedModel.__init__(self, config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [ModifiedLlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + attn_impl = getattr(config, "_attn_implementation", getattr(config, "attn_implementation", "eager")) + self._use_sdpa = attn_impl == "sdpa" + self._use_flash_attention_2 = attn_impl == "flash_attention_2" + self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = LlamaRotaryEmbedding(config=config) + self.gradient_checkpointing = False + self.post_init() diff --git a/src/pruna/evaluation/metrics/vendor/oneig_llm2vec/models/bidirectional_llama.py b/src/pruna/evaluation/metrics/vendor/oneig_llm2vec/models/bidirectional_llama.py new file mode 100644 index 00000000..610853ac --- /dev/null +++ b/src/pruna/evaluation/metrics/vendor/oneig_llm2vec/models/bidirectional_llama.py @@ -0,0 +1,228 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Vendored from OneIG-Benchmark (commit 41b49831e79e6dde5323618c164da1c4cf0f699d). + +import importlib.metadata +from typing import cast + +import torch +from packaging import version +from peft import PeftModel +from torch import nn +from transformers import ( + LlamaConfig, + LlamaForCausalLM, + LlamaModel, + LlamaPreTrainedModel, +) +from transformers.modeling_attn_mask_utils import AttentionMaskConverter +from transformers.models.llama.modeling_llama import ( + LlamaDecoderLayer, + LlamaRMSNorm, + LlamaRotaryEmbedding, +) +from transformers.utils import logging +from transformers.utils.import_utils import _is_package_available + +logger = logging.get_logger(__name__) + + +def is_transformers_attn_greater_or_equal_4_38() -> bool: + """ + Check whether the installed ``transformers`` package is at least 4.38.0. + + Returns + ------- + bool + True if ``transformers`` is installed and its version is >= 4.38.0; + False otherwise. + """ + if not _is_package_available("transformers"): + return False + return version.parse(importlib.metadata.version("transformers")) >= version.parse("4.38.0") + + +def is_transformers_attn_greater_or_equal_4_40() -> bool: + """ + Check whether the installed ``transformers`` package is at least 4.40.0. + + Returns + ------- + bool + True if ``transformers`` is installed and its version is >= 4.40.0; + False otherwise. + """ + if not _is_package_available("transformers"): + return False + return version.parse(importlib.metadata.version("transformers")) >= version.parse("4.40.0") + + +class ModifiedLlamaDecoderLayer(LlamaDecoderLayer): + """ + Decoder layer with non-causal self-attention when supported by the attention module. + + Parameters + ---------- + config : LlamaConfig + Model configuration. + layer_idx : int + Index of this decoder layer. + """ + + def __init__(self, config: LlamaConfig, layer_idx: int): + super().__init__(config, layer_idx) + if hasattr(self.self_attn, "is_causal"): + self.self_attn.is_causal = False + + +class LlamaBiModel(LlamaModel): + """ + Bidirectional Llama backbone for MNTP-style training (transformers >= 4.38). + + Parameters + ---------- + config : LlamaConfig + Model configuration. + """ + + _no_split_modules = ["ModifiedLlamaDecoderLayer"] + + def __init__(self, config: LlamaConfig): + if not is_transformers_attn_greater_or_equal_4_38(): + raise ValueError( + "The current implementation of LlamaBiModel follows modeling_llama.py of transformers version >= 4.38.0" + ) + LlamaPreTrainedModel.__init__(self, config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + + self.layers = nn.ModuleList( + [ModifiedLlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + + self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = LlamaRotaryEmbedding(config=config) + self.gradient_checkpointing = False + + self.post_init() + + def _update_causal_mask( + self, + attention_mask, + input_tensor, + cache_position, + past_seen_tokens=None, + output_attentions=False, + ): + attn_impl = getattr(self.config, "_attn_implementation", getattr(self.config, "attn_implementation", "eager")) + if attn_impl == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + + if hasattr(getattr(self.layers[0], "self_attn", {}), "past_key_value"): + target_length = self.config.max_position_embeddings + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else ( + cache_position[-1] + 1 + if not is_transformers_attn_greater_or_equal_4_40() + else past_seen_tokens + sequence_length + 1 + ) + ) + + causal_mask = torch.zeros((sequence_length, target_length), dtype=dtype, device=device) + + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) + + if attention_mask is not None: + causal_mask = causal_mask.clone() + if attention_mask.dim() == 2: + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0) + causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype) + elif attention_mask.dim() == 4: + offset = cache_position[0] if attention_mask.shape[-2] < cache_position[0] + sequence_length else 0 + mask_shape = attention_mask.shape + mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype + causal_mask[ + : mask_shape[0], + : mask_shape[1], + offset : mask_shape[2] + offset, + : mask_shape[3], + ] = mask_slice + + attn_impl = getattr(self.config, "_attn_implementation", getattr(self.config, "attn_implementation", "eager")) + if ( + attn_impl == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + causal_mask = AttentionMaskConverter._unmask_unattended( + cast(torch.FloatTensor, causal_mask.to(dtype=torch.float32)), + min_dtype, + ) + + return causal_mask + + +class LlamaBiForMNTP(LlamaForCausalLM): + """ + Causal LM wrapper around :class:`LlamaBiModel` for MNTP with optional PEFT. + + Parameters + ---------- + config : LlamaConfig + Model configuration. + """ + + def __init__(self, config: LlamaConfig): + LlamaPreTrainedModel.__init__(self, config) + self.model = LlamaBiModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + self.post_init() + + def get_model_for_peft(self) -> LlamaBiModel | PeftModel: + """ + Return the inner model for PEFT wrapping (base or wrapped). + + Returns + ------- + LlamaBiModel or PeftModel + ``self.model``, either a :class:`LlamaBiModel` or a :class:`peft.PeftModel`. + """ + return self.model + + def set_model_for_peft(self, model: PeftModel) -> None: + """ + Replace the inner model with a PEFT-wrapped model. + + Parameters + ---------- + model : PeftModel + PEFT model whose base matches the expected backbone. + """ + self.model = model + + def save_peft_model(self, path: str) -> None: + """ + Save the (possibly PEFT-wrapped) inner model to disk. + + Parameters + ---------- + path : str + Directory path passed to ``save_pretrained`` on the inner model. + """ + self.model.save_pretrained(path) From fb6d9675aef4b433cb6fc8ae43b56d810c8c7163 Mon Sep 17 00:00:00 2001 From: David Berenstein Date: Tue, 5 May 2026 11:50:40 +0200 Subject: [PATCH 2/6] fix(vendor): honor llm2vec length and numpy flags Patch two upstream llm2vec behavior bugs found in review so downstream VLM metrics use caller-provided doc_max_length and can return numpy when requested. Document Pruna's vendor deviations in NOTICE for traceability. Co-authored-by: Cursor --- src/pruna/evaluation/metrics/vendor/NOTICE.oneig_llm2vec | 5 +++++ src/pruna/evaluation/metrics/vendor/oneig_llm2vec/llm2vec.py | 4 +++- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/src/pruna/evaluation/metrics/vendor/NOTICE.oneig_llm2vec b/src/pruna/evaluation/metrics/vendor/NOTICE.oneig_llm2vec index 01654bd4..611a8d24 100644 --- a/src/pruna/evaluation/metrics/vendor/NOTICE.oneig_llm2vec +++ b/src/pruna/evaluation/metrics/vendor/NOTICE.oneig_llm2vec @@ -10,3 +10,8 @@ See the project repository for full license text. ``oneig_llm2vec/modeling_llama_encoder.py`` is derived from McGill-NLP/LLM2Vec-Meta-Llama-3-8B-Instruct-mntp (Hugging Face Hub); Pruna relaxes the upstream flash-attention-only constraint for CPU use. + +Pruna also includes two minimal compatibility fixes in +``oneig_llm2vec/llm2vec.py``: +- Preserve constructor-provided ``doc_max_length`` instead of hardcoding 512. +- Honor ``convert_to_numpy=True`` in ``encode()`` by returning ``numpy.ndarray``. diff --git a/src/pruna/evaluation/metrics/vendor/oneig_llm2vec/llm2vec.py b/src/pruna/evaluation/metrics/vendor/oneig_llm2vec/llm2vec.py index 102f5b28..e6073e1d 100644 --- a/src/pruna/evaluation/metrics/vendor/oneig_llm2vec/llm2vec.py +++ b/src/pruna/evaluation/metrics/vendor/oneig_llm2vec/llm2vec.py @@ -85,7 +85,7 @@ def __init__( self.pooling_mode = pooling_mode self.skip_instruction = skip_instruction self.max_length = max_length - self.doc_max_length = 512 + self.doc_max_length = doc_max_length self.config = model.config @classmethod @@ -448,6 +448,8 @@ def encode( all_embeddings = torch.cat(all_embeddings, dim=0) all_embeddings = all_embeddings[np.argsort(length_sorted_idx)] all_embeddings = all_embeddings.to(torch.float32) + if convert_to_numpy: + return all_embeddings.cpu().numpy() return all_embeddings def save(self, output_path, merge_before_save=False, save_config=True): From 833186425ce1dcb86a09e82eb08e59e10f84837e Mon Sep 17 00:00:00 2001 From: David Berenstein Date: Thu, 4 Jun 2026 07:45:45 +0200 Subject: [PATCH 3/6] fix(metrics): unblock CI sync and VLM metric init Drop the broken Intel uv index (aligned with main), fix QAAccuracy keyword-only aggregation syntax, pass single/y_gt call types correctly for OneIG alignment, and expose metric_units on results. Co-authored-by: Cursor --- pyproject.toml | 49 +++++++++++++++++++------------------------------ 1 file changed, 19 insertions(+), 30 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index dc42053e..a603847b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,9 +36,6 @@ possibly-missing-attribute = "ignore" missing-argument = "ignore" unused-type-ignore-comment = "ignore" -[tool.bandit] -exclude_dirs = ["tests", "docs"] - [tool.coverage.run] source = ["src/pruna"] @@ -70,29 +67,21 @@ name = "pruna_internal" url = "https://prunaai.pythonanywhere.com/simple/" explicit = true -[[tool.uv.index]] -name = "intel-pytorch-extension" -url = "https://pytorch-extension.intel.com/release-whl/stable/cpu/cn/" -explicit = true - [tool.uv] index-strategy = "first-index" +exclude-newer = "1 week" # protection against compromised dependencies +# trusted dev wheels that are missing an upload date +exclude-newer-package = { gptqmodel = false, "stable-fast-pruna" = false } conflicts = [ [{ extra = "awq" }, { extra = "vbench" }], [{ extra = "vllm" }, { extra = "vbench" }], - [{ extra = "intel" }, { extra = "awq" }], [{ extra = "gptq" }, { extra = "awq" }], - # intel is incompatible with all stable-fast variants and vllm - [{ extra = "intel" }, { extra = "stable-fast" }, { extra = "stable-fast-extraindex" }], - [{ extra = "intel" }, { extra = "full" }, { extra = "stable-fast-extraindex" }], - [{ extra = "intel" }, { extra = "vllm" }], [{ extra = "kvpress" }, { extra = "vbench" }], ] [tool.uv.sources] gptqmodel = { index = "pruna_internal", marker = "sys_platform != 'darwin' or platform_machine != 'arm64'" } -intel-extension-for-pytorch = { index = "intel-pytorch-extension" } stable-fast-pruna = { index = "pruna_internal", extra = "stable-fast-extraindex" } [project] @@ -171,6 +160,21 @@ vllm = [ "vllm>=0.16.0", "ray", ] +rapidata = [ + "rapidata>=3.0.0", +] +upscale = [ + "realesrgan", +] +evaluation = [ + "pruna[rapidata]", + "pruna[lmharness]", + "outlines>1.2.0,<2.0.0", + "litellm>=1.0.0", +] +oneig-reasoning = [ + "hf_transfer>=0.1.9", +] stable-fast = [ "xformers>=0.0.30", "stable-fast-pruna>=1.0.8,<1.0.9", @@ -195,18 +199,12 @@ awq = [ "llmcompressor>=0.9", "torch>=2.9.0" ] -upscale = [ - "realesrgan", -] full = [ "pruna[stable-fast]", ] vbench = [ "vbench-pruna; sys_platform != 'darwin'", ] -rapidata = [ - "rapidata>=3.0.0" -] dev = [ "wget", "python-dotenv", @@ -233,22 +231,13 @@ dev = [ "types-PyYAML", "logbar", "pytest-xdist>=3.8.0", + "pruna[evaluation]", ] cpu = [] lmharness = [ "lm-eval>=0.4.0" ] -evaluation = [ - "pruna[rapidata]", - "pruna[lmharness]" -] -# Intel extension is tightly coupled with the torch version -intel = [ - "intel-extension-for-pytorch>=2.7.0", - "torch>=2.7.0,<2.9.0", - "torchvision>=0.22.0,<0.24.0", -] kvpress = [ "kvpress>=0.5.2", ] From 8e536d9cd18fb52edce1a044d651a461688f96cd Mon Sep 17 00:00:00 2001 From: David Berenstein Date: Thu, 4 Jun 2026 09:34:34 +0200 Subject: [PATCH 4/6] fix(ci): lint/docstrings and stack-appropriate VLM tests Replace forward-import VLM test module on pre-e2e branches with infrastructure-only tests; propagate docstring and conftest fixes. Co-authored-by: Cursor --- src/pruna/evaluation/metrics/metric_torch.py | 47 +++++++++++++------- tests/conftest.py | 7 +++ 2 files changed, 37 insertions(+), 17 deletions(-) diff --git a/src/pruna/evaluation/metrics/metric_torch.py b/src/pruna/evaluation/metrics/metric_torch.py index 4d329d86..b2c16f00 100644 --- a/src/pruna/evaluation/metrics/metric_torch.py +++ b/src/pruna/evaluation/metrics/metric_torch.py @@ -50,6 +50,26 @@ ) from pruna.logging.logger import pruna_logger +_PRUNA_TASK_ROUTING_KWARGS: tuple[str, ...] = ( + "vlm_type", + "model_name", + "structured_output", + "vlm_kwargs", + "api_key", +) + + +def _strip_task_routing_kwargs(kwargs: dict[str, Any]) -> None: + """ + Drop kwargs :class:`~pruna.evaluation.task.Task` passes when building mixed metric lists. + + Torchmetrics classes often end with ``**kwargs`` and would otherwise accept bogus keys + until a lower layer raises. Stripping here keeps :class:`TorchMetricWrapper` the single + choke point between Pruna routing and torchmetrics constructors. + """ + for key in _PRUNA_TASK_ROUTING_KWARGS: + kwargs.pop(key, None) + def default_update(metric: Metric, *args, **kwargs) -> None: """ @@ -124,9 +144,7 @@ def arniqa_update(metric: ARNIQA, preds: Any) -> None: def ssim_update( - metric: StructuralSimilarityIndexMeasure | MultiScaleStructuralSimilarityIndexMeasure, - preds: Any, - target: Any + metric: StructuralSimilarityIndexMeasure | MultiScaleStructuralSimilarityIndexMeasure, preds: Any, target: Any ) -> None: """ Update handler for SSIM or MS-SSIM metric. @@ -152,29 +170,22 @@ class TorchMetrics(Enum): """ Enumeration of torchmetrics metrics for evaluation. - This enum provides a tuple per member (metric_factory, update_fn, call_type): - metric_factory builds the metric (typically a torchmetrics class, or - functools.partial when some constructor arguments are fixed); update_fn is - an optional custom update handler; call_type describes how inputs are paired - for the metric. + Each member value is a ``(metric_factory, update_fn, call_type)`` tuple. Parameters ---------- value : tuple - Tuple holding metric_factory, update_fn, and call_type as described above. + ``(metric_factory, update_fn, call_type)`` for this enum member. names : str - The name of the enum member. + Enum member name. module : str - The module where the enum is defined. + Defining module name. qualname : str - The qualified name of the enum. + Qualified name of the enum class. type : type - The type of the enum. + Enum metaclass type. start : int - The start index for auto-numbering enum values. - boundary : enum.FlagBoundary or None - Boundary handling mode used by the Enum functional API for Flag and - IntFlag enums. + Auto-numbering start index for functional API enums. """ fid = (FrechetInceptionDistance, fid_update, "gt_y") @@ -246,6 +257,7 @@ def __new__(cls, metric_name: str, call_type: str = "", **kwargs) -> StatefulMet if metric_name == "clip_score" and call_type.startswith(PAIRWISE): from pruna.evaluation.metrics.metric_pairwise_clip import PairwiseClipScore + _strip_task_routing_kwargs(kwargs) return PairwiseClipScore(**kwargs) return super().__new__(cls) @@ -259,6 +271,7 @@ def __init__(self, metric_name: str, call_type: str = "", **kwargs) -> None: If the metric name is not supported. """ self.metric_name = metric_name + _strip_task_routing_kwargs(kwargs) super().__init__(kwargs.pop("device", None)) try: self.metric = TorchMetrics[metric_name](**kwargs) diff --git a/tests/conftest.py b/tests/conftest.py index 80d54825..6dff757b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,7 +1,14 @@ +import os from typing import Any import pytest +if os.environ.get("PRUNA_CI_CPU_ONLY") == "1": + import torch + + if hasattr(torch.backends, "mps"): + torch.backends.mps.is_available = lambda: False # type: ignore[method-assign] + # import all fixtures to make them avaliable for pytest from .fixtures import * # noqa: F403, F401 From 33e15517838f7cd3bed44222a8a1c030e07ec4b8 Mon Sep 17 00:00:00 2001 From: David Berenstein Date: Thu, 4 Jun 2026 09:55:42 +0200 Subject: [PATCH 5/6] fix(ci): ruff on infra VLM test template Co-authored-by: Cursor --- scripts/test_vlm_base_infrastructure_infra.py | 121 ++++++++++++++++++ 1 file changed, 121 insertions(+) create mode 100644 scripts/test_vlm_base_infrastructure_infra.py diff --git a/scripts/test_vlm_base_infrastructure_infra.py b/scripts/test_vlm_base_infrastructure_infra.py new file mode 100644 index 00000000..524cd49b --- /dev/null +++ b/scripts/test_vlm_base_infrastructure_infra.py @@ -0,0 +1,121 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for VLM base classes and vlm_utils (infrastructure PR only).""" + +from unittest.mock import MagicMock, patch + +import pytest + +from pruna.evaluation.metrics.vlm_base import BaseVLM, LitellmVLM, get_vlm +from pruna.evaluation.metrics.vlm_utils import FloatOutput, get_score_from_response, yes_no_first_token_id_groups + + +@pytest.mark.parametrize( + ("raw", "expected"), + [ + (FloatOutput(score=8.0), 0.8), + ({"score": 5.0}, 0.5), + ('{"score": 7.5}', 0.75), + ('{"score": 10}', 1.0), + ("8", 0.8), + ("Score: 7.5 out of 10", 0.75), + ("", 0.0), + ], +) +def test_get_score_from_response(raw: object, expected: float) -> None: + """``get_score_from_response`` maps pydantic, dict, JSON, and text to ``[0, 1]``.""" + assert get_score_from_response(raw) == pytest.approx(expected) + + +@pytest.mark.cpu +def test_get_vlm_returns_custom() -> None: + """get_vlm returns the provided VLM instance unchanged.""" + custom = MagicMock(spec=BaseVLM) + out = get_vlm(vlm=custom, vlm_type="litellm", model_name="gpt-4o") + assert out is custom + + +@pytest.mark.cpu +def test_yes_no_first_token_id_groups_disjoint() -> None: + """Prefix token ids for Yes vs No should not overlap (avoids double-counting).""" + pytest.importorskip("transformers") + from transformers import AutoTokenizer + + tok = AutoTokenizer.from_pretrained("gpt2") + yes_ids, no_ids = yes_no_first_token_id_groups(tok) + assert yes_ids and no_ids + assert not (set(yes_ids) & set(no_ids)) + + +@pytest.mark.cpu +def test_get_vlm_requires_model_name_without_vlm() -> None: + """get_vlm raises ValueError when no model_name is given and no vlm is provided.""" + with pytest.raises(ValueError, match="model_name"): + get_vlm(vlm=None, vlm_type="litellm") + + +@pytest.mark.cpu +def test_litellm_logprob_aggregation_sums_all_yes_tokens() -> None: + """LitellmVLM logprob scoring must sum all yes-prefix token probs, not return the first.""" + pytest.importorskip("litellm") + import math + + import numpy as np + from PIL import Image + + def make_top_logprob(token, logprob): + t = MagicMock() + t.token = token + t.logprob = logprob + return t + + first_tok = MagicMock() + first_tok.top_logprobs = [ + make_top_logprob("Yes", math.log(0.10)), + make_top_logprob(" yes", math.log(0.05)), + make_top_logprob("No", math.log(0.20)), + make_top_logprob(" no", math.log(0.10)), + make_top_logprob("maybe", math.log(0.55)), + ] + + mock_logprobs = MagicMock() + mock_logprobs.content = [first_tok] + + mock_choice = MagicMock() + mock_choice.logprobs = mock_logprobs + mock_choice.message.content = "Yes" + + mock_response = MagicMock() + mock_response.choices = [mock_choice] + + with patch("litellm.completion", return_value=mock_response): + vlm = LitellmVLM(model_name="openai/gpt-4o") + img = Image.fromarray(np.zeros((32, 32, 3), dtype="uint8")) + score = vlm._score_with_logprobs(img, "Is there a cat?", "Yes") + + assert 0.28 < score < 0.40, f"Expected ~0.333 (sum-normalized), got {score}" + + +@pytest.mark.cpu +@pytest.mark.slow +def test_yes_no_token_ids_smolvlm_nonempty() -> None: + """SmolVLM tokenizer yields non-empty yes/no prefix id groups.""" + pytest.importorskip("transformers") + from transformers import AutoTokenizer + + tok = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolVLM-256M-Instruct") + yes_ids, no_ids = yes_no_first_token_id_groups(tok) + assert yes_ids + assert no_ids From df3be047b08618f1c80197a788ca30773662826b Mon Sep 17 00:00:00 2001 From: David Berenstein Date: Thu, 4 Jun 2026 10:18:07 +0200 Subject: [PATCH 6/6] chore: drop local-only scripts from PR scope Remove verify helper and duplicate infra test template from scripts/; tests live under tests/evaluation/ only. Co-authored-by: Cursor --- scripts/test_vlm_base_infrastructure_infra.py | 121 ------------------ 1 file changed, 121 deletions(-) delete mode 100644 scripts/test_vlm_base_infrastructure_infra.py diff --git a/scripts/test_vlm_base_infrastructure_infra.py b/scripts/test_vlm_base_infrastructure_infra.py deleted file mode 100644 index 524cd49b..00000000 --- a/scripts/test_vlm_base_infrastructure_infra.py +++ /dev/null @@ -1,121 +0,0 @@ -# Copyright 2025 - Pruna AI GmbH. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for VLM base classes and vlm_utils (infrastructure PR only).""" - -from unittest.mock import MagicMock, patch - -import pytest - -from pruna.evaluation.metrics.vlm_base import BaseVLM, LitellmVLM, get_vlm -from pruna.evaluation.metrics.vlm_utils import FloatOutput, get_score_from_response, yes_no_first_token_id_groups - - -@pytest.mark.parametrize( - ("raw", "expected"), - [ - (FloatOutput(score=8.0), 0.8), - ({"score": 5.0}, 0.5), - ('{"score": 7.5}', 0.75), - ('{"score": 10}', 1.0), - ("8", 0.8), - ("Score: 7.5 out of 10", 0.75), - ("", 0.0), - ], -) -def test_get_score_from_response(raw: object, expected: float) -> None: - """``get_score_from_response`` maps pydantic, dict, JSON, and text to ``[0, 1]``.""" - assert get_score_from_response(raw) == pytest.approx(expected) - - -@pytest.mark.cpu -def test_get_vlm_returns_custom() -> None: - """get_vlm returns the provided VLM instance unchanged.""" - custom = MagicMock(spec=BaseVLM) - out = get_vlm(vlm=custom, vlm_type="litellm", model_name="gpt-4o") - assert out is custom - - -@pytest.mark.cpu -def test_yes_no_first_token_id_groups_disjoint() -> None: - """Prefix token ids for Yes vs No should not overlap (avoids double-counting).""" - pytest.importorskip("transformers") - from transformers import AutoTokenizer - - tok = AutoTokenizer.from_pretrained("gpt2") - yes_ids, no_ids = yes_no_first_token_id_groups(tok) - assert yes_ids and no_ids - assert not (set(yes_ids) & set(no_ids)) - - -@pytest.mark.cpu -def test_get_vlm_requires_model_name_without_vlm() -> None: - """get_vlm raises ValueError when no model_name is given and no vlm is provided.""" - with pytest.raises(ValueError, match="model_name"): - get_vlm(vlm=None, vlm_type="litellm") - - -@pytest.mark.cpu -def test_litellm_logprob_aggregation_sums_all_yes_tokens() -> None: - """LitellmVLM logprob scoring must sum all yes-prefix token probs, not return the first.""" - pytest.importorskip("litellm") - import math - - import numpy as np - from PIL import Image - - def make_top_logprob(token, logprob): - t = MagicMock() - t.token = token - t.logprob = logprob - return t - - first_tok = MagicMock() - first_tok.top_logprobs = [ - make_top_logprob("Yes", math.log(0.10)), - make_top_logprob(" yes", math.log(0.05)), - make_top_logprob("No", math.log(0.20)), - make_top_logprob(" no", math.log(0.10)), - make_top_logprob("maybe", math.log(0.55)), - ] - - mock_logprobs = MagicMock() - mock_logprobs.content = [first_tok] - - mock_choice = MagicMock() - mock_choice.logprobs = mock_logprobs - mock_choice.message.content = "Yes" - - mock_response = MagicMock() - mock_response.choices = [mock_choice] - - with patch("litellm.completion", return_value=mock_response): - vlm = LitellmVLM(model_name="openai/gpt-4o") - img = Image.fromarray(np.zeros((32, 32, 3), dtype="uint8")) - score = vlm._score_with_logprobs(img, "Is there a cat?", "Yes") - - assert 0.28 < score < 0.40, f"Expected ~0.333 (sum-normalized), got {score}" - - -@pytest.mark.cpu -@pytest.mark.slow -def test_yes_no_token_ids_smolvlm_nonempty() -> None: - """SmolVLM tokenizer yields non-empty yes/no prefix id groups.""" - pytest.importorskip("transformers") - from transformers import AutoTokenizer - - tok = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolVLM-256M-Instruct") - yes_ids, no_ids = yes_no_first_token_id_groups(tok) - assert yes_ids - assert no_ids