diff --git a/pyproject.toml b/pyproject.toml index 004fa3f0..a603847b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -67,41 +67,33 @@ name = "pruna_internal" url = "https://prunaai.pythonanywhere.com/simple/" explicit = true -[[tool.uv.index]] -name = "intel-pytorch-extension" -url = "https://pytorch-extension.intel.com/release-whl/stable/cpu/cn/" -explicit = true - [tool.uv] index-strategy = "first-index" +exclude-newer = "1 week" # protection against compromised dependencies +# trusted dev wheels that are missing an upload date +exclude-newer-package = { gptqmodel = false, "stable-fast-pruna" = false } conflicts = [ [{ extra = "awq" }, { extra = "vbench" }], [{ extra = "vllm" }, { extra = "vbench" }], - [{ extra = "intel" }, { extra = "awq" }], [{ extra = "gptq" }, { extra = "awq" }], - # intel is incompatible with all stable-fast variants and vllm - [{ extra = "intel" }, { extra = "stable-fast" }, { extra = "stable-fast-extraindex" }], - [{ extra = "intel" }, { extra = "full" }, { extra = "stable-fast-extraindex" }], - [{ extra = "intel" }, { extra = "vllm" }], [{ extra = "kvpress" }, { extra = "vbench" }], ] [tool.uv.sources] gptqmodel = { index = "pruna_internal", marker = "sys_platform != 'darwin' or platform_machine != 'arm64'" } -intel-extension-for-pytorch = { index = "intel-pytorch-extension" } stable-fast-pruna = { index = "pruna_internal", extra = "stable-fast-extraindex" } [project] name = "pruna" -version = "0.3.2" +version = "0.3.3" description = "Smash your AI models" authors = [ {name = "Pruna AI", email = "hello@pruna.ai"} ] license = {file = "LICENSE"} readme = "README.md" -requires-python = ">=3.10,<3.13" +requires-python = ">=3.10,<3.14" keywords = ["AI", "machine learning", "model optimization", "pruning"] classifiers = [ "Development Status :: 4 - Beta", @@ -246,12 +238,6 @@ lmharness = [ "lm-eval>=0.4.0" ] -# Intel extension is tightly coupled with the torch version -intel = [ - "intel-extension-for-pytorch>=2.7.0", - "torch>=2.7.0,<2.9.0", - "torchvision>=0.22.0,<0.24.0", -] kvpress = [ "kvpress>=0.5.2", ] diff --git a/src/pruna/evaluation/benchmarks.py b/src/pruna/evaluation/benchmarks.py index e3f58164..36e50c5e 100644 --- a/src/pruna/evaluation/benchmarks.py +++ b/src/pruna/evaluation/benchmarks.py @@ -66,19 +66,19 @@ class BenchmarkRegistry: paper (see reference URL). All entries verified from paper evaluation sections (ar5iv/HTML or PDF) as of verification pass: - - Parti Prompts (2206.10789 §5.2, §5.4): human side-by-side only on P222. - - DrawBench (2205.11487 §4.3): human raters only; COCO uses FID + CLIP. + - Parti Prompts (2206.10789 ?5.2, ?5.4): human side-by-side only on P222. + - DrawBench (2205.11487 ?4.3): human raters only; COCO uses FID + CLIP. - GenAI Bench (2406.13743): VQAScore only (web/PWC; ar5iv failed). - VBench (2311.17982): 16 dimension-specific methods; no single Pruna metric. - - COCO (2205.11487 §4.1): FID and CLIP score for fidelity and alignment. - - ImageNet (1409.0575 §4): top-1/top-5 classification accuracy. - - WikiText (1609.07843 §5): perplexity on validation/test. - - GenEval (2310.11513 §3.2): Mask2Former + CLIP color pipeline, binary score. + - COCO (2205.11487 ?4.1): FID and CLIP score for fidelity and alignment. + - ImageNet (1409.0575 ?4): top-1/top-5 classification accuracy. + - WikiText (1609.07843 ?5): perplexity on validation/test. + - GenEval (2310.11513 ?3.2): Mask2Former + CLIP color pipeline, binary score. - HPS (2306.09341): HPS v2 scoring model (CLIP fine-tuned on HPD v2). - - ImgEdit (2505.20275 §4.2): GPT-4o 1–5 ratings and ImgEdit-Judge. - - Long Text Bench (2507.22058 §4): Text Accuracy (OCR, Qwen2.5-VL-7B). - - GEditBench (2504.17761 §4.2): VIEScore (SQ, PQ, O via GPT-4.1/Qwen2.5-VL). - - OneIG (2506.07977 §4.1): per-dimension metrics (semantic alignment, ED, etc.). + - ImgEdit (2505.20275 ?4.2): GPT-4o 15 ratings and ImgEdit-Judge. + - Long Text Bench (2507.22058 ?4): Text Accuracy (OCR, Qwen2.5-VL-7B). + - GEditBench (2504.17761 ?4.2): VIEScore (SQ, PQ, O via GPT-4.1/Qwen2.5-VL). + - OneIG (2506.07977 ?4.1): per-dimension metrics (semantic alignment, ED, etc.). - DPG (2403.05135): DSG-style graph score, mPLUG-large adjudicator. """ @@ -195,7 +195,7 @@ def list(cls, task_type: str | None = None) -> list[str]: "MS-COCO for text-to-image evaluation (Imagen, 2205.11487). Paper reports " "FID for fidelity and CLIP score for image-text alignment." ), - metrics=["fid", "clip_score"], # §4.1: FID + CLIP score + metrics=["fid", "clip_score"], # ?4.1: FID + CLIP score task_type="text_to_image", reference="https://arxiv.org/abs/2205.11487", ), @@ -256,7 +256,7 @@ def list(cls, task_type: str | None = None) -> list[str]: "Text-to-image benchmark for long, detailed prompts. Evaluates model ability to " "handle complex multi-clause descriptions and maintain coherence across long instructions." ), - metrics=[], # Paper uses text_score/TIT-Score; not in Pruna + metrics=[], # Paper uses word accuracy (X-Omni); not wired to text_score yet task_type="text_to_image", reference="https://arxiv.org/abs/2507.22058", ), @@ -299,6 +299,13 @@ def list(cls, task_type: str | None = None) -> list[str]: task_type="text_to_image", reference="https://arxiv.org/abs/2506.07977", ), + Benchmark( + name="OneIG Text Rendering", + description="OneIG subset: text and graphics painted into the image.", + metrics=["oneig_text_score"], + task_type="text_to_image", + reference="https://arxiv.org/abs/2506.07977", + ), Benchmark( name="DPG", description=( diff --git a/src/pruna/evaluation/metrics/__init__.py b/src/pruna/evaluation/metrics/__init__.py index 49cfe904..d17ea6dd 100644 --- a/src/pruna/evaluation/metrics/__init__.py +++ b/src/pruna/evaluation/metrics/__init__.py @@ -22,11 +22,12 @@ from pruna.evaluation.metrics.metric_evalharness import LMEvalMetric from pruna.evaluation.metrics.metric_memory import DiskMemoryMetric, InferenceMemoryMetric, TrainingMemoryMetric from pruna.evaluation.metrics.metric_model_architecture import TotalMACsMetric, TotalParamsMetric -from pruna.evaluation.metrics.metric_pairwise_clip import PairwiseClipScore from pruna.evaluation.metrics.metric_oneig_alignment import OneIGAlignmentMetric +from pruna.evaluation.metrics.metric_pairwise_clip import PairwiseClipScore from pruna.evaluation.metrics.metric_qa_accuracy import QAAccuracyMetric from pruna.evaluation.metrics.metric_rapiddata import RapidataMetric as RapidataMetric from pruna.evaluation.metrics.metric_sharpness import SharpnessMetric +from pruna.evaluation.metrics.metric_text_score import OneIGTextScoreMetric, TextScoreMetric from pruna.evaluation.metrics.metric_torch import TorchMetricWrapper from pruna.evaluation.metrics.vlm_base import ( BaseVLM, @@ -56,8 +57,10 @@ "AestheticLAION", "LMEvalMetric", "OneIGAlignmentMetric", + "OneIGTextScoreMetric", "QAAccuracyMetric", "RapidataMetric", + "TextScoreMetric", "BaseVLM", "LitellmVLM", "StatefulVLMMeanScoresMetric", diff --git a/src/pruna/evaluation/metrics/metric_oneig_alignment.py b/src/pruna/evaluation/metrics/metric_oneig_alignment.py index 0f372f4f..a8827dd7 100644 --- a/src/pruna/evaluation/metrics/metric_oneig_alignment.py +++ b/src/pruna/evaluation/metrics/metric_oneig_alignment.py @@ -151,8 +151,6 @@ class OneIGAlignmentMetric(QAAccuracyMetric): (default ``2 x 2``), score **one question per VLM call** across all cells, apply dependency masking per cell, then average cell scores. - Scoring semantics - ----------------- OneIG Q_D probes are phrased so **Yes = aligned**. Each call requests :meth:`~pruna.evaluation.metrics.vlm_base.BaseVLM.score` with expected answer ``"Yes"`` (probability of Yes). Low scores act as semantic **No** for dependency @@ -178,11 +176,9 @@ class OneIGAlignmentMetric(QAAccuracyMetric): api_key : str | None, optional API key for litellm. call_type : str, optional - Call type for the metric. - aggregation : str, optional - Unused; kept for registry compatibility with :class:`QAAccuracyMetric`. + Call type for the metric (``"single"`` or ``"pairwise"``). **kwargs : Any - Additional keyword arguments for :class:`QAAccuracyMetric`. + Forwarded to :class:`QAAccuracyMetric` (e.g. ``aggregation``). Examples -------- @@ -199,7 +195,6 @@ class OneIGAlignmentMetric(QAAccuracyMetric): def __init__( self, - *args: Any, grid_size: tuple[int, int] = (2, 2), vlm: Any | None = None, vlm_type: Literal["litellm", "transformers"] = "transformers", @@ -212,7 +207,6 @@ def __init__( **kwargs: Any, ) -> None: super().__init__( - *args, vlm=vlm, vlm_type=vlm_type, model_name=model_name, @@ -220,10 +214,11 @@ def __init__( structured_output=structured_output, device=device, api_key=api_key, - call_type=call_type if call_type is not None else "y_gt", + call_type=call_type, **kwargs, ) self.grid_size = (int(grid_size[0]), int(grid_size[1])) + self.metric_units = type(self).metric_units def _score_sample(self, image: Any, aux: dict[str, Any]) -> float: if not isinstance(image, Image.Image): diff --git a/src/pruna/evaluation/metrics/metric_qa_accuracy.py b/src/pruna/evaluation/metrics/metric_qa_accuracy.py index f954c0eb..ba5ed118 100644 --- a/src/pruna/evaluation/metrics/metric_qa_accuracy.py +++ b/src/pruna/evaluation/metrics/metric_qa_accuracy.py @@ -55,8 +55,6 @@ class QAAccuracyMetric(StatefulVLMMeanScoresMetric): Parameters ---------- - *args : Any - Additional positional arguments. vlm : BaseVLM | None, optional Custom VLM instance. If provided, ``vlm_type`` and ``model_name`` are ignored. vlm_type : {"litellm", "transformers"}, optional @@ -76,8 +74,10 @@ class QAAccuracyMetric(StatefulVLMMeanScoresMetric): API key for litellm. call_type : str, optional Call type for the metric. + aggregation : {"mean", "all_or_nothing"}, optional + Per-image score aggregation (keyword-only). Default is ``"mean"``. **kwargs : Any - Supports ``aggregation``: ``"mean"`` or ``"all_or_nothing"``. + Additional keyword arguments forwarded to the parent class. Raises ------ @@ -111,7 +111,6 @@ class QAAccuracyMetric(StatefulVLMMeanScoresMetric): def __init__( self, - *args, vlm: BaseVLM | None = None, vlm_type: Literal["litellm", "transformers"] = "litellm", model_name: str | None = None, @@ -119,7 +118,7 @@ def __init__( structured_output: bool = True, device: str | torch.device | None = None, api_key: str | None = None, - call_type: str = SINGLE, + call_type: str | None = None, *, aggregation: str = "mean", **kwargs: Any, @@ -139,7 +138,7 @@ def __init__( structured_output=structured_output, device=device, api_key=api_key, - call_type=call_type, + call_type=call_type if call_type is not None else SINGLE, ) def _extract_questions(self, gt: Any, n: int) -> list[list[str]]: diff --git a/src/pruna/evaluation/metrics/metric_text_score.py b/src/pruna/evaluation/metrics/metric_text_score.py new file mode 100644 index 00000000..b56a6957 --- /dev/null +++ b/src/pruna/evaluation/metrics/metric_text_score.py @@ -0,0 +1,372 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Text rendering via OCR: mean Levenshtein (``text_score`` / ``ocr_levenshtein``). + +OneIG composite: ``oneig_text_score`` / ``ocr_text_score``. +""" + +from __future__ import annotations + +from abc import abstractmethod +from typing import Any, Literal + +import numpy as np +import torch + +from pruna.engine.utils import set_to_best_available_device +from pruna.evaluation.metrics.metric_stateful import StatefulMetric +from pruna.evaluation.metrics.metric_text_score_utils import ( + levenshtein, + normalize_text_simple, + oneig_mean_text_score, + oneig_per_sample_contributions, +) +from pruna.evaluation.metrics.registry import MetricRegistry +from pruna.evaluation.metrics.result import MetricResult +from pruna.evaluation.metrics.utils import ( + SINGLE, + get_call_type_for_single_metric, + metric_data_processor, +) +from pruna.evaluation.metrics.vlm_base import BaseVLM, get_vlm +from pruna.evaluation.metrics.vlm_utils import TextOutput, _process_images, get_text_from_response + +OCR_PROMPT = ( + "Extract all text visible in this image. Include logos, stylized fonts, handwritten text, " + "and non-standard typography. Return only the extracted text, exactly as it appears—no preamble, " + "explanation, or markdown. Preserve words, numbers, punctuation, and spacing. " + "IMPORTANT: Do NOT correct spelling errors or typos. If a word is misspelled in the image " + "(e.g. 'Teclhology' instead of 'Technology'), reproduce it exactly as it appears, including the misspelling. " + "If no text is recognized, reply with exactly: No text recognized" +) + + +class _BaseVLMOCRTextMetric(StatefulMetric): + """ + Shared VLM OCR over rendered images with ground truth in ``text_content``. + + Subclasses implement how OCR and GT strings are scored and aggregated. + + Parameters + ---------- + *args : Any + Additional positional arguments (unused; registry compatibility). + vlm : BaseVLM | None, optional + Custom VLM instance. If provided, ``vlm_type`` and ``model_name`` are ignored. + vlm_type : {'litellm', 'transformers'}, optional + VLM backend. Default is ``'litellm'``. + model_name : str | None, optional + Litellm model id or HuggingFace checkpoint id. **Required** when ``vlm`` is not + provided (e.g. ``openai/gpt-4o``). + vlm_kwargs : dict, optional + Forwarded by ``get_vlm`` to ``LitellmVLM`` or ``TransformersVLM``. For local models, + set ``model_load_kwargs`` for ``from_pretrained``; for litellm, pass extra API options. + structured_output : bool, optional + Use structured generation (litellm pydantic; transformers outlines when applicable). + Default is True. + device : str | torch.device | None, optional + Device for transformers VLM. + api_key : str | None, optional + API key for litellm. + call_type : str, optional + Call type for the metric. + **kwargs : Any + Additional arguments. + + Examples + -------- + OCR metrics call ``get_vlm`` directly (not ``StatefulVLMMeanScoresMetric``). Same + ``hosted`` / ``local`` pattern as :func:`~pruna.evaluation.metrics.vlm_base.get_vlm`: + + .. code-block:: python + + import torch + + from pruna.evaluation.metrics import TextScoreMetric + + hosted = TextScoreMetric(vlm_type="litellm", model_name="openai/gpt-4o") + local = TextScoreMetric( + vlm_type="transformers", + model_name="HuggingFaceTB/SmolVLM-256M-Instruct", + device="cpu", + vlm_kwargs={"model_load_kwargs": {"torch_dtype": torch.float32}}, + ) + + Use ``OneIGTextScoreMetric`` the same way for ``oneig_text_score`` / ``ocr_text_score``. + """ + + default_call_type: str = "y_gt" + + def __init__( + self, + *args: Any, + vlm: BaseVLM | None = None, + vlm_type: Literal["litellm", "transformers"] = "litellm", + model_name: str | None = None, + vlm_kwargs: dict | None = None, + structured_output: bool = True, + device: str | torch.device | None = None, + api_key: str | None = None, + call_type: str = SINGLE, + **kwargs: Any, + ) -> None: + super().__init__(device=device) + self.device = set_to_best_available_device(device) + + self.vlm = get_vlm( + vlm=vlm, + vlm_type=vlm_type, + model_name=model_name, + device=device, + api_key=api_key, + structured_output=structured_output, + **(vlm_kwargs or {}), + ) + self.response_format = TextOutput if structured_output else None + + self.call_type = get_call_type_for_single_metric(call_type, self.default_call_type) + self.higher_is_better = type(self).higher_is_better + + @abstractmethod + def _accumulate_sample(self, text_gt: str, ocr_text: str) -> None: + """Update metric state from one ground-truth / OCR pair.""" + + @abstractmethod + def _compute_result_value(self) -> float: + """Return the scalar reported as ``MetricResult.result``.""" + + def update(self, x: list[Any] | torch.Tensor, gt: list[str], outputs: torch.Tensor) -> None: + """ + Run OCR on outputs and score against ``text_content`` (or string list) auxiliaries. + + Parameters + ---------- + x : List[Any] | torch.Tensor + Batch prompts or metadata. + gt : list of dict or list of str + Auxiliaries with ``'text_content'`` as a string, a list of strings (joined with + newlines), or plain strings per batch item. + outputs : torch.Tensor + Rendered images. + """ + inputs = metric_data_processor(x, gt, outputs, self.call_type) + images = _process_images(inputs[0]) + auxiliaries = inputs[1] if len(inputs) > 1 and isinstance(inputs[1], (list, tuple)) else [{}] * len(images) + for i, image in enumerate(images): + responses = self.vlm.generate([image], [OCR_PROMPT], response_format=self.response_format) + raw = responses[0] if responses else "" + ocr_text = get_text_from_response(raw) + aux = auxiliaries[i] if i < len(auxiliaries) else {} + text_gt = aux.get("text_content") if isinstance(aux, dict) else (aux if isinstance(aux, str) else None) + if isinstance(text_gt, list): + text_gt = "\n".join(str(x) for x in text_gt) + if text_gt is None: + raise ValueError( + f"{self.metric_name} requires 'text_content' in auxiliaries. " + "Use a benchmark that provides it (e.g. LongTextBench, OneIG)." + ) + self._accumulate_sample(text_gt, ocr_text) + + def compute(self) -> MetricResult: + """ + Aggregate batched contributions into a single metric value. + + Returns + ------- + MetricResult + Named result with ``higher_is_better`` taken from the class. + """ + value = self._compute_result_value() + return MetricResult(self.metric_name, self.__dict__, float(value)) + + +@MetricRegistry.register("ocr_levenshtein") +@MetricRegistry.register("text_score") +class TextScoreMetric(_BaseVLMOCRTextMetric): + """ + OCR then mean normalized character accuracy in [0, 1] (higher is better). + + Registry: ``ocr_levenshtein`` (descriptive) and ``text_score`` (legacy). + + Uses light normalization only (not the full OneIG preprocess). See + :class:`OneIGTextScoreMetric` for the OneIG composite ``ocr_text_score``. + + Parameters + ---------- + *args : Any + Additional positional arguments (unused; registry compatibility). + vlm : BaseVLM | None, optional + Custom VLM instance. If provided, ``vlm_type`` and ``model_name`` are ignored. + vlm_type : {'litellm', 'transformers'}, optional + VLM backend. Default is ``'litellm'``. + model_name : str | None, optional + Litellm model id or HuggingFace checkpoint id. **Required** when ``vlm`` is not + provided (e.g. ``openai/gpt-4o``). + vlm_kwargs : dict, optional + Forwarded by ``get_vlm`` to ``LitellmVLM`` or ``TransformersVLM``. For local models, + set ``model_load_kwargs`` for ``from_pretrained``; for litellm, pass extra API options. + structured_output : bool, optional + Use structured generation (litellm pydantic; transformers outlines when applicable). + Default is True. + device : str | torch.device | None, optional + Device for transformers VLM. + api_key : str | None, optional + API key for litellm. + call_type : str, optional + Call type for the metric. + **kwargs : Any + Additional keyword arguments forwarded to :class:`_BaseVLMOCRTextMetric`. + """ + + scores: list[float] + higher_is_better: bool = True + metric_name: str = "text_score" + + def __init__( + self, + *args: Any, + vlm: BaseVLM | None = None, + vlm_type: Literal["litellm", "transformers"] = "litellm", + model_name: str | None = None, + vlm_kwargs: dict[str, Any] | None = None, + structured_output: bool = True, + device: str | torch.device | None = None, + api_key: str | None = None, + call_type: str = SINGLE, + **kwargs: Any, + ) -> None: + super().__init__( + *args, + vlm=vlm, + vlm_type=vlm_type, + model_name=model_name, + vlm_kwargs=vlm_kwargs, + structured_output=structured_output, + device=device, + api_key=api_key, + call_type=call_type, + **kwargs, + ) + self.add_state("scores", []) + + def _accumulate_sample(self, text_gt: str, ocr_text: str) -> None: + norm_gt = normalize_text_simple(text_gt) + norm_ocr = normalize_text_simple(ocr_text) + ed = levenshtein(norm_ocr, norm_gt) + denom = max(float(len(norm_gt)), 1.0) + self.scores.append(1.0 - min(1.0, ed / denom)) + + def _compute_result_value(self) -> float: + if not self.scores: + return 0.0 + return float(np.mean(self.scores)) + + +@MetricRegistry.register("ocr_text_score") +@MetricRegistry.register("oneig_text_score") +class OneIGTextScoreMetric(_BaseVLMOCRTextMetric): + """ + OCR then OneIG-style composite text score (higher is better). + + Registry: ``ocr_text_score`` (descriptive) and ``oneig_text_score`` (protocol). + + Aggregates edit distance, completion rate, and word/char accuracy like + ``OneIG-Benchmark/scripts/text/text_score.py``. + + Parameters + ---------- + *args : Any + Additional positional arguments (forwarded to :class:`_BaseVLMOCRTextMetric`). + language_mode : {'EN', 'ZH'}, optional + Selects ``MAX_EDIT_DISTANCE`` (100 vs 50) for the composite. + vlm : BaseVLM | None, optional + Custom VLM instance. If provided, ``vlm_type`` and ``model_name`` are ignored. + vlm_type : {'litellm', 'transformers'}, optional + VLM backend. Default is ``'litellm'``. + model_name : str | None, optional + Litellm model id or HuggingFace checkpoint id. **Required** when ``vlm`` is not + provided (e.g. ``openai/gpt-4o``). + vlm_kwargs : dict, optional + Forwarded by ``get_vlm`` to ``LitellmVLM`` or ``TransformersVLM``. For local models, + set ``model_load_kwargs`` for ``from_pretrained``; for litellm, pass extra API options. + structured_output : bool, optional + Use structured generation (litellm pydantic; transformers outlines when applicable). + Default is True. + device : str | torch.device | None, optional + Device for transformers VLM. + api_key : str | None, optional + API key for litellm. + call_type : str, optional + Call type for the metric. + **kwargs : Any + Additional keyword arguments forwarded to :class:`_BaseVLMOCRTextMetric`. + """ + + edit_distances: list[float] + completion_ratios: list[float] + match_counts: list[int] + gt_totals: list[int] + + higher_is_better: bool = True + metric_name: str = "oneig_text_score" + + def __init__( + self, + *args: Any, + language_mode: Literal["EN", "ZH"] = "EN", + vlm: BaseVLM | None = None, + vlm_type: Literal["litellm", "transformers"] = "litellm", + model_name: str | None = None, + vlm_kwargs: dict[str, Any] | None = None, + structured_output: bool = True, + device: str | torch.device | None = None, + api_key: str | None = None, + call_type: str = SINGLE, + **kwargs: Any, + ) -> None: + super().__init__( + *args, + vlm=vlm, + vlm_type=vlm_type, + model_name=model_name, + vlm_kwargs=vlm_kwargs, + structured_output=structured_output, + device=device, + api_key=api_key, + call_type=call_type, + **kwargs, + ) + self.language_mode = language_mode + self.add_state("edit_distances", []) + self.add_state("completion_ratios", []) + self.add_state("match_counts", []) + self.add_state("gt_totals", []) + + def _accumulate_sample(self, text_gt: str, ocr_text: str) -> None: + ed, cr, mcount, gtot = oneig_per_sample_contributions(text_gt, ocr_text) + self.edit_distances.append(ed) + self.completion_ratios.append(cr) + self.match_counts.append(mcount) + self.gt_totals.append(gtot) + + def _compute_result_value(self) -> float: + *_, text_score = oneig_mean_text_score( + self.edit_distances, + self.completion_ratios, + self.match_counts, + self.gt_totals, + self.language_mode, + ) + return text_score diff --git a/src/pruna/evaluation/metrics/metric_text_score_utils.py b/src/pruna/evaluation/metrics/metric_text_score_utils.py new file mode 100644 index 00000000..ce8eb6e0 --- /dev/null +++ b/src/pruna/evaluation/metrics/metric_text_score_utils.py @@ -0,0 +1,272 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Helpers for text rendering metrics (simple Levenshtein vs OneIG-style composite). + +OneIG-style preprocessing and aggregation follow +`OneIG-Benchmark/scripts/text/text_utils.py` and `text_score.py` (Apache-2.0). +""" + +from __future__ import annotations + +import re +from collections import Counter +from typing import Literal + +# Known OneIG/Qwen OCR boilerplate (see OneIG ``clean_and_remove_hallucinations``). +_OCR_HALLUCINATION_KEYWORDS = ("addCriterion", "No text recognized.", "No text recognized") + + +def normalize_text_simple(s: str) -> str: + """ + Normalize text for the legacy ``text_score`` metric (light cleanup + spacing). + + Parameters + ---------- + s : str + Raw string. + + Returns + ------- + str + Normalized string. + """ + cleaned = re.sub( + r"[^\u4e00-\u9fffa-zA-Z0-9\sàâäéèêëîïôöùûüçÀÂÄÉÈÊËÎÏÔÖÙÛÜÇ]", + "", + s or "", + ) + return re.sub(r"\s+", " ", cleaned).strip() + + +def levenshtein(s1: str, s2: str) -> float: + """ + Symmetric Levenshtein edit distance. + + Parameters + ---------- + s1 : str + First string. + s2 : str + Second string. + + Returns + ------- + float + Edit distance. + """ + if len(s1) < len(s2): + return levenshtein(s2, s1) + prev = list(range(len(s2) + 1)) + for i, c1 in enumerate(s1): + curr = [i + 1] + for j, c2 in enumerate(s2): + curr.append(min(prev[j] + (c1 != c2), prev[j + 1] + 1, curr[-1] + 1)) + prev = curr + return float(prev[-1]) + + +def contains_chinese(text: str) -> bool: + """ + Return True if ``text`` contains CJK unified ideographs. + + Parameters + ---------- + text : str + Input text. + + Returns + ------- + bool + Whether Chinese characters are present. + """ + return bool(re.search(r"[\u4e00-\u9fff]", text)) + + +def preprocess_string_oneig(s: str) -> str: + """ + OneIG ``preprocess_string``: charset filter, Chinese vs whitespace normalization. + + Parameters + ---------- + s : str + Raw string. + + Returns + ------- + str + Preprocessed string (ground truth or OCR). + """ + cleaned = normalize_text_simple(s) + if contains_chinese(cleaned): + # Spaces between CJK characters are a common Qwen OCR artifact. + cleaned = re.sub(r"(?<=[\u4e00-\u9fff])\s+(?=[\u4e00-\u9fff])", "", cleaned) + pattern = re.compile( + r"[\u4e00-\u9fffa-zA-Z0-9àâäéèêëîïôöùûüçÀÂÄÉÈÊËÎÏÔÖÙÛÜÇ]", + ) + return "".join(pattern.findall(cleaned)).strip() + return cleaned + + +def clean_oneig_ocr_hallucinations(text: str) -> str: + """ + Remove known OCR boilerplate substrings (OneIG ``clean_and_remove_hallucinations``). + + Parameters + ---------- + text : str + Raw OCR output. + + Returns + ------- + str + Cleaned OCR text. + """ + out = text or "" + for keyword in _OCR_HALLUCINATION_KEYWORDS: + out = out.replace(f"\n{keyword}", "").replace(f"{keyword}\n", "").replace(keyword, "") + return out + + +def calculate_char_match_ratio( + text_gt: str, + ocr_str: str, +) -> tuple[int, float, int]: + """ + OneIG overlap stats: character multiset for ZH, word multiset for EN. + + Parameters + ---------- + text_gt : str + Preprocessed ground truth. + ocr_str : str + Preprocessed OCR. + + Returns + ------- + total_match_count : int + Overlap count used in WAC numerator aggregation. + ratio : float + Per-sample ratio (mean of ratios is not used in the official aggregate). + gt_total : int + Denominator term: ``sum(gt_counter.values())`` for WAC aggregation. + """ + if contains_chinese(text_gt): + gt_counter: Counter[str] = Counter(text_gt) + ocr_counter: Counter[str] = Counter(ocr_str) + total_match_count = int(sum((gt_counter & ocr_counter).values())) + ratio = total_match_count / len(text_gt) if len(text_gt) > 0 else 0.0 + return total_match_count, ratio, int(sum(gt_counter.values())) + + words_gt = text_gt.split() + words_ocr = ocr_str.split() + gt_counter = Counter(words_gt) + ocr_counter = Counter(words_ocr) + total_match_count = int(sum((gt_counter & ocr_counter).values())) + total_gt_count = len(words_gt) + ratio = total_match_count / total_gt_count if total_gt_count > 0 else 0.0 + return total_match_count, ratio, int(sum(gt_counter.values())) + + +def max_edit_distance_for_language(language_mode: Literal["EN", "ZH"]) -> int: + """ + OneIG ``MAX_EDIT_DISTANCE`` (100 for English, 50 for Chinese benchmark split). + + Parameters + ---------- + language_mode : {'EN', 'ZH'} + Benchmark language mode. + + Returns + ------- + int + Cap used in the composite text score. + """ + return 50 if language_mode == "ZH" else 100 + + +def oneig_per_sample_contributions(text_gt: str, ocr_raw: str) -> tuple[float, float, int, int]: + """ + Per-sample terms for OneIG aggregation (ED, CR, WAC numerator/denominator parts). + + Parameters + ---------- + text_gt : str + Ground-truth text (dataset field). + ocr_raw : str + Raw OCR string from the VLM. + + Returns + ------- + edit_distance : float + Levenshtein distance after OneIG preprocess. + completion_ratio : float + 1.0 if distance is zero, else 0.0. + match_count : int + Overlap count for WAC. + gt_total : int + Ground-truth token count term for WAC denominator. + """ + ocr_clean = clean_oneig_ocr_hallucinations(ocr_raw) + gt_pre = preprocess_string_oneig(text_gt) + ocr_pre = preprocess_string_oneig(ocr_clean) + ed = levenshtein(ocr_pre, gt_pre) + cr = 1.0 if ed == 0.0 else 0.0 + match_count, _, gt_total = calculate_char_match_ratio(gt_pre, ocr_pre) + return ed, cr, match_count, gt_total + + +def oneig_mean_text_score( + edit_distances: list[float], + completion_ratios: list[float], + match_counts: list[int], + gt_totals: list[int], + language_mode: Literal["EN", "ZH"], +) -> tuple[float, float, float, float]: + """ + Aggregate OneIG ED, CR, WAC and composite text score (higher is better). + + Parameters + ---------- + edit_distances : list of float + Per-sample edit distances. + completion_ratios : list of float + Per-sample completion indicators. + match_counts : list of int + Per-sample WAC numerators. + gt_totals : list of int + Per-sample WAC denominator terms. + language_mode : {'EN', 'ZH'} + Selects ``MAX_EDIT_DISTANCE``. + + Returns + ------- + ed_mean : float + Mean edit distance. + cr_mean : float + Mean completion ratio. + wac : float + Micro-averaged WAC: ``sum(match_counts) / sum(gt_totals)``. + text_score : float + Composite: ``1 - min(MAX_ED, ED) * (1 - CR) * (1 - WAC) / MAX_ED``. + """ + cap = float(max_edit_distance_for_language(language_mode)) + if not edit_distances: + return 0.0, 0.0, 0.0, 0.0 + ed_mean = float(sum(edit_distances) / len(edit_distances)) + cr_mean = float(sum(completion_ratios) / len(completion_ratios)) + denom = float(sum(gt_totals)) + wac = float(sum(match_counts) / denom) if denom > 0.0 else 0.0 + text_score = 1.0 - min(cap, ed_mean) * (1.0 - cr_mean) * (1.0 - wac) / cap + return ed_mean, cr_mean, wac, text_score diff --git a/src/pruna/evaluation/metrics/metric_torch.py b/src/pruna/evaluation/metrics/metric_torch.py index 4d329d86..b2c16f00 100644 --- a/src/pruna/evaluation/metrics/metric_torch.py +++ b/src/pruna/evaluation/metrics/metric_torch.py @@ -50,6 +50,26 @@ ) from pruna.logging.logger import pruna_logger +_PRUNA_TASK_ROUTING_KWARGS: tuple[str, ...] = ( + "vlm_type", + "model_name", + "structured_output", + "vlm_kwargs", + "api_key", +) + + +def _strip_task_routing_kwargs(kwargs: dict[str, Any]) -> None: + """ + Drop kwargs :class:`~pruna.evaluation.task.Task` passes when building mixed metric lists. + + Torchmetrics classes often end with ``**kwargs`` and would otherwise accept bogus keys + until a lower layer raises. Stripping here keeps :class:`TorchMetricWrapper` the single + choke point between Pruna routing and torchmetrics constructors. + """ + for key in _PRUNA_TASK_ROUTING_KWARGS: + kwargs.pop(key, None) + def default_update(metric: Metric, *args, **kwargs) -> None: """ @@ -124,9 +144,7 @@ def arniqa_update(metric: ARNIQA, preds: Any) -> None: def ssim_update( - metric: StructuralSimilarityIndexMeasure | MultiScaleStructuralSimilarityIndexMeasure, - preds: Any, - target: Any + metric: StructuralSimilarityIndexMeasure | MultiScaleStructuralSimilarityIndexMeasure, preds: Any, target: Any ) -> None: """ Update handler for SSIM or MS-SSIM metric. @@ -152,29 +170,22 @@ class TorchMetrics(Enum): """ Enumeration of torchmetrics metrics for evaluation. - This enum provides a tuple per member (metric_factory, update_fn, call_type): - metric_factory builds the metric (typically a torchmetrics class, or - functools.partial when some constructor arguments are fixed); update_fn is - an optional custom update handler; call_type describes how inputs are paired - for the metric. + Each member value is a ``(metric_factory, update_fn, call_type)`` tuple. Parameters ---------- value : tuple - Tuple holding metric_factory, update_fn, and call_type as described above. + ``(metric_factory, update_fn, call_type)`` for this enum member. names : str - The name of the enum member. + Enum member name. module : str - The module where the enum is defined. + Defining module name. qualname : str - The qualified name of the enum. + Qualified name of the enum class. type : type - The type of the enum. + Enum metaclass type. start : int - The start index for auto-numbering enum values. - boundary : enum.FlagBoundary or None - Boundary handling mode used by the Enum functional API for Flag and - IntFlag enums. + Auto-numbering start index for functional API enums. """ fid = (FrechetInceptionDistance, fid_update, "gt_y") @@ -246,6 +257,7 @@ def __new__(cls, metric_name: str, call_type: str = "", **kwargs) -> StatefulMet if metric_name == "clip_score" and call_type.startswith(PAIRWISE): from pruna.evaluation.metrics.metric_pairwise_clip import PairwiseClipScore + _strip_task_routing_kwargs(kwargs) return PairwiseClipScore(**kwargs) return super().__new__(cls) @@ -259,6 +271,7 @@ def __init__(self, metric_name: str, call_type: str = "", **kwargs) -> None: If the metric name is not supported. """ self.metric_name = metric_name + _strip_task_routing_kwargs(kwargs) super().__init__(kwargs.pop("device", None)) try: self.metric = TorchMetrics[metric_name](**kwargs) diff --git a/tests/conftest.py b/tests/conftest.py index 80d54825..6dff757b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,7 +1,14 @@ +import os from typing import Any import pytest +if os.environ.get("PRUNA_CI_CPU_ONLY") == "1": + import torch + + if hasattr(torch.backends, "mps"): + torch.backends.mps.is_available = lambda: False # type: ignore[method-assign] + # import all fixtures to make them avaliable for pytest from .fixtures import * # noqa: F403, F401 diff --git a/tests/evaluation/test_text_metrics.py b/tests/evaluation/test_text_metrics.py index a5931bae..d566390d 100644 --- a/tests/evaluation/test_text_metrics.py +++ b/tests/evaluation/test_text_metrics.py @@ -120,20 +120,3 @@ def test_oneig_alignment_all_padding_questions_yields_zero_without_vlm() -> None assert metric.compute().result == 0.0 mock_vlm.score.assert_not_called() - -def test_to_oneig_record_strips_null_questions_and_dependencies() -> None: - """Null-valued Q_D entries are filtered out at record construction time.""" - row = {"category": "Anime_Stylization", "id": "001", "class": "None", "prompt_en": "a cat"} - questions_by_key = { - "anime_001": { - "questions": {"1": "Is there a cat?", "21": None}, - "dependencies": {"1": [0], "21": None}, - } - } - record = _to_oneig_record(row, questions_by_key, {}, {}) - assert "21" not in record["questions"] - assert "21" not in record["dependencies"] - assert record["questions"] == {"1": "Is there a cat?"} - assert record["dependencies"] == {"1": [0]} - - diff --git a/tests/evaluation/test_vlm_base_infrastructure.py b/tests/evaluation/test_vlm_base_infrastructure.py index a4eaa139..b6ac9b1c 100644 --- a/tests/evaluation/test_vlm_base_infrastructure.py +++ b/tests/evaluation/test_vlm_base_infrastructure.py @@ -1,50 +1,12 @@ -"""Tests for VLM metrics (VQA, ImageEditScore, QAAccuracy, TextScore, VieScore) and vlm_utils helpers.""" +"""Tests for VLM base classes and vlm_utils (infrastructure PR only).""" from unittest.mock import MagicMock, patch import pytest import torch -from pruna.evaluation.metrics.metric_img_edit_score import ImageEditScoreMetric -from pruna.evaluation.metrics.metric_oneig_alignment import OneIGAlignmentMetric -from pruna.evaluation.metrics.metric_qa_accuracy import QAAccuracyMetric -from pruna.evaluation.metrics.metric_text_score import OneIGTextScoreMetric, TextScoreMetric -from pruna.evaluation.metrics.metric_vie_score import VieScoreMetric -from pruna.evaluation.metrics.metric_vqa import VQAMetric -from pruna.evaluation.metrics.result import MetricResult -from pruna.evaluation.metrics.vlm_base import BaseVLM, get_vlm -from pruna.evaluation.metrics.vlm_utils import ( - FloatOutput, - VLM_AUX_IMAGE_BYTES_KEY_ORDER, - get_score_from_response, - yes_no_first_token_id_groups, -) - -from ._vlm_batch_snapshot_helpers import ( - BenchmarkVlmBatchOutcome, - pred_tensor_from_auxiliaries, - safe_json_for_snapshot, - vlm_benchmark_batch_to_json_record, -) - -SMOL_VLM = "HuggingFaceTB/SmolVLM-256M-Instruct" - -_ALL_VLM = ( - VQAMetric, - ImageEditScoreMetric, - QAAccuracyMetric, - OneIGAlignmentMetric, - TextScoreMetric, - OneIGTextScoreMetric, - VieScoreMetric, -) - -_SLOW_SMOL_SUBSET = ( - VQAMetric, - OneIGAlignmentMetric, - ImageEditScoreMetric, - VieScoreMetric, -) +from pruna.evaluation.metrics.vlm_base import BaseVLM, LitellmVLM, get_vlm +from pruna.evaluation.metrics.vlm_utils import FloatOutput, get_score_from_response, yes_no_first_token_id_groups @pytest.mark.parametrize( @@ -64,115 +26,6 @@ def test_get_score_from_response(raw: object, expected: float) -> None: assert get_score_from_response(raw) == pytest.approx(expected) -def _dummy_image(batch: int = 1, size: int = 224) -> torch.Tensor: - return torch.rand(batch, 3, size, size) - - -def _update_metric(metric: object, prompts: list, images: torch.Tensor) -> None: - if isinstance(metric, OneIGAlignmentMetric): - metric.update( - prompts, - [ - { - "questions": {"1": "Is there a cat?", "2": "Is it sleeping?"}, - "dependencies": {"1": [0], "2": [1]}, - } - ], - images, - ) - elif isinstance(metric, QAAccuracyMetric): - metric.update( - prompts, - [{"questions": {"1": "Is there a cat?"}}], - images, - ) - elif isinstance(metric, (TextScoreMetric, OneIGTextScoreMetric)): - metric.update(prompts, ["cat"], images) - else: - metric.update(prompts, images, images) - - -@pytest.mark.cpu -@pytest.mark.slow -@pytest.mark.parametrize("metric_cls", _SLOW_SMOL_SUBSET) -def test_vlm_metrics_transformers_smolvlm(metric_cls: type) -> None: - """Smoke-test a subset with local SmolVLM (full matrix covered by litellm mock).""" - metric = metric_cls( - vlm_type="transformers", - model_name=SMOL_VLM, - device="cpu", - structured_output=True, - ) - images = _dummy_image(batch=1) - prompts = ["a cat"] - _update_metric(metric, prompts, images) - result = metric.compute() - assert result.name == metric.metric_name - assert isinstance(result.result, float) - if metric.higher_is_better: - assert 0.0 <= result.result <= 1.0 - else: - assert result.result >= 0.0 - - -@pytest.mark.cpu -@pytest.mark.parametrize("metric_cls", _ALL_VLM) -def test_vlm_metrics_litellm_mocked(metric_cls: type) -> None: - """Each VLM metric runs end-to-end with mocked litellm.""" - pytest.importorskip("litellm") - mock_response = MagicMock() - mock_response.choices = [MagicMock()] - if metric_cls in (VQAMetric, QAAccuracyMetric, OneIGAlignmentMetric): - mock_response.choices[0].message.content = '{"answer": "Yes"}' - else: - mock_response.choices[0].message.content = '{"score": 8}' - - with patch("litellm.completion") as mock_completion: - mock_completion.return_value = mock_response - - metric = metric_cls( - vlm_type="litellm", - model_name="gpt-4o", - device="cpu", - structured_output=True, - ) - images = _dummy_image(batch=1) - prompts = ["a cat"] - _update_metric(metric, prompts, images) - result = metric.compute() - - assert result.name == metric.metric_name - assert isinstance(result.result, float) - assert mock_completion.called - - -@pytest.mark.cpu -def test_vlm_metrics_empty_compute_returns_zero() -> None: - """No updates → compute is 0.0 (same for all stateful VLM metrics).""" - metric = VQAMetric( - vlm_type="transformers", - model_name=SMOL_VLM, - device="cpu", - structured_output=True, - ) - assert metric.compute().result == 0.0 - - -@pytest.mark.cpu -def test_vlm_metrics_custom_vlm() -> None: - """Custom VLM passed to VQAMetric is used instead of the default litellm backend.""" - mock_vlm = MagicMock(spec=BaseVLM) - mock_vlm.generate.return_value = ["Yes"] - mock_vlm.score.return_value = [1.0] - - metric = VQAMetric(vlm=mock_vlm, vlm_type="litellm", device="cpu", structured_output=True) - images = _dummy_image(batch=1) - prompts = ["a cat"] - metric.update(prompts, images, images) - assert metric.compute().result == 1.0 - mock_vlm.score.assert_called() - - @pytest.mark.cpu def test_get_vlm_returns_custom() -> None: """get_vlm returns the provided VLM instance unchanged.""" @@ -200,286 +53,15 @@ def test_get_vlm_requires_model_name_without_vlm() -> None: get_vlm(vlm=None, vlm_type="litellm") -@pytest.mark.cpu -@pytest.mark.parametrize( - "metric_cls, expected_name, expected_result", - [ - (TextScoreMetric, "text_score", 1.0), - (OneIGTextScoreMetric, "oneig_text_score", 1.0), - ], -) -def test_text_metrics_list_str_gt(metric_cls: type, expected_name: str, expected_result: float) -> None: - """Text metrics accept plain string ground-truth and return the expected score.""" - mock_vlm = MagicMock(spec=BaseVLM) - mock_vlm.generate.return_value = ["hello world"] - - metric = metric_cls(vlm=mock_vlm, vlm_type="litellm", device="cpu") - images = _dummy_image(batch=1) - metric.update(["a prompt"], ["hello world"], images) - result = metric.compute() - - assert result.result == expected_result - assert result.name == expected_name - mock_vlm.generate.assert_called_once() - - -@pytest.mark.cpu -def test_text_score_result_in_zero_one_range() -> None: - """TextScoreMetric must return a normalized score in [0, 1], not raw edit distance.""" - mock_vlm = MagicMock(spec=BaseVLM) - # VLM OCR returns something very different from ground truth (high edit distance) - mock_vlm.generate.return_value = ["completely wrong text abcdefghijklmnop"] - - metric = TextScoreMetric(vlm=mock_vlm, device="cpu") - images = _dummy_image(batch=1) - metric.update(["prompt"], ["hello"], images) - result = metric.compute() - - assert 0.0 <= result.result <= 1.0, f"TextScoreMetric must return [0,1], got {result.result}" - assert result.result < 0.5, f"Very different strings should score below 0.5, got {result.result}" - - -@pytest.mark.cpu -def test_text_score_perfect_match_is_one() -> None: - """TextScoreMetric: identical OCR and GT -> score 1.0.""" - mock_vlm = MagicMock(spec=BaseVLM) - mock_vlm.generate.return_value = ["hello world"] - - metric = TextScoreMetric(vlm=mock_vlm, device="cpu") - images = _dummy_image(batch=1) - metric.update(["prompt"], ["hello world"], images) - result = metric.compute() - - assert result.result == 1.0, f"Perfect match should give 1.0, got {result.result}" - assert result.higher_is_better is True - - -@pytest.mark.cpu -def test_text_score_registry_aliases() -> None: - """Registry aliases ocr_levenshtein and ocr_text_score resolve to the correct metric classes.""" - from pruna.evaluation.metrics.registry import MetricRegistry - - lev = MetricRegistry.get_metric("ocr_levenshtein", device="cpu", model_name="openai/gpt-4o") - comp = MetricRegistry.get_metric("ocr_text_score", device="cpu", model_name="openai/gpt-4o") - assert type(lev).__name__ == "TextScoreMetric" - assert type(comp).__name__ == "OneIGTextScoreMetric" - assert lev.metric_name == "text_score" - assert comp.metric_name == "oneig_text_score" - - -@pytest.mark.cpu -def test_oneig_text_score_utils_golden_composite() -> None: - """oneig_mean_text_score returns expected component values for a known input.""" - from pruna.evaluation.metrics.metric_text_score_utils import oneig_mean_text_score - - ed, cr, wac, composite = oneig_mean_text_score( - edit_distances=[10.0], - completion_ratios=[0.0], - match_counts=[2], - gt_totals=[4], - language_mode="EN", - ) - assert ed == 10.0 - assert cr == 0.0 - assert wac == 0.5 - assert composite == pytest.approx(0.95) - - _, _, _, zh = oneig_mean_text_score( - edit_distances=[30.0], - completion_ratios=[0.0], - match_counts=[0], - gt_totals=[1], - language_mode="ZH", - ) - assert zh == pytest.approx(0.4) - - -@pytest.mark.cpu -def test_qa_accuracy_all_or_nothing_partial_fail() -> None: - """all_or_nothing: if any question scores 0, the image score is 0.0 (not a partial mean).""" - mock_vlm = MagicMock(spec=BaseVLM) - # First question Yes (1.0), second question No (0.0) → mean=0.5, all_or_nothing=0.0 - mock_vlm.score.return_value = [1.0, 0.0] - - metric = QAAccuracyMetric(vlm=mock_vlm, device="cpu", aggregation="all_or_nothing") - metric.update( - ["a prompt"], - [{"questions": {"1": "Is there a cat?", "2": "Is it blue?"}}], - _dummy_image(batch=1), - ) - result = metric.compute() - assert result.result == 0.0, f"Expected 0.0 for all_or_nothing with one No, got {result.result}" - - -@pytest.mark.cpu -def test_qa_accuracy_all_or_nothing_all_yes() -> None: - """all_or_nothing: all Yes → score 1.0.""" - mock_vlm = MagicMock(spec=BaseVLM) - mock_vlm.score.return_value = [1.0, 1.0] - - metric = QAAccuracyMetric(vlm=mock_vlm, device="cpu", aggregation="all_or_nothing") - metric.update( - ["a prompt"], - [{"questions": {"1": "Is there a cat?", "2": "Is it blue?"}}], - _dummy_image(batch=1), - ) - result = metric.compute() - assert result.result == 1.0, f"Expected 1.0 for all_or_nothing with all Yes, got {result.result}" - - -@pytest.mark.cpu -def test_qa_accuracy_invalid_aggregation_raises() -> None: - """qa_accuracy rejects aggregation values other than mean / all_or_nothing.""" - mock_vlm = MagicMock(spec=BaseVLM) - with pytest.raises(ValueError, match="aggregation"): - QAAccuracyMetric(vlm=mock_vlm, device="cpu", aggregation="median") - - -@pytest.mark.cpu -def test_vie_score_tie_uses_source_from_gt_and_two_image_sc() -> None: - """With ``source_image_bytes`` in gt, VieScore calls two-image SC then PQ on the edited image.""" - from io import BytesIO - - from PIL import Image - - buf = BytesIO() - Image.new("RGB", (8, 8), color=(0, 0, 200)).save(buf, format="PNG") - src_bytes = buf.getvalue() - - mock_vlm = MagicMock() - mock_vlm.generate_with_image_lists.return_value = ['{"score": [8.0, 8.0], "reasoning": "ok"}'] - mock_vlm.generate.return_value = ['{"score": [9.0, 9.0], "reasoning": "ok"}'] - - metric = VieScoreMetric(vlm=mock_vlm, device="cpu", structured_output=True) - pred = _dummy_image(batch=1) - metric.update( - ["make the sky purple"], - [{"source_image_bytes": src_bytes}], - pred, - ) - result = metric.compute() - - assert mock_vlm.generate_with_image_lists.called - assert mock_vlm.generate.called - assert 0.0 <= result.result <= 1.0 - - -@pytest.mark.cpu -def test_vie_score_uses_get_score_from_response() -> None: - """VieScoreMetric ``t2i`` path parses JSON ``score`` lists via ``viescore_min_scores_0_10``.""" - mock_vlm = MagicMock(spec=BaseVLM) - # LitellmVLM returns model_dump_json() for structured outputs → JSON string (two SC + two PQ sub-scores) - mock_vlm.generate.return_value = ['{"score": [8.0, 8.0], "reasoning": ""}'] - - metric = VieScoreMetric(vlm=mock_vlm, device="cpu", structured_output=True) - metric.update(["a cat on a sofa"], _dummy_image(batch=1), _dummy_image(batch=1)) - result = metric.compute() - - # min(SC)=8, min(PQ)=8 → sqrt(8 * 8) / 10 = 0.8 - assert abs(result.result - 0.8) < 0.01, f"Expected ~0.8, got {result.result}" - - -@pytest.mark.cpu -def test_img_edit_score_negative_response_clamped() -> None: - """img_edit_score must be non-negative even when the VLM generates a negative JSON score. - - Regression for: Outlines constrained decoding can emit {"score": -10} despite the - FloatOutput JSON schema specifying minimum=0, because Outlines does not enforce numeric - bounds during token sampling. The fix is max(0.0, ...) in get_score_from_response. - """ - mock_vlm = MagicMock(spec=BaseVLM) - # Simulate Outlines generating a negative value (the bug scenario) - mock_vlm.generate.return_value = ['{"score": -10.0}'] - - metric = ImageEditScoreMetric(vlm=mock_vlm, device="cpu", structured_output=True) - metric.update(["replace the boot with a mug"], torch.zeros(1), _dummy_image(batch=1)) - result = metric.compute() - - assert result.result >= 0.0, f"img_edit_score must be >= 0, got {result.result}" - - -@pytest.mark.cpu -def test_qa_accuracy_all_or_nothing_ambiguous_score() -> None: - """all_or_nothing: score exactly 0.5 (ambiguous) is treated as No → result 0.0.""" - mock_vlm = MagicMock(spec=BaseVLM) - mock_vlm.score.return_value = [0.5] - - metric = QAAccuracyMetric(vlm=mock_vlm, device="cpu", aggregation="all_or_nothing") - metric.update( - ["a prompt"], - [{"questions": {"1": "Is there a cat?"}}], - _dummy_image(batch=1), - ) - result = metric.compute() - assert result.result == 0.0, f"Score 0.5 should be treated as No (ambiguous), got {result.result}" - - -@pytest.mark.cpu -@pytest.mark.slow -def test_yes_no_token_ids_smolvlm_nonempty() -> None: - """SmolVLM tokenizer must yield non-empty disjoint yes/no prefix ids for VQAScore scoring.""" - pytest.importorskip("transformers") - from transformers import AutoTokenizer - - tok = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolVLM-256M-Instruct") - yes_ids, no_ids = yes_no_first_token_id_groups(tok) - assert len(yes_ids) > 0, "SmolVLM tokenizer has no 'Yes'-prefix token ids" - assert len(no_ids) > 0, "SmolVLM tokenizer has no 'No'-prefix token ids" - assert not (set(yes_ids) & set(no_ids)), "yes_ids and no_ids must be disjoint" - - -@pytest.mark.cpu -def test_img_edit_score_uses_prompt_from_x() -> None: - """img_edit_score must score the edited image against the instruction from x, not gt.""" - mock_vlm = MagicMock(spec=BaseVLM) - mock_vlm.generate.return_value = ['{"score": 9}'] - - metric = ImageEditScoreMetric(vlm=mock_vlm, device="cpu") - pred = _dummy_image(batch=1) - metric.update( - ["replace the cat with a dog"], # x = instruction - pred, # gt = unused for y_x - pred, # outputs = edited image - ) - result = metric.compute() - - call_args = mock_vlm.generate.call_args - prompt_sent = call_args[0][1][0] # second positional arg = prompts list, first item - assert "replace the cat with a dog" in prompt_sent, f"Instruction not in VLM prompt. Got: {prompt_sent}" - assert abs(result.result - 0.9) < 0.01, f"Expected ~0.9, got {result.result}" - - -@pytest.mark.cpu -def test_vie_score_geditbench_gap_documented() -> None: - """VieScoreMetric infers text--image editing from ``source_image_bytes`` in aux (no ``task_type``). - - This test fails if a ``task_type`` parameter is added to ``__init__`` without updating - GEditBench integration tests and benchmark copy accordingly. - """ - import inspect - - sig = inspect.signature(VieScoreMetric.__init__) - assert "task_type" not in sig.parameters, ( - "VieScoreMetric now has task_type — update GEditBench docs and e2e tests, then remove this sentinel." - ) - - @pytest.mark.cpu def test_litellm_logprob_aggregation_sums_all_yes_tokens() -> None: """LitellmVLM logprob scoring must sum all yes-prefix token probs, not return the first.""" pytest.importorskip("litellm") import math - from unittest.mock import MagicMock, patch import numpy as np from PIL import Image - from pruna.evaluation.metrics.vlm_base import LitellmVLM - - # Simulate top_logprobs for first output token: - # "Yes" → logprob=-2.303 (p≈0.10), " yes" → logprob=-2.996 (p≈0.05) → total p_yes≈0.15 - # "No" → logprob=-1.609 (p≈0.20), " no" → logprob=-2.303 (p≈0.10) → total p_no≈0.30 - # normalized: p_yes/(p_yes+p_no) ≈ 0.15/0.45 ≈ 0.333 def make_top_logprob(token, logprob): t = MagicMock() t.token = token @@ -510,175 +92,17 @@ def make_top_logprob(token, logprob): img = Image.fromarray(np.zeros((32, 32, 3), dtype="uint8")) score = vlm._score_with_logprobs(img, "Is there a cat?", "Yes") - # Should be ~0.333 (p_yes=0.15 / (p_yes+p_no)=0.45), not just 0.10 (first match) assert 0.28 < score < 0.40, f"Expected ~0.333 (sum-normalized), got {score}" @pytest.mark.cpu @pytest.mark.slow -def test_vqa_probability_score_normalized() -> None: - """P(Yes) from TransformersVLM.score use_probability=True is in [0, 1].""" +def test_yes_no_token_ids_smolvlm_nonempty() -> None: + """SmolVLM tokenizer yields non-empty yes/no prefix id groups.""" pytest.importorskip("transformers") - import numpy as np - from PIL import Image - - from pruna.evaluation.metrics.vlm_base import TransformersVLM - - vlm = TransformersVLM( - model_name="HuggingFaceTB/SmolVLM-256M-Instruct", - device="cpu", - use_outlines=False, - ) - img = Image.fromarray(np.zeros((32, 32, 3), dtype="uint8")) - scores = vlm.score([img], ["Is there a cat?"], ["Yes"], use_probability=True) - assert len(scores) == 1 - assert 0.0 <= scores[0] <= 1.0, f"P(Yes) must be in [0, 1], got {scores[0]}" - - -# --------------------------------------------------------------------------- -# vlm_benchmark_batch_to_json_record serialization tests -# --------------------------------------------------------------------------- - - -def test_vlm_benchmark_batch_to_json_record_serializes_batch() -> None: - """Record includes prompts, pred shape, and metric fields.""" - mr = MetricResult(name="qa_accuracy", params={}, result=0.25, higher_is_better=True) - outcome = BenchmarkVlmBatchOutcome( - result=mr, - prompts=["prompt"], - auxiliaries=[{"path": "/tmp/x.png"}], - pred=torch.zeros(1, 3, 8, 8), - ) - rec = vlm_benchmark_batch_to_json_record( - outcome, - benchmark_key="GenEval", - benchmark_name="GenEval", - metric_name="qa_accuracy", - vlm_type="transformers", - model_name="m", - device="cpu", - ) - assert rec["inputs"]["prompts"] == ["prompt"] - assert rec["pred"]["shape"] == [1, 3, 8, 8] - assert rec["metric_result"]["result"] == 0.25 - - -def test_safe_json_handles_bytes_without_expanding() -> None: - """Bytes values in aux (e.g. source_image_bytes) are summarized, not expanded to str repr.""" - result = safe_json_for_snapshot({"source_image_bytes": b"\xff\xd8\xff" * 1000, "name": "test"}) - assert result["source_image_bytes"] == {"bytes_len": 3000} - assert result["name"] == "test" - - -def test_vlm_benchmark_batch_to_json_record_preserves_null_question_slots() -> None: - """Padded ``None`` question labels stay JSON null, not the string ``"None"``.""" - mr = MetricResult(name="oneig_alignment", params={}, result=1.0, higher_is_better=True) - outcome = BenchmarkVlmBatchOutcome( - result=mr, - prompts=["p"], - auxiliaries=[{"questions": {"1": "Are there boys?", "21": None}, "subset": "Anime_Stylization"}], - pred=torch.zeros(1, 3, 8, 8), - ) - rec = vlm_benchmark_batch_to_json_record( - outcome, - benchmark_key="OneIGAnimeStylization", - benchmark_name="OneIG Anime Stylization", - metric_name="oneig_alignment", - vlm_type="transformers", - model_name="m", - device="cpu", - ) - qs = rec["inputs"]["auxiliary_0"]["questions"] - assert qs["1"] == "Are there boys?" - assert qs["21"] is None - - -# --------------------------------------------------------------------------- -# pred_tensor_from_auxiliaries (test helper, wraps pil_rgb_from_aux_image_bytes) tests -# --------------------------------------------------------------------------- - - -def _make_jpeg_bytes(h: int = 32, w: int = 32) -> bytes: - """Return a tiny JPEG-encoded RGB image as bytes (test helper).""" - import io - - import numpy as np - from PIL import Image - - arr = (np.random.rand(h, w, 3) * 255).astype("uint8") - buf = io.BytesIO() - Image.fromarray(arr).save(buf, format="JPEG") - return buf.getvalue() - - -@pytest.mark.cpu -def test_pred_from_auxiliaries_uses_source_image_bytes() -> None: - """pred_tensor_from_auxiliaries decodes source_image_bytes into a float tensor in [0, 1].""" - src_bytes = _make_jpeg_bytes() - aux = [{"source_image_bytes": src_bytes, "category": "background_change"}] - pred = pred_tensor_from_auxiliaries(aux, size=64) - - assert pred.shape == (1, 3, 64, 64), f"Expected (1,3,64,64), got {pred.shape}" - assert pred.min() >= 0.0 and pred.max() <= 1.0, "Pixel values must be in [0, 1]" - - -@pytest.mark.cpu -def test_pred_from_auxiliaries_falls_back_to_noise_without_source_image() -> None: - """pred_tensor_from_auxiliaries returns random noise when no source_image_bytes is present.""" - aux = [{"category": "single_object"}] - pred = pred_tensor_from_auxiliaries(aux, size=32) - assert pred.shape == (1, 3, 32, 32) - assert pred.min() >= 0.0 and pred.max() <= 1.0 - - -@pytest.mark.cpu -def test_pred_from_auxiliaries_mixed_batch() -> None: - """Batch with one source image and one missing falls back per-item.""" - src_bytes = _make_jpeg_bytes() - aux = [ - {"source_image_bytes": src_bytes, "category": "color_alter"}, - {"category": "style_change"}, # no source image - ] - pred = pred_tensor_from_auxiliaries(aux, size=32) - assert pred.shape == (2, 3, 32, 32) - assert pred.min() >= 0.0 and pred.max() <= 1.0 - - -@pytest.mark.cpu -def test_pred_from_auxiliaries_generic_bytes_scan() -> None: - """pred_tensor_from_auxiliaries discovers image bytes under an unknown field name (generic scan).""" - src_bytes = _make_jpeg_bytes() - aux = [{"my_custom_image_bytes": src_bytes, "category": "motion_change"}] - pred = pred_tensor_from_auxiliaries(aux, size=32) - assert pred.shape == (1, 3, 32, 32) - assert pred.min() >= 0.0 and pred.max() <= 1.0 - - -@pytest.mark.cpu -def test_pred_from_auxiliaries_known_names_take_priority() -> None: - """Known field names are resolved before the generic bytes scan.""" - src_bytes_known = _make_jpeg_bytes(16, 16) - src_bytes_unknown = _make_jpeg_bytes(32, 32) - first_known = VLM_AUX_IMAGE_BYTES_KEY_ORDER[0] - aux = [{"other_bytes": src_bytes_unknown, first_known: src_bytes_known}] - pred = pred_tensor_from_auxiliaries(aux, size=16) - # Should use the known key (16x16 image → 16x16 crop); generic scan would pick 32x32 - assert pred.shape == (1, 3, 16, 16) - - -@pytest.mark.cpu -def test_pred_from_auxiliaries_require_source_image_raises_when_missing() -> None: - """require_source_image=True raises ValueError instead of silently returning noise.""" - aux = [{"category": "replace"}] # no image bytes - with pytest.raises(ValueError, match="require_source_image=True"): - pred_tensor_from_auxiliaries(aux, size=32, require_source_image=True) - + from transformers import AutoTokenizer -@pytest.mark.cpu -def test_pred_from_auxiliaries_require_source_image_succeeds_when_present() -> None: - """require_source_image=True succeeds and decodes bytes when source_image_bytes is present.""" - src_bytes = _make_jpeg_bytes() - aux = [{"source_image_bytes": src_bytes, "category": "replace"}] - pred = pred_tensor_from_auxiliaries(aux, size=32, require_source_image=True) - assert pred.shape == (1, 3, 32, 32) - assert pred.min() >= 0.0 and pred.max() <= 1.0 + tok = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolVLM-256M-Instruct") + yes_ids, no_ids = yes_no_first_token_id_groups(tok) + assert yes_ids + assert no_ids