11import copy
22import logging
33import warnings
4- from collections .abc import Iterable
4+ from collections .abc import Sequence
55from pathlib import Path
66from typing import List , Optional , Union
77
1515from docling .datamodel .layout_model_specs import DOCLING_LAYOUT_V2 , LayoutModelConfig
1616from docling .datamodel .pipeline_options import LayoutOptions
1717from docling .datamodel .settings import settings
18- from docling .models .base_model import BasePageModel
18+ from docling .models .base_layout_model import BaseLayoutModel
1919from docling .models .utils .hf_model_download import download_hf_model
2020from docling .utils .accelerator_utils import decide_device
2121from docling .utils .layout_postprocessor import LayoutPostprocessor
2525_log = logging .getLogger (__name__ )
2626
2727
28- class LayoutModel (BasePageModel ):
28+ class LayoutModel (BaseLayoutModel ):
2929 TEXT_ELEM_LABELS = [
3030 DocItemLabel .TEXT ,
3131 DocItemLabel .FOOTNOTE ,
@@ -86,6 +86,10 @@ def __init__(
8686 num_threads = accelerator_options .num_threads ,
8787 )
8888
89+ @classmethod
90+ def get_options_type (cls ) -> type [LayoutOptions ]:
91+ return LayoutOptions
92+
8993 @staticmethod
9094 def download_models (
9195 local_dir : Optional [Path ] = None ,
@@ -145,11 +149,13 @@ def draw_clusters_and_cells_side_by_side(
145149 out_file = out_path / f"{ mode_prefix } _layout_page_{ page .page_no :05} .png"
146150 combined_image .save (str (out_file ), format = "png" )
147151
148- def __call__ (
149- self , conv_res : ConversionResult , page_batch : Iterable [Page ]
150- ) -> Iterable [Page ]:
151- # Convert to list to allow multiple iterations
152- pages = list (page_batch )
152+ def predict_layout (
153+ self ,
154+ conv_res : ConversionResult ,
155+ pages : Sequence [Page ],
156+ ) -> Sequence [LayoutPrediction ]:
157+ # Convert to list to ensure predictable iteration
158+ pages = list (pages )
153159
154160 # Separate valid and invalid pages
155161 valid_pages = []
@@ -167,12 +173,6 @@ def __call__(
167173 valid_pages .append (page )
168174 valid_page_images .append (page_image )
169175
170- _log .debug (f"{ len (pages )= } " )
171- if pages :
172- _log .debug (f"{ pages [0 ].page_no } -{ pages [- 1 ].page_no } " )
173- _log .debug (f"{ len (valid_pages )= } " )
174- _log .debug (f"{ len (valid_page_images )= } " )
175-
176176 # Process all valid pages with batch prediction
177177 batch_predictions = []
178178 if valid_page_images :
@@ -182,11 +182,14 @@ def __call__(
182182 )
183183
184184 # Process each page with its predictions
185+ layout_predictions : list [LayoutPrediction ] = []
185186 valid_page_idx = 0
186187 for page in pages :
187188 assert page ._backend is not None
188189 if not page ._backend .is_valid ():
189- yield page
190+ existing_prediction = page .predictions .layout or LayoutPrediction ()
191+ page .predictions .layout = existing_prediction
192+ layout_predictions .append (existing_prediction )
190193 continue
191194
192195 page_predictions = batch_predictions [valid_page_idx ]
@@ -233,11 +236,14 @@ def __call__(
233236 np .mean ([c .confidence for c in processed_cells if c .from_ocr ])
234237 )
235238
236- page .predictions .layout = LayoutPrediction (clusters = processed_clusters )
239+ prediction = LayoutPrediction (clusters = processed_clusters )
240+ page .predictions .layout = prediction
237241
238242 if settings .debug .visualize_layout :
239243 self .draw_clusters_and_cells_side_by_side (
240244 conv_res , page , processed_clusters , mode_prefix = "postprocessed"
241245 )
242246
243- yield page
247+ layout_predictions .append (prediction )
248+
249+ return layout_predictions
0 commit comments