Skip to content

Commit ad97e52

Browse files
authored
feat: Factory and plugin-capability for Layout and Table models (#2637)
* feat: Scaffolding for layout and table model plugin factory Signed-off-by: Christoph Auer <[email protected]> * Add missing files Signed-off-by: Christoph Auer <[email protected]> * Add base options classes for layout and table Signed-off-by: Christoph Auer <[email protected]> --------- Signed-off-by: Christoph Auer <[email protected]>
1 parent dcb57bf commit ad97e52

File tree

11 files changed

+346
-171
lines changed

11 files changed

+346
-171
lines changed

docling/datamodel/pipeline_options.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -59,9 +59,14 @@ class TableFormerMode(str, Enum):
5959
ACCURATE = "accurate"
6060

6161

62-
class TableStructureOptions(BaseModel):
62+
class BaseTableStructureOptions(BaseOptions):
63+
"""Base options for table structure models."""
64+
65+
66+
class TableStructureOptions(BaseTableStructureOptions):
6367
"""Options for the table structure."""
6468

69+
kind: ClassVar[str] = "docling_tableformer"
6570
do_cell_matching: bool = (
6671
True
6772
# True: Matches predictions back to PDF cells. Can break table output if PDF cells
@@ -308,19 +313,25 @@ class VlmPipelineOptions(PaginatedPipelineOptions):
308313
)
309314

310315

311-
class LayoutOptions(BaseModel):
312-
"""Options for layout processing."""
316+
class BaseLayoutOptions(BaseOptions):
317+
"""Base options for layout models."""
313318

314-
create_orphan_clusters: bool = True # Whether to create clusters for orphaned cells
315319
keep_empty_clusters: bool = (
316320
False # Whether to keep clusters that contain no text cells
317321
)
318-
model_spec: LayoutModelConfig = DOCLING_LAYOUT_HERON
319322
skip_cell_assignment: bool = (
320323
False # Skip cell-to-cluster assignment for VLM-only processing
321324
)
322325

323326

327+
class LayoutOptions(BaseLayoutOptions):
328+
"""Options for layout processing."""
329+
330+
kind: ClassVar[str] = "docling_layout_default"
331+
create_orphan_clusters: bool = True # Whether to create clusters for orphaned cells
332+
model_spec: LayoutModelConfig = DOCLING_LAYOUT_HERON
333+
334+
324335
class AsrPipelineOptions(PipelineOptions):
325336
asr_options: Union[InlineAsrOptions] = asr_model_specs.WHISPER_TINY
326337

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
from __future__ import annotations
2+
3+
from abc import ABC, abstractmethod
4+
from collections.abc import Iterable, Sequence
5+
from typing import Type
6+
7+
from docling.datamodel.base_models import LayoutPrediction, Page
8+
from docling.datamodel.document import ConversionResult
9+
from docling.datamodel.pipeline_options import BaseLayoutOptions
10+
from docling.models.base_model import BaseModelWithOptions, BasePageModel
11+
12+
13+
class BaseLayoutModel(BasePageModel, BaseModelWithOptions, ABC):
14+
"""Shared interface for layout models."""
15+
16+
@classmethod
17+
@abstractmethod
18+
def get_options_type(cls) -> Type[BaseLayoutOptions]:
19+
"""Return the options type supported by this layout model."""
20+
21+
@abstractmethod
22+
def predict_layout(
23+
self,
24+
conv_res: ConversionResult,
25+
pages: Sequence[Page],
26+
) -> Sequence[LayoutPrediction]:
27+
"""Produce layout predictions for the provided pages."""
28+
29+
def __call__(
30+
self,
31+
conv_res: ConversionResult,
32+
page_batch: Iterable[Page],
33+
) -> Iterable[Page]:
34+
pages = list(page_batch)
35+
predictions = self.predict_layout(conv_res, pages)
36+
37+
for page, prediction in zip(pages, predictions):
38+
page.predictions.layout = prediction
39+
yield page

docling/models/base_table_model.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
from __future__ import annotations
2+
3+
from abc import ABC, abstractmethod
4+
from collections.abc import Iterable, Sequence
5+
from typing import Type
6+
7+
from docling.datamodel.base_models import Page, TableStructurePrediction
8+
from docling.datamodel.document import ConversionResult
9+
from docling.datamodel.pipeline_options import BaseTableStructureOptions
10+
from docling.models.base_model import BaseModelWithOptions, BasePageModel
11+
12+
13+
class BaseTableStructureModel(BasePageModel, BaseModelWithOptions, ABC):
14+
"""Shared interface for table structure models."""
15+
16+
enabled: bool
17+
18+
@classmethod
19+
@abstractmethod
20+
def get_options_type(cls) -> Type[BaseTableStructureOptions]:
21+
"""Return the options type supported by this table model."""
22+
23+
@abstractmethod
24+
def predict_tables(
25+
self,
26+
conv_res: ConversionResult,
27+
pages: Sequence[Page],
28+
) -> Sequence[TableStructurePrediction]:
29+
"""Produce table structure predictions for the provided pages."""
30+
31+
def __call__(
32+
self,
33+
conv_res: ConversionResult,
34+
page_batch: Iterable[Page],
35+
) -> Iterable[Page]:
36+
if not getattr(self, "enabled", True):
37+
yield from page_batch
38+
return
39+
40+
pages = list(page_batch)
41+
predictions = self.predict_tables(conv_res, pages)
42+
43+
for page, prediction in zip(pages, predictions):
44+
page.predictions.tablestructure = prediction
45+
yield page

docling/models/factories/__init__.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
import logging
22
from functools import lru_cache
33

4+
from docling.models.factories.layout_factory import LayoutFactory
45
from docling.models.factories.ocr_factory import OcrFactory
56
from docling.models.factories.picture_description_factory import (
67
PictureDescriptionFactory,
78
)
9+
from docling.models.factories.table_factory import TableStructureFactory
810

911
logger = logging.getLogger(__name__)
1012

@@ -25,3 +27,21 @@ def get_picture_description_factory(
2527
factory.load_from_plugins(allow_external_plugins=allow_external_plugins)
2628
logger.info("Registered picture descriptions: %r", factory.registered_kind)
2729
return factory
30+
31+
32+
@lru_cache
33+
def get_layout_factory(allow_external_plugins: bool = False) -> LayoutFactory:
34+
factory = LayoutFactory()
35+
factory.load_from_plugins(allow_external_plugins=allow_external_plugins)
36+
logger.info("Registered layout engines: %r", factory.registered_kind)
37+
return factory
38+
39+
40+
@lru_cache
41+
def get_table_structure_factory(
42+
allow_external_plugins: bool = False,
43+
) -> TableStructureFactory:
44+
factory = TableStructureFactory()
45+
factory.load_from_plugins(allow_external_plugins=allow_external_plugins)
46+
logger.info("Registered table structure engines: %r", factory.registered_kind)
47+
return factory
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from docling.models.base_layout_model import BaseLayoutModel
2+
from docling.models.factories.base_factory import BaseFactory
3+
4+
5+
class LayoutFactory(BaseFactory[BaseLayoutModel]):
6+
def __init__(self, *args, **kwargs):
7+
super().__init__("layout_engines", *args, **kwargs)
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from docling.models.base_table_model import BaseTableStructureModel
2+
from docling.models.factories.base_factory import BaseFactory
3+
4+
5+
class TableStructureFactory(BaseFactory[BaseTableStructureModel]):
6+
def __init__(self, *args, **kwargs):
7+
super().__init__("table_structure_engines", *args, **kwargs)

docling/models/layout_model.py

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import copy
22
import logging
33
import warnings
4-
from collections.abc import Iterable
4+
from collections.abc import Sequence
55
from pathlib import Path
66
from typing import List, Optional, Union
77

@@ -15,7 +15,7 @@
1515
from docling.datamodel.layout_model_specs import DOCLING_LAYOUT_V2, LayoutModelConfig
1616
from docling.datamodel.pipeline_options import LayoutOptions
1717
from docling.datamodel.settings import settings
18-
from docling.models.base_model import BasePageModel
18+
from docling.models.base_layout_model import BaseLayoutModel
1919
from docling.models.utils.hf_model_download import download_hf_model
2020
from docling.utils.accelerator_utils import decide_device
2121
from docling.utils.layout_postprocessor import LayoutPostprocessor
@@ -25,7 +25,7 @@
2525
_log = logging.getLogger(__name__)
2626

2727

28-
class LayoutModel(BasePageModel):
28+
class LayoutModel(BaseLayoutModel):
2929
TEXT_ELEM_LABELS = [
3030
DocItemLabel.TEXT,
3131
DocItemLabel.FOOTNOTE,
@@ -86,6 +86,10 @@ def __init__(
8686
num_threads=accelerator_options.num_threads,
8787
)
8888

89+
@classmethod
90+
def get_options_type(cls) -> type[LayoutOptions]:
91+
return LayoutOptions
92+
8993
@staticmethod
9094
def download_models(
9195
local_dir: Optional[Path] = None,
@@ -145,11 +149,13 @@ def draw_clusters_and_cells_side_by_side(
145149
out_file = out_path / f"{mode_prefix}_layout_page_{page.page_no:05}.png"
146150
combined_image.save(str(out_file), format="png")
147151

148-
def __call__(
149-
self, conv_res: ConversionResult, page_batch: Iterable[Page]
150-
) -> Iterable[Page]:
151-
# Convert to list to allow multiple iterations
152-
pages = list(page_batch)
152+
def predict_layout(
153+
self,
154+
conv_res: ConversionResult,
155+
pages: Sequence[Page],
156+
) -> Sequence[LayoutPrediction]:
157+
# Convert to list to ensure predictable iteration
158+
pages = list(pages)
153159

154160
# Separate valid and invalid pages
155161
valid_pages = []
@@ -167,12 +173,6 @@ def __call__(
167173
valid_pages.append(page)
168174
valid_page_images.append(page_image)
169175

170-
_log.debug(f"{len(pages)=}")
171-
if pages:
172-
_log.debug(f"{pages[0].page_no}-{pages[-1].page_no}")
173-
_log.debug(f"{len(valid_pages)=}")
174-
_log.debug(f"{len(valid_page_images)=}")
175-
176176
# Process all valid pages with batch prediction
177177
batch_predictions = []
178178
if valid_page_images:
@@ -182,11 +182,14 @@ def __call__(
182182
)
183183

184184
# Process each page with its predictions
185+
layout_predictions: list[LayoutPrediction] = []
185186
valid_page_idx = 0
186187
for page in pages:
187188
assert page._backend is not None
188189
if not page._backend.is_valid():
189-
yield page
190+
existing_prediction = page.predictions.layout or LayoutPrediction()
191+
page.predictions.layout = existing_prediction
192+
layout_predictions.append(existing_prediction)
190193
continue
191194

192195
page_predictions = batch_predictions[valid_page_idx]
@@ -233,11 +236,14 @@ def __call__(
233236
np.mean([c.confidence for c in processed_cells if c.from_ocr])
234237
)
235238

236-
page.predictions.layout = LayoutPrediction(clusters=processed_clusters)
239+
prediction = LayoutPrediction(clusters=processed_clusters)
240+
page.predictions.layout = prediction
237241

238242
if settings.debug.visualize_layout:
239243
self.draw_clusters_and_cells_side_by_side(
240244
conv_res, page, processed_clusters, mode_prefix="postprocessed"
241245
)
242246

243-
yield page
247+
layout_predictions.append(prediction)
248+
249+
return layout_predictions

docling/models/plugins/defaults.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,3 +28,23 @@ def picture_description():
2828
PictureDescriptionApiModel,
2929
]
3030
}
31+
32+
33+
def layout_engines():
34+
from docling.models.layout_model import LayoutModel
35+
36+
return {
37+
"layout_engines": [
38+
LayoutModel,
39+
]
40+
}
41+
42+
43+
def table_structure_engines():
44+
from docling.models.table_structure_model import TableStructureModel
45+
46+
return {
47+
"table_structure_engines": [
48+
TableStructureModel,
49+
]
50+
}

0 commit comments

Comments
 (0)