diff --git a/src/layoutparser/__init__.py b/src/layoutparser/__init__.py index 31f87d9..95cefe2 100644 --- a/src/layoutparser/__init__.py +++ b/src/layoutparser/__init__.py @@ -16,4 +16,8 @@ from .models import ( Detectron2LayoutModel +) + +from .io import ( + load_json ) \ No newline at end of file diff --git a/src/layoutparser/elements.py b/src/layoutparser/elements.py index a7e3637..b29b2e1 100644 --- a/src/layoutparser/elements.py +++ b/src/layoutparser/elements.py @@ -1,26 +1,31 @@ +from typing import List, Union, Dict, Dict, Any from abc import ABC, abstractmethod -from collections.abc import Iterable +from collections.abc import Iterable, MutableSequence from copy import copy, deepcopy from inspect import getmembers, isfunction import warnings import functools + import numpy as np import pandas as pd from PIL import Image from cv2 import getPerspectiveTransform as _getPerspectiveTransform from cv2 import warpPerspective as _warpPerspective -__all__ = ['Interval', 'Rectangle', 'Quadrilateral', 'TextBlock', 'Layout'] +__all__ = ["Interval", "Rectangle", "Quadrilateral", "TextBlock", "Layout"] def _cvt_coordinates_to_points(coords): x_1, y_1, x_2, y_2 = coords - return np.array([[x_1, y_1], # Top Left - [x_2, y_1], # Top Right - [x_2, y_2], # Bottom Right - [x_1, y_2], # Bottom Left - ]) + return np.array( + [ + [x_1, y_1], # Top Left + [x_2, y_1], # Top Right + [x_2, y_2], # Bottom Right + [x_1, y_2], # Bottom Left + ] + ) def _cvt_points_to_coordinates(points): @@ -39,7 +44,7 @@ def _perspective_transformation(M, points, is_inv=False): src_mid = np.hstack([points, np.ones((points.shape[0], 1))]).T # 3x4 dst_mid = np.matmul(M, src_mid) - dst = (dst_mid/dst_mid[-1]).T[:, :2] # 4x2 + dst = (dst_mid / dst_mid[-1]).T[:, :2] # 4x2 return dst @@ -56,28 +61,8 @@ def _vertice_in_polygon(vertice, polygon_points): # If the points are ordered clockwise, the det should <=0 -def _parse_datatype_from_feature_names(feature_names): - - type_feature_map = { - Interval: set(Interval.feature_names), - Rectangle: set(Rectangle.feature_names), - Quadrilateral: set(Quadrilateral.feature_names) - } - - for cls, fnames in type_feature_map.items(): - if set(feature_names) == fnames: - return cls - - raise ValueError( - "\n " - "\n The input feature is incompatible with the designated format." - "\n Please check the tutorials for more details." - "\n " - ) - - def _polygon_area(xs, ys): - """Calculate the area of polygons using + """Calculate the area of polygons using `Shoelace Formula `_. Args: @@ -89,7 +74,7 @@ def _polygon_area(xs, ys): # The formula is equivalent to the original one indicated in the wikipedia # page. - return 0.5*np.abs(np.dot(xs, np.roll(ys, 1)) - np.dot(ys, np.roll(xs, 1))) + return 0.5 * np.abs(np.dot(xs, np.roll(ys, 1)) - np.dot(ys, np.roll(xs, 1))) def mixin_textblock_meta(func): @@ -100,6 +85,7 @@ def wrap(self, *args, **kwargs): self = copy(self) self.block = out return self + return wrap @@ -131,11 +117,11 @@ def wrap(self, other, *args, **kwargs): other = other.block out = func(self, other, *args, **kwargs) return out - return wrap + return wrap -class BaseLayoutElement(): +class BaseLayoutElement: def set(self, inplace=False, **kwargs): obj = self if inplace else copy(self) @@ -152,8 +138,7 @@ def set(self, inplace=False, **kwargs): def __repr__(self): - info_str = ', '.join( - [f'{key}={val}' for key, val in vars(self).items()]) + info_str = ", ".join([f"{key}={val}" for key, val in vars(self).items()]) return f"{self.__class__.__name__}({info_str})" def __eq__(self, other): @@ -165,6 +150,17 @@ def __eq__(self, other): class BaseCoordElement(ABC, BaseLayoutElement): + @property + @abstractmethod + def _name(self) -> str: + """The name of the class""" + pass + + @property + @abstractmethod + def _features(self) -> List[str]: + """A list of features names used for initializing the class object""" + pass ####################################################################### ######################### Layout Properties ######################### @@ -172,26 +168,31 @@ class BaseCoordElement(ABC, BaseLayoutElement): @property @abstractmethod - def width(self): pass + def width(self): + pass @property @abstractmethod - def height(self): pass + def height(self): + pass @property @abstractmethod - def coordinates(self): pass + def coordinates(self): + pass @property @abstractmethod - def points(self): pass + def points(self): + pass @property @abstractmethod - def area(self): pass + def area(self): + pass ####################################################################### - ### Geometric Relations (relative to, condition on, and is in) ### + ### Geometric Relations (relative to, condition on, and is in) ### ####################################################################### @abstractmethod @@ -201,14 +202,14 @@ def condition_on(self, other): generate a new element of the current element in absolute coordinates. Args: - other (:obj:`BaseCoordElement`): + other (:obj:`BaseCoordElement`): The other layout element involved in the geometric operations. Raises: Exception: Raise error when the input type of the other element is invalid. Returns: - :obj:`BaseCoordElement`: + :obj:`BaseCoordElement`: The BaseCoordElement object of the original element in the absolute coordinate system. """ @@ -227,7 +228,7 @@ def relative_to(self, other): Exception: Raise error when the input type of the other element is invalid. Returns: - :obj:`BaseCoordElement`: + :obj:`BaseCoordElement`: The BaseCoordElement object of the original element in the relative coordinate system. """ @@ -236,15 +237,15 @@ def relative_to(self, other): @abstractmethod def is_in(self, other, soft_margin={}, center=False): """ - Identify whether the current element is within another element. + Identify whether the current element is within another element. Args: - other (:obj:`BaseCoordElement`): + other (:obj:`BaseCoordElement`): The other layout element involved in the geometric operations. - soft_margin (:obj:`dict`, `optional`, defaults to `{}`): - Enlarge the other element with wider margins to relax the restrictions. - center (:obj:`bool`, `optional`, defaults to `False`): - The toggle to determine whether the center (instead of the four corners) + soft_margin (:obj:`dict`, `optional`, defaults to `{}`): + Enlarge the other element with wider margins to relax the restrictions. + center (:obj:`bool`, `optional`, defaults to `False`): + The toggle to determine whether the center (instead of the four corners) of the current element is in the other element. Returns: @@ -258,10 +259,9 @@ def is_in(self, other, soft_margin={}, center=False): ####################################################################### @abstractmethod - def pad(self, left=0, right=0, top=0, bottom=0, - safe_mode=True): - """ Pad the layout element on the four sides of the polygon with the user-defined pixels. If - safe_mode is set to True, the function will cut off the excess padding that falls on the negative + def pad(self, left=0, right=0, top=0, bottom=0, safe_mode=True): + """Pad the layout element on the four sides of the polygon with the user-defined pixels. If + safe_mode is set to True, the function will cut off the excess padding that falls on the negative side of the coordinates. Args: @@ -284,7 +284,7 @@ def shift(self, shift_distance=0): numeric value, the element will by shifted by the same specified amount on both x and y axis. Args: - shift_distance (:obj:`numeric` or :obj:`Tuple(numeric)` or :obj:`List[numeric]`): + shift_distance (:obj:`numeric` or :obj:`Tuple(numeric)` or :obj:`List[numeric]`): The number of pixels used to shift the element. Returns: @@ -307,6 +307,7 @@ def scale(self, scale_factor=1): """ pass + ####################################################################### ################################# MISC ################################ ####################################################################### @@ -325,54 +326,88 @@ def crop_image(self, image): pass + ####################################################################### + ########################## Import and Export ########################## + ####################################################################### + + def to_dict(self) -> Dict[str, Any]: + """ + Generate a dictionary representation of the current object: + { + "block_type": <"interval", "rectangle", "quadrilateral"> , + "non_empty_block_attr1": value1, + ... + } + """ + + data = { + key: getattr(self, key) + for key in self._features + if getattr(self, key) is not None + } + data["block_type"] = self._name + return data + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "BaseCoordElement": + """Initialize an instance based on the dictionary representation + + Args: + data (:obj:`dict`): The dictionary representation of the object + """ + + assert ( + cls._name == data["block_type"] + ), f"Incompatible block types {data['block_type']}" + + return cls(**{f: data[f] for f in cls._features}) + @inherit_docstrings class Interval(BaseCoordElement): """ - This class describes the coordinate system of an interval, a block defined by a pair of start and end point + This class describes the coordinate system of an interval, a block defined by a pair of start and end point on the designated axis and same length as the base canvas on the other axis. Args: - start (:obj:`numeric`): + start (:obj:`numeric`): The coordinate of the start point on the designated axis. - end (:obj:`numeric`): + end (:obj:`numeric`): The end coordinate on the same axis as start. - axis (:obj:`str`, optional`, defaults to 'x'): + axis (:obj:`str`): The designated axis that the end points belong to. - canvas_height (:obj:`numeric`, `optional`, defaults to 0): + canvas_height (:obj:`numeric`, `optional`, defaults to 0): The height of the canvas that the interval is on. - canvas_width (:obj:`numeric`, `optional`, defaults to 0): + canvas_width (:obj:`numeric`, `optional`, defaults to 0): The width of the canvas that the interval is on. """ - name = "_interval" - feature_names = ["x_1", "y_1", "x_2", "y_2", "height", "width"] + _name = "interval" + _features = ["start", "end", "axis", "canvas_height", "canvas_width"] - def __init__(self, start, end, axis='x', - canvas_height=0, canvas_width=0): + def __init__(self, start, end, axis, canvas_height=None, canvas_width=None): assert start <= end, f"Invalid input for start and end. Start must <= end." self.start = start self.end = end - assert axis in [ - 'x', 'y'], f"Invalid axis {axis}. Axis must be in 'x' or 'y'" + assert axis in ["x", "y"], f"Invalid axis {axis}. Axis must be in 'x' or 'y'" self.axis = axis - self.canvas_height = canvas_height - self.canvas_width = canvas_width + self.canvas_height = canvas_height or 0 + self.canvas_width = canvas_width or 0 @property def height(self): """ - Calculate the height of the interval. If the interval is along the x-axis, the height will be the + Calculate the height of the interval. If the interval is along the x-axis, the height will be the height of the canvas, otherwise, it will be the difference between the start and end point. Returns: :obj:`numeric`: Output the numeric value of the height. """ - if self.axis == 'x': + if self.axis == "x": return self.canvas_height else: return self.end - self.start @@ -380,14 +415,14 @@ def height(self): @property def width(self): """ - Calculate the width of the interval. If the interval is along the y-axis, the width will be the + Calculate the width of the interval. If the interval is along the y-axis, the width will be the width of the canvas, otherwise, it will be the difference between the start and end point. Returns: :obj:`numeric`: Output the numeric value of the width. """ - if self.axis == 'y': + if self.axis == "y": return self.canvas_width else: return self.end - self.start @@ -395,15 +430,15 @@ def width(self): @property def coordinates(self): """ - This method considers an interval as a rectangle and calculates the coordinates of the upper left + This method considers an interval as a rectangle and calculates the coordinates of the upper left and lower right corners to define the interval. Returns: - :obj:`Tuple(numeric)`: - Output the numeric values of the coordinates in a Tuple of size four. + :obj:`Tuple(numeric)`: + Output the numeric values of the coordinates in a Tuple of size four. """ - if self.axis == 'x': + if self.axis == "x": coords = (self.start, 0, self.end, self.canvas_height) else: coords = (0, self.start, self.canvas_width, self.end) @@ -413,8 +448,8 @@ def coordinates(self): @property def points(self): """ - Return the coordinates of all four corners of the interval in a clockwise fashion - starting from the upper left. + Return the coordinates of all four corners of the interval in a clockwise fashion + starting from the upper left. Returns: :obj:`Numpy array`: A Numpy array of shape 4x2 containing the coordinates. @@ -431,14 +466,14 @@ def center(self): :obj:`Tuple(numeric)`: Returns of coordinate of the center. """ - return (self.start + self.end) / 2. + return (self.start + self.end) / 2.0 @property def area(self): - """Return the area of the covered region of the interval. + """Return the area of the covered region of the interval. The area is bounded to the canvas. If the interval is put - on a canvas, the area equals to interval width * canvas height - (axis='x') or interval height * canvas width (axis='y'). + on a canvas, the area equals to interval width * canvas height + (axis='x') or interval height * canvas width (axis='y'). Otherwise, the area is zero. """ return self.height * self.width @@ -448,13 +483,13 @@ def put_on_canvas(self, canvas): Set the height and the width of the canvas that the interval is on. Args: - canvas (:obj:`Numpy array` or :obj:`BaseCoordElement` or :obj:`PIL.Image.Image`): - The base element that the interval is on. The numpy array should be the + canvas (:obj:`Numpy array` or :obj:`BaseCoordElement` or :obj:`PIL.Image.Image`): + The base element that the interval is on. The numpy array should be the format of `[height, width]`. Returns: - :obj:`Interval`: - A copy of the current Interval with its canvas height and width set to + :obj:`Interval`: + A copy of the current Interval with its canvas height and width set to those of the input canvas. """ @@ -482,19 +517,11 @@ def condition_on(self, other): elif isinstance(other, Rectangle): - return (self - .put_on_canvas(other) - .to_rectangle() - .condition_on(other) - ) + return self.put_on_canvas(other).to_rectangle().condition_on(other) elif isinstance(other, Quadrilateral): - return (self - .put_on_canvas(other) - .to_quadrilateral() - .condition_on(other) - ) + return self.put_on_canvas(other).to_quadrilateral().condition_on(other) else: raise Exception(f"Invalid input type {other.__class__} for other") @@ -512,19 +539,11 @@ def relative_to(self, other): elif isinstance(other, Rectangle): - return (self - .put_on_canvas(other) - .to_rectangle() - .relative_to(other) - ) + return self.put_on_canvas(other).to_rectangle().relative_to(other) elif isinstance(other, Quadrilateral): - return (self - .put_on_canvas(other) - .to_quadrilateral() - .relative_to(other) - ) + return self.put_on_canvas(other).to_quadrilateral().relative_to(other) else: raise Exception(f"Invalid input type {other.__class__} for other") @@ -547,12 +566,12 @@ def is_in(self, other, soft_margin={}, center=False): x_1, y_1, x_2, y_2 = other.coordinates if center: - if self.axis == 'x': + if self.axis == "x": return x_1 <= self.center <= x_2 else: return y_1 <= self.center <= y_2 else: - if self.axis == 'x': + if self.axis == "x": return x_1 <= self.start <= self.end <= x_2 else: return y_1 <= self.start <= self.end <= y_2 @@ -562,18 +581,20 @@ def is_in(self, other, soft_margin={}, center=False): def pad(self, left=0, right=0, top=0, bottom=0, safe_mode=True): - if self.axis == 'x': + if self.axis == "x": start = self.start - left end = self.end + right if top or bottom: warnings.warn( - f"Invalid padding top/bottom for an x axis {self.__class__.__name__}") + f"Invalid padding top/bottom for an x axis {self.__class__.__name__}" + ) else: start = self.start - top end = self.end + bottom if left or right: warnings.warn( - f"Invalid padding right/left for a y axis {self.__class__.__name__}") + f"Invalid padding right/left for a y axis {self.__class__.__name__}" + ) if safe_mode: start = max(0, start) @@ -592,10 +613,12 @@ def shift(self, shift_distance): """ if isinstance(shift_distance, Iterable): - shift_distance = shift_distance[0] if self.axis == 'x' \ - else shift_distance[1] + shift_distance = ( + shift_distance[0] if self.axis == "x" else shift_distance[1] + ) warnings.warn( - f"Input shift for multiple axes. Only use the distance for the {self.axis} axis") + f"Input shift for multiple axes. Only use the distance for the {self.axis} axis" + ) start = self.start + shift_distance end = self.end + shift_distance @@ -613,10 +636,10 @@ def scale(self, scale_factor): """ if isinstance(scale_factor, Iterable): - scale_factor = scale_factor[0] if self.axis == 'x' \ - else scale_factor[1] + scale_factor = scale_factor[0] if self.axis == "x" else scale_factor[1] warnings.warn( - f"Input scale for multiple axes. Only use the factor for the {self.axis} axis") + f"Input scale for multiple axes. Only use the factor for the {self.axis} axis" + ) start = self.start * scale_factor end = self.end * scale_factor @@ -624,10 +647,10 @@ def scale(self, scale_factor): def crop_image(self, image): x_1, y_1, x_2, y_2 = self.put_on_canvas(image).coordinates - return image[int(y_1):int(y_2), int(x_1):int(x_2)] + return image[int(y_1) : int(y_2), int(x_1) : int(x_2)] def to_rectangle(self): - """ + """ Convert the Interval to a Rectangle element. Returns: @@ -647,16 +670,20 @@ def to_quadrilateral(self): @classmethod def from_series(cls, series): series = series.dropna() - if series.get('x_1') and series.get('x_2'): - axis = 'x' - start, end = series.get('x_1'), series.get('x_2') + if series.get("x_1") and series.get("x_2"): + axis = "x" + start, end = series.get("x_1"), series.get("x_2") else: - axis = 'y' - start, end = series.get('y_1'), series.get('y_2') + axis = "y" + start, end = series.get("y_1"), series.get("y_2") - return cls(start, end, axis=axis, - canvas_height=series.get('height') or 0, - canvas_width=series.get('width') or 0) + return cls( + start, + end, + axis=axis, + canvas_height=series.get("height") or 0, + canvas_width=series.get("width") or 0, + ) @inherit_docstrings @@ -671,18 +698,18 @@ class Rectangle(BaseCoordElement): ---- (x_2, y_2) Args: - x_1 (:obj:`numeric`): + x_1 (:obj:`numeric`): x coordinate on the horizontal axis of the upper left corner of the rectangle. - y_1 (:obj:`numeric`): + y_1 (:obj:`numeric`): y coordinate on the vertical axis of the upper left corner of the rectangle. - x_2 (:obj:`numeric`): + x_2 (:obj:`numeric`): x coordinate on the horizontal axis of the lower right corner of the rectangle. - y_2 (:obj:`numeric`): + y_2 (:obj:`numeric`): y coordinate on the vertical axis of the lower right corner of the rectangle. """ - name = "_rectangle" - feature_names = ["x_1", "y_1", "x_2", "y_2"] + _name = "rectangle" + _features = ["x_1", "y_1", "x_2", "y_2"] def __init__(self, x_1, y_1, x_2, y_2): @@ -719,7 +746,7 @@ def coordinates(self): Return the coordinates of the two points that define the rectangle. Returns: - :obj:`Tuple(numeric)`: Output the numeric values of the coordinates in a Tuple of size four. + :obj:`Tuple(numeric)`: Output the numeric values of the coordinates in a Tuple of size four. """ return (self.x_1, self.y_1, self.x_2, self.y_2) @@ -727,8 +754,8 @@ def coordinates(self): @property def points(self): """ - Return the coordinates of all four corners of the rectangle in a clockwise fashion - starting from the upper left. + Return the coordinates of all four corners of the rectangle in a clockwise fashion + starting from the upper left. Returns: :obj:`Numpy array`: A Numpy array of shape 4x2 containing the coordinates. @@ -745,7 +772,7 @@ def center(self): :obj:`Tuple(numeric)`: Returns of coordinate of the center. """ - return (self.x_1 + self.x_2)/2., (self.y_1 + self.y_2)/2. + return (self.x_1 + self.x_2) / 2.0, (self.y_1 + self.y_2) / 2.0 @property def area(self): @@ -758,23 +785,26 @@ def area(self): def condition_on(self, other): if isinstance(other, Interval): - if other.axis == 'x': + if other.axis == "x": dx, dy = other.start, 0 else: dx, dy = 0, other.start - return self.__class__(self.x_1 + dx, self.y_1 + dy, - self.x_2 + dx, self.y_2 + dy) + return self.__class__( + self.x_1 + dx, self.y_1 + dy, self.x_2 + dx, self.y_2 + dy + ) elif isinstance(other, Rectangle): dx, dy, _, _ = other.coordinates - return self.__class__(self.x_1 + dx, self.y_1 + dy, - self.x_2 + dx, self.y_2 + dy) + return self.__class__( + self.x_1 + dx, self.y_1 + dy, self.x_2 + dx, self.y_2 + dy + ) elif isinstance(other, Quadrilateral): - transformed_points = _perspective_transformation(other.perspective_matrix, - self.points, is_inv=True) + transformed_points = _perspective_transformation( + other.perspective_matrix, self.points, is_inv=True + ) return other.__class__(transformed_points, self.height, self.width) @@ -784,23 +814,26 @@ def condition_on(self, other): @support_textblock def relative_to(self, other): if isinstance(other, Interval): - if other.axis == 'x': + if other.axis == "x": dx, dy = other.start, 0 else: dx, dy = 0, other.start - return self.__class__(self.x_1 - dx, self.y_1 - dy, - self.x_2 - dx, self.y_2 - dy) + return self.__class__( + self.x_1 - dx, self.y_1 - dy, self.x_2 - dx, self.y_2 - dy + ) elif isinstance(other, Rectangle): dx, dy, _, _ = other.coordinates - return self.__class__(self.x_1 - dx, self.y_1 - dy, - self.x_2 - dx, self.y_2 - dy) + return self.__class__( + self.x_1 - dx, self.y_1 - dy, self.x_2 - dx, self.y_2 - dy + ) elif isinstance(other, Quadrilateral): - transformed_points = _perspective_transformation(other.perspective_matrix, - self.points, is_inv=False) + transformed_points = _perspective_transformation( + other.perspective_matrix, self.points, is_inv=False + ) return other.__class__(transformed_points, self.height, self.width) @@ -814,28 +847,31 @@ def is_in(self, other, soft_margin={}, center=False): if isinstance(other, Interval): if not center: - if other.axis == 'x': + if other.axis == "x": start, end = self.x_1, self.x_2 else: start, end = self.y_1, self.y_2 return other.start <= start <= end <= other.end else: - c = self.center[0] if other.axis == 'x' else self.center[1] + c = self.center[0] if other.axis == "x" else self.center[1] return other.start <= c <= other.end elif isinstance(other, Rectangle): - x_interval = other.to_interval(axis='x') - y_interval = other.to_interval(axis='y') - return self.is_in(x_interval, center=center) and \ - self.is_in(y_interval, center=center) + x_interval = other.to_interval(axis="x") + y_interval = other.to_interval(axis="y") + return self.is_in(x_interval, center=center) and self.is_in( + y_interval, center=center + ) elif isinstance(other, Quadrilateral): if not center: # This is equivalent to determine all the points of the # rectangle is in the quadrilateral. - is_vertice_in = [_vertice_in_polygon( - vertice, other.points) for vertice in self.points] + is_vertice_in = [ + _vertice_in_polygon(vertice, other.points) + for vertice in self.points + ] return all(is_vertice_in) else: center = np.array(self.center) @@ -844,8 +880,7 @@ def is_in(self, other, soft_margin={}, center=False): else: raise Exception(f"Invalid input type {other.__class__} for other") - def pad(self, left=0, right=0, top=0, bottom=0, - safe_mode=True): + def pad(self, left=0, right=0, top=0, bottom=0, safe_mode=True): x_1 = self.x_1 - left y_1 = self.y_1 - top @@ -864,8 +899,9 @@ def shift(self, shift_distance=0): shift_x = shift_distance shift_y = shift_distance else: - assert len( - shift_distance) == 2, "shift_distance should have 2 elements, one for x dimension and one for y dimension" + assert ( + len(shift_distance) == 2 + ), "shift_distance should have 2 elements, one for x dimension and one for y dimension" shift_x, shift_y = shift_distance x_1 = self.x_1 + shift_x @@ -880,8 +916,9 @@ def scale(self, scale_factor=1): scale_x = scale_factor scale_y = scale_factor else: - assert len( - scale_factor) == 2, "scale_factor should have 2 elements, one for x dimension and one for y dimension" + assert ( + len(scale_factor) == 2 + ), "scale_factor should have 2 elements, one for x dimension and one for y dimension" scale_x, scale_y = scale_factor x_1 = self.x_1 * scale_x @@ -892,10 +929,10 @@ def scale(self, scale_factor=1): def crop_image(self, image): x_1, y_1, x_2, y_2 = self.coordinates - return image[int(y_1):int(y_2), int(x_1):int(x_2)] + return image[int(y_1) : int(y_2), int(x_1) : int(x_2)] - def to_interval(self, axis='x', **kwargs): - if axis == 'x': + def to_interval(self, axis, **kwargs): + if axis == "x": start, end = self.x_1, self.x_2 else: start, end = self.y_1, self.y_2 @@ -914,7 +951,7 @@ def from_series(cls, series): @inherit_docstrings class Quadrilateral(BaseCoordElement): """ - This class describes the coodinate system of a four-sided polygon. A quadrilateral is defined by + This class describes the coodinate system of a four-sided polygon. A quadrilateral is defined by the coordinates of its 4 corners in a clockwise order starting with the upper left corner (as shown below):: points[0] -...- points[1] @@ -926,8 +963,10 @@ class Quadrilateral(BaseCoordElement): points[3] -...- points[2] Args: - points (:obj:`Numpy array`): - The array of 4 corner coordinates of size 4x2. + points (:obj:`Numpy array` or `list`): + A `np.ndarray` of shape 4x2 for four corner coordinates + or a list of length 8 for in the format of + `[p[0,0], p[0,1], p[1,0], p[1,1], ...]`. height (:obj:`numeric`, `optional`, defaults to `None`): The height of the quadrilateral. This is to better support the perspective transformation from the OpenCV library. @@ -936,15 +975,25 @@ class Quadrilateral(BaseCoordElement): transformation from the OpenCV library. """ - name = "_quadrilateral" - feature_names = ["p11", "p12", "p21", "p22", - "p31", "p32", "p41", "p42", - "height", "width"] + _name = "quadrilateral" + _features = ["points", "height", "width"] def __init__(self, points, height=None, width=None): - assert isinstance( - points, np.ndarray), f" Invalid input: points must be a numpy array" + if isinstance(points, np.ndarray): + if points.shape != (4, 2): + raise ValueError(f"Invalid points shape: {points.shape}.") + elif isinstance(points, list): + if len(points) != 8: + raise ValueError( + f"Invalid number of points element {len(points)}. Should be 8." + ) + points = np.array(points).reshape(4, 2) + else: + raise ValueError( + f"Invalid input type for points {type(points)}." + "Please make sure it is a list of np.ndarray." + ) self._points = points self._width = width @@ -979,11 +1028,11 @@ def width(self): @property def coordinates(self): """ - Return the coordinates of the upper left and lower right corners points that + Return the coordinates of the upper left and lower right corners points that define the circumscribed rectangle. Returns - :obj:`Tuple(numeric)`: Output the numeric values of the coordinates in a Tuple of size four. + :obj:`Tuple(numeric)`: Output the numeric values of the coordinates in a Tuple of size four. """ return _cvt_points_to_coordinates(self.points) @@ -991,8 +1040,8 @@ def coordinates(self): @property def points(self): """ - Return the coordinates of all four corners of the quadrilateral in a clockwise fashion - starting from the upper left. + Return the coordinates of all four corners of the quadrilateral in a clockwise fashion + starting from the upper left. Returns: :obj:`Numpy array`: A Numpy array of shape 4x2 containing the coordinates. @@ -1028,25 +1077,29 @@ def mapped_rectangle_points(self): @property def perspective_matrix(self): - return _getPerspectiveTransform(self.points.astype('float32'), - self.mapped_rectangle_points.astype('float32')) + return _getPerspectiveTransform( + self.points.astype("float32"), + self.mapped_rectangle_points.astype("float32"), + ) def map_to_points_ordering(self, x_map, y_map): points_ordering = self.points.argsort(axis=0).argsort(axis=0) # Ref: https://github.com/numpy/numpy/issues/8757#issuecomment-355126992 - return np.vstack([ - np.vectorize(x_map.get)(points_ordering[:, 0]), - np.vectorize(y_map.get)(points_ordering[:, 1]) - ]).T + return np.vstack( + [ + np.vectorize(x_map.get)(points_ordering[:, 0]), + np.vectorize(y_map.get)(points_ordering[:, 1]), + ] + ).T @support_textblock def condition_on(self, other): if isinstance(other, Interval): - if other.axis == 'x': + if other.axis == "x": return self.shift([other.start, 0]) else: return self.shift([0, other.start]) @@ -1057,8 +1110,9 @@ def condition_on(self, other): elif isinstance(other, Quadrilateral): - transformed_points = _perspective_transformation(other.perspective_matrix, - self.points, is_inv=True) + transformed_points = _perspective_transformation( + other.perspective_matrix, self.points, is_inv=True + ) return self.__class__(transformed_points, self.height, self.width) else: @@ -1069,7 +1123,7 @@ def relative_to(self, other): if isinstance(other, Interval): - if other.axis == 'x': + if other.axis == "x": return self.shift([-other.start, 0]) else: return self.shift([0, -other.start]) @@ -1080,8 +1134,9 @@ def relative_to(self, other): elif isinstance(other, Quadrilateral): - transformed_points = _perspective_transformation(other.perspective_matrix, - self.points, is_inv=False) + transformed_points = _perspective_transformation( + other.perspective_matrix, self.points, is_inv=False + ) return self.__class__(transformed_points, self.height, self.width) else: @@ -1094,28 +1149,31 @@ def is_in(self, other, soft_margin={}, center=False): if isinstance(other, Interval): if not center: - if other.axis == 'x': + if other.axis == "x": start, end = self.coordinates[0], self.coordinates[2] else: start, end = self.coordinates[1], self.coordinates[3] return other.start <= start <= end <= other.end else: - c = self.center[0] if other.axis == 'x' else self.center[1] + c = self.center[0] if other.axis == "x" else self.center[1] return other.start <= c <= other.end elif isinstance(other, Rectangle): - x_interval = other.to_interval(axis='x') - y_interval = other.to_interval(axis='y') - return self.is_in(x_interval, center=center) and \ - self.is_in(y_interval, center=center) + x_interval = other.to_interval(axis="x") + y_interval = other.to_interval(axis="y") + return self.is_in(x_interval, center=center) and self.is_in( + y_interval, center=center + ) elif isinstance(other, Quadrilateral): if not center: # This is equivalent to determine all the points of the # rectangle is in the quadrilateral. - is_vertice_in = [_vertice_in_polygon( - vertice, other.points) for vertice in self.points] + is_vertice_in = [ + _vertice_in_polygon(vertice, other.points) + for vertice in self.points + ] return all(is_vertice_in) else: center = np.array(self.center) @@ -1124,11 +1182,10 @@ def is_in(self, other, soft_margin={}, center=False): else: raise Exception(f"Invalid input type {other.__class__} for other") - def pad(self, left=0, right=0, top=0, bottom=0, - safe_mode=True): + def pad(self, left=0, right=0, top=0, bottom=0, safe_mode=True): - x_map = {0: -left, 1: -left, 2: right, 3: right} - y_map = {0: -top, 1: -top, 2: bottom, 3: bottom} + x_map = {0: -left, 1: -left, 2: right, 3: right} + y_map = {0: -top, 1: -top, 2: bottom, 3: bottom} padding_mat = self.map_to_points_ordering(x_map, y_map) @@ -1143,8 +1200,9 @@ def shift(self, shift_distance=0): if not isinstance(shift_distance, Iterable): shift_mat = [shift_distance, shift_distance] else: - assert len( - shift_distance) == 2, "shift_distance should have 2 elements, one for x dimension and one for y dimension" + assert ( + len(shift_distance) == 2 + ), "shift_distance should have 2 elements, one for x dimension and one for y dimension" shift_mat = shift_distance points = self.points + np.array(shift_mat) @@ -1156,8 +1214,9 @@ def scale(self, scale_factor=1): if not isinstance(scale_factor, Iterable): scale_mat = [scale_factor, scale_factor] else: - assert len( - scale_factor) == 2, "scale_factor should have 2 elements, one for x dimension and one for y dimension" + assert ( + len(scale_factor) == 2 + ), "scale_factor should have 2 elements, one for x dimension and one for y dimension" scale_mat = scale_factor points = self.points * np.array(scale_mat) @@ -1175,12 +1234,14 @@ def crop_image(self, image): :obj:`Numpy array`: The array of the cropped image. """ - return _warpPerspective(image, self.perspective_matrix, (int(self.width), int(self.height))) + return _warpPerspective( + image, self.perspective_matrix, (int(self.width), int(self.height)) + ) - def to_interval(self, axis='x', **kwargs): + def to_interval(self, axis="x", **kwargs): x_1, y_1, x_2, y_2 = self.coordinates - if axis == 'x': + if axis == "x": start, end = x_1, x_2 else: start, end = y_1, y_2 @@ -1194,12 +1255,11 @@ def to_rectangle(self): def from_series(cls, series): series = series.dropna() - points = pd.to_numeric( - series[cls.feature_names[:8]]).values.reshape(4, -2) + points = pd.to_numeric(series[cls.feature_names[:8]]).values.reshape(4, -2) - return cls(points=points, - height=series.get("height"), - width=series.get("width")) + return cls( + points=points, height=series.get("height"), width=series.get("width") + ) def __eq__(self, other): if other.__class__ is not self.__class__: @@ -1207,21 +1267,50 @@ def __eq__(self, other): return np.isclose(self.points, other.points).all() def __repr__(self): - keys = ['points', 'width', 'height'] - info_str = ', '.join([f'{key}={getattr(self, key)}' for key in keys]) + keys = ["points", "width", "height"] + info_str = ", ".join([f"{key}={getattr(self, key)}" for key in keys]) return f"{self.__class__.__name__}({info_str})" + def to_dict(self) -> Dict[str, Any]: + + """ + Generate a dictionary representation of the current object:: + + { + "block_type": "quadrilateral", + "points": [ + p[0,0], p[0,1], + p[1,0], p[1,1], + p[2,0], p[2,1], + p[3,0], p[3,1] + ], + "height": value, + "width": value + } + """ + data = super().to_dict() + data["points"] = data["points"].reshape(-1).tolist() + return data + + +ALL_BASECOORD_ELEMENTS = [Interval, Rectangle, Quadrilateral] + +BASECOORD_ELEMENT_NAMEMAP = {ele._name: ele for ele in ALL_BASECOORD_ELEMENTS} +BASECOORD_ELEMENT_INDEXMAP = { + ele._name: idx for idx, ele in enumerate(ALL_BASECOORD_ELEMENTS) +} + @inherit_docstrings(base_class=BaseCoordElement) class TextBlock(BaseLayoutElement): """ - This class constructs content-related information of a layout element in addition to its coordinate definitions + This class constructs content-related information of a layout element in addition to its coordinate definitions (i.e. Interval, Rectangle or Quadrilateral). Args: - block (:obj:`BaseCoordElement`): + block (:obj:`BaseCoordElement`): The shape-specific coordinate systems that the text block belongs to. - text (:obj:`str`, `optional`, defaults to ""): + text (:obj:`str`, `optional`, defaults to None): The ocr'ed text results within the boundaries of the text block. id (:obj:`int`, `optional`, defaults to `None`): The id of the text block. @@ -1235,12 +1324,12 @@ class TextBlock(BaseLayoutElement): The prediction confidence of the block """ - name = "_textblock" - feature_names = ["text", "id", "type", "parent", "next", "score"] + _name = "textblock" + _features = ["text", "id", "type", "parent", "next", "score"] - def __init__(self, block, text="", - id=None, type=None, parent=None, - next=None, score=None): + def __init__( + self, block, text=None, id=None, type=None, parent=None, next=None, score=None + ): assert isinstance(block, BaseCoordElement) self.block = block @@ -1280,7 +1369,7 @@ def coordinates(self): Return the coordinates of the two corner points that define the shape-specific block. Returns: - :obj:`Tuple(numeric)`: Output the numeric values of the coordinates in a Tuple of size four. + :obj:`Tuple(numeric)`: Output the numeric values of the coordinates in a Tuple of size four. """ return self.block.coordinates @@ -1288,8 +1377,8 @@ def coordinates(self): @property def points(self): """ - Return the coordinates of all four corners of the shape-specific block in a clockwise fashion - starting from the upper left. + Return the coordinates of all four corners of the shape-specific block in a clockwise fashion + starting from the upper left. Returns: :obj:`Numpy array`: A Numpy array of shape 4x2 containing the coordinates. @@ -1312,16 +1401,16 @@ def condition_on(self, other): def relative_to(self, other): return self.block.relative_to(other) - def is_in(self, other, **kwargs): - return self.block.is_in(other, **kwargs) + def is_in(self, other, soft_margin={}, center=False): + return self.block.is_in(other, soft_margin, center) @mixin_textblock_meta def shift(self, shift_distance): return self.block.shift(shift_distance) @mixin_textblock_meta - def pad(self, **kwargs): - return self.block.pad(**kwargs) + def pad(self, left=0, right=0, top=0, bottom=0, safe_mode=True): + return self.block.pad(left, right, top, bottom, safe_mode) @mixin_textblock_meta def scale(self, scale_factor): @@ -1344,54 +1433,199 @@ def from_series(cls, series): else: target_type = Interval - return cls( - block=target_type.from_series(series), - **features) + return cls(block=target_type.from_series(series), **features) + def to_dict(self) -> Dict[str, Any]: + """ + Generate a dictionary representation of the current textblock of the format:: + + { + "block_type": , + + } + """ + base_dict = self.block.to_dict() + for f in self._features: + val = getattr(self, f) + if val is not None: + base_dict[f] = getattr(self, f) + return base_dict + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "TextBlock": + """Initialize the textblock based on the dictionary representation. + It generate the block based on the `block_type` and `block_attr`, + and loads the textblock specific features from the dict. -class Layout(list): - """ A handy class for handling a list of text blocks. All the class functions will be broadcasted to - each element block in the list. + Args: + data (:obj:`dict`): The dictionary representation of the object + """ + assert ( + data["block_type"] in BASECOORD_ELEMENT_NAMEMAP + ), f"Invalid block_type {data['block_type']}" + + block = BASECOORD_ELEMENT_NAMEMAP[data["block_type"]].from_dict(data) + + return cls(block, **{f: data.get(f, None) for f in cls._features}) + + +class Layout(MutableSequence): """ + The :obj:`Layout` class id designed for processing a list of layout elements + on a page. It stores the layout elements in a list and the related `page_data`, + and provides handy APIs for processing all the layout elements in batch. ` + + Args: + blocks (:obj:`list`): + A list of layout element blocks + page_data (Dict, optional): + A dictionary storing the page (canvas) related information + like `height`, `width`, etc. + Defaults to None. + """ + + def __init__(self, blocks: List = [], page_data: Dict = None): + self._blocks = blocks + self.page_data = page_data or {} + + def __getitem__(self, key): + blocks = self._blocks[key] + if isinstance(key, slice): + return self.__class__(self._blocks[key], self.page_data) + else: + return blocks + + def __setitem__(self, key, newvalue): + self._blocks[key] = newvalue + + def __delitem__(self, key): + del self._blocks[key] + + def __len__(self): + return len(self._blocks) + + def __iter__(self): + for ele in self._blocks: + yield ele + + def __repr__(self): + info_str = ", ".join([f"{key}={val}" for key, val in vars(self).items()]) + return f"{self.__class__.__name__}({info_str})" + + def __eq__(self, other): + if isinstance(other, Layout): + return ( + all((a, b) for a, b in zip(self, other)) + and self.page_data == other.page_data + ) + else: + return False + + def __add__(self, other): + if isinstance(other, Layout): + if self.page_data == other.page_data: + return self.__class__(self._blocks + other._blocks, self.page_data) + elif self.page_data == {} or other.page_data == {}: + return self.__class__( + self._blocks + other._blocks, self.page_data or other.page_data + ) + else: + raise ValueError( + f"Incompatible page_data for two innputs: {self.page_data} vs {other.page_data}." + ) + elif isinstance(other, list): + return self.__class__(self._blocks + other, self.page_data) + else: + raise ValueError( + f"Invalid input type for other {other.__class__.__name__}." + ) - identifier_map = { - Interval.name: Interval, - Rectangle.name: Rectangle, - Quadrilateral.name: Quadrilateral, - TextBlock.name: TextBlock} + def insert(self, key, value): + self._blocks.insert(key, value) + + def copy(self): + return self.__class__(copy(self._blocks), self.page_data) def relative_to(self, other): - return self.__class__([ele.relative_to(other) for ele in self]) + return self.__class__([ele.relative_to(other) for ele in self], self.page_data) def condition_on(self, other): - return self.__class__([ele.condition_on(other) for ele in self]) + return self.__class__([ele.condition_on(other) for ele in self], self.page_data) - def is_in(self, other, **kwargs): - return self.__class__([ele.is_in(other, **kwargs) for ele in self]) + def is_in(self, other, soft_margin={}, center=False): + return self.__class__( + [ele.is_in(other, soft_margin, center) for ele in self], self.page_data + ) - def filter_by(self, other, **kwargs): + def filter_by(self, other, soft_margin={}, center=False): """ Return a `Layout` object containing the elements that are in the `other` object. Args: - other (:obj:`BaseCoordElement`) + other (:obj:`BaseCoordElement`): + The block to filter the current elements. Returns: - :obj:`Layout` + :obj:`Layout`: + A new layout object after filtering. """ - return self.__class__([ele for ele in self if ele.is_in(other, **kwargs)]) + return self.__class__( + [ele for ele in self if ele.is_in(other, soft_margin, center)], + self.page_data, + ) - @functools.wraps(BaseCoordElement.shift) def shift(self, shift_distance): - return self.__class__([ele.shift(shift_distance) for ele in self]) + """ + Shift all layout elements by user specified amounts on x and y axis respectively. If shift_distance is one + numeric value, the element will by shifted by the same specified amount on both x and y axis. + + Args: + shift_distance (:obj:`numeric` or :obj:`Tuple(numeric)` or :obj:`List[numeric]`): + The number of pixels used to shift the element. - @functools.wraps(BaseCoordElement.pad) - def pad(self, **kwargs): - return self.__class__([ele.pad(**kwargs) for ele in self]) + Returns: + :obj:`Layout`: + A new layout object with all the elements shifted in the specified values. + """ + return self.__class__( + [ele.shift(shift_distance) for ele in self], self.page_data + ) + + def pad(self, left=0, right=0, top=0, bottom=0, safe_mode=True): + """Pad all layout elements on the four sides of the polygon with the user-defined pixels. If + safe_mode is set to True, the function will cut off the excess padding that falls on the negative + side of the coordinates. + + Args: + left (:obj:`int`, `optional`, defaults to 0): The number of pixels to pad on the upper side of the polygon. + right (:obj:`int`, `optional`, defaults to 0): The number of pixels to pad on the lower side of the polygon. + top (:obj:`int`, `optional`, defaults to 0): The number of pixels to pad on the left side of the polygon. + bottom (:obj:`int`, `optional`, defaults to 0): The number of pixels to pad on the right side of the polygon. + safe_mode (:obj:`bool`, `optional`, defaults to True): A bool value to toggle the safe_mode. + + Returns: + :obj:`Layout`: + A new layout object with all the elements padded in the specified values. + """ + return self.__class__( + [ele.pad(left, right, top, bottom, safe_mode) for ele in self], + self.page_data, + ) - @functools.wraps(BaseCoordElement.scale) def scale(self, scale_factor): - return self.__class__([ele.scale(scale_factor) for ele in self]) + """ + Scale all layout element by a user specified amount on x and y axis respectively. If scale_factor is one + numeric value, the element will by scaled by the same specified amount on both x and y axis. + + Args: + scale_factor (:obj:`numeric` or :obj:`Tuple(numeric)` or :obj:`List[numeric]`): The amount for downscaling or upscaling the element. + + Returns: + :obj:`Layout`: + A new layout object with all the elements scaled in the specified values. + """ + return self.__class__([ele.scale(scale_factor) for ele in self], self.page_data) def crop_image(self, image): return [ele.crop_image(image) for ele in self] @@ -1404,7 +1638,7 @@ def get_texts(self): :obj:`List[str]`: A list of text strings of the text blocks in the list of layout elements. """ - return [ele.text for ele in self if hasattr(ele, 'text')] + return [ele.text for ele in self if hasattr(ele, "text")] def get_info(self, attr_name): """Given user-provided attribute name, check all the elements in the list and return the corresponding @@ -1414,31 +1648,90 @@ def get_info(self, attr_name): attr_name (:obj:`str`): The text string of certain attribute name. Returns: - :obj:`List`: - The list of the corresponding attribute value (if exist) of each element in the list. + :obj:`List`: + The list of the corresponding attribute value (if exist) of each element in the list. """ return [getattr(ele, attr_name) for ele in self if hasattr(ele, attr_name)] - @classmethod - def from_dataframe(cls, df): + def to_dict(self) -> Dict[str, Any]: + """Generate a dict representation of the layout object with + the page_data and all the blocks in its dict representation. - if "_identifier" in df.columns: - return cls( - [cls.identifier_map[series["_identifier"]].from_series(series.drop(columns=["_identifier"])) - for (_, series) in df.iterrows()] - ) + Returns: + :obj:`Dict`: + The dictionary representation of the layout object. + """ + return {"page_data": self.page_data, "blocks": [ele.to_dict() for ele in self]} + + def get_homogeneous_blocks(self) -> List[BaseLayoutElement]: + """Convert all elements into blocks of the same type based + on the type casting rule:: + + Interval < Rectangle < Quadrilateral < TextBlock + + Returns: + List[BaseLayoutElement]: + A list of base layout elements of the maximal compatible + type + """ - elif any(col in TextBlock.feature_names for col in df.columns): + # Detect the maximal compatible type + has_textblock = False + max_coord_level = -1 + for ele in self: - return cls( - [TextBlock.from_series(series) - for (_, series) in df.iterrows()] + if isinstance(ele, TextBlock): + has_textblock = True + block = ele.block + else: + block = ele + + max_coord_level = max( + max_coord_level, BASECOORD_ELEMENT_INDEXMAP[block._name] ) + target_coord_name = ALL_BASECOORD_ELEMENTS[max_coord_level]._name + + if has_textblock: + new_blocks = [] + for ele in self: + if isinstance(ele, TextBlock): + ele = copy(ele) + if ele.block._name != target_coord_name: + ele.block = getattr(ele.block, f"to_{target_coord_name}")() + else: + if ele._name != target_coord_name: + ele = getattr(ele, f"to_{target_coord_name}")() + ele = TextBlock(block) + new_blocks.append(ele) + else: + new_blocks = [ + getattr(ele, f"to_{target_coord_name}")() + if ele._name != target_coord_name + else ele + for ele in self + ] + + return new_blocks + + def to_dataframe(self, enforce_same_type=False) -> pd.DataFrame: + """Convert the layout object into the dataframe. + Warning: the page data won't be exported. + + Args: + enforce_same_type (:obj:`bool`, optional): + If true, it will convert all the contained blocks to + the maximal compatible data type. + Defaults to False. + Returns: + pd.DataFrame: + The dataframe representation of layout object + """ + if enforce_same_type: + blocks = self.get_homogeneous_blocks() else: - target_type = _parse_datatype_from_feature_names(df.columns) + blocks = self - return cls( - [target_type.from_series(series) - for (_, series) in df.iterrows()] - ) + df = pd.DataFrame([ele.to_dict() for ele in blocks]) + + return df diff --git a/src/layoutparser/io.py b/src/layoutparser/io.py new file mode 100644 index 0000000..fad13b7 --- /dev/null +++ b/src/layoutparser/io.py @@ -0,0 +1,133 @@ +import ast +import json +from typing import List, Union, Dict, Dict, Any + +import pandas as pd + +from .elements import ( + BaseCoordElement, + BaseLayoutElement, + Interval, + Rectangle, + Quadrilateral, + TextBlock, + Layout, + BASECOORD_ELEMENT_NAMEMAP, +) + + +def load_json(filename: str) -> Union[BaseLayoutElement, Layout]: + """Load a JSON file and save it as a layout object with appropriate data types. + + Args: + filename (str): + The name of the JSON file. + + Returns: + Union[BaseLayoutElement, Layout]: + Based on the JSON file format, it will automatically parse + the type of the data and load it accordingly. + """ + with open(filename, "r") as fp: + res = json.load(fp) + + return load_dict(res) + + +def load_dict(data: Union[Dict, List[Dict]]) -> Union[BaseLayoutElement, Layout]: + """Load a dict of list of dict representations of some layout data, + automatically parse its type, and save it as any of BaseLayoutElement + or Layout datatype. + + Args: + data (Union[Dict, List]): + A dict of list of dict representations of the layout data + + Raises: + ValueError: + If the data format is incompatible with the layout-data-JSON format, + raise a `ValueError`. + ValueError: + If any `block_type` name is not in the available list of layout element + names defined in `BASECOORD_ELEMENT_NAMEMAP`, raise a `ValueError`. + + Returns: + Union[BaseLayoutElement, Layout]: + Based on the dict format, it will automatically parse the type of + the data and load it accordingly. + """ + if isinstance(data, dict): + if "page_data" in data: + # It is a layout instance + return Layout(load_dict(data["blocks"]), page_data=data["page_data"]) + else: + + if data["block_type"] not in BASECOORD_ELEMENT_NAMEMAP: + raise ValueError(f"Invalid block_type {data['block_type']}") + + # Check if it is a textblock + is_textblock = any(ele in data for ele in TextBlock._features) + if is_textblock: + return TextBlock.from_dict(data) + else: + return BASECOORD_ELEMENT_NAMEMAP[data["block_type"]].from_dict(data) + + elif isinstance(data, list): + return Layout([load_dict(ele) for ele in data]) + + else: + raise ValueError(f"Invalid input JSON structure.") + + +def load_csv(filename: str, block_type: str = None) -> Layout: + """Load the Layout object from the given CSV file. + + Args: + filename (str): + The name of the CSV file. A row of the table represents + an individual layout element. + + block_type (str): + If there's no block_type column in the CSV file, + you must pass in a block_type variable such that layout parser + can appropriately detect the type of the layout elements. + + Returns: + Layout: + The parsed Layout object from the CSV file. + """ + + return load_dataframe(pd.read_csv(filename), block_type=block_type) + + +def load_dataframe(df: pd.DataFrame, block_type: str = None) -> Layout: + """Load the Layout object from the given dataframe. + + Args: + df (pd.DataFrame): + + block_type (str): + If there's no block_type column in the CSV file, + you must pass in a block_type variable such that layout parser + can appropriately detect the type of the layout elements. + + Returns: + Layout: + The parsed Layout object from the CSV file. + """ + df = df.copy() + if "points" in df.columns: + if df["points"].dtype == object: + df["points"] = df["points"].map( + lambda x: ast.literal_eval(x) if not pd.isna(x) else x + ) + + if block_type is None: + if "block_type" not in df.columns: + raise ValueError( + "`block_type` not specified both in dataframe and arguments" + ) + else: + df["block_type"] = block_type + + return load_dict(df.apply(lambda x: x.dropna().to_dict(), axis=1).to_list()) diff --git a/src/layoutparser/models/layoutmodel.py b/src/layoutparser/models/layoutmodel.py index 6110c74..78b6973 100644 --- a/src/layoutparser/models/layoutmodel.py +++ b/src/layoutparser/models/layoutmodel.py @@ -6,7 +6,8 @@ import numpy as np import torch from fvcore.common.file_io import PathManager -#TODO: Update to iopath in the next major release + +# TODO: Update to iopath in the next major release from ..elements import * diff --git a/src/layoutparser/ocr.py b/src/layoutparser/ocr.py index b7835cd..65f5ccc 100644 --- a/src/layoutparser/ocr.py +++ b/src/layoutparser/ocr.py @@ -5,15 +5,17 @@ import os import json import csv +import warnings +import pickle + import numpy as np import pandas as pd from cv2 import imencode + from .elements import * -import warnings -import pickle +from .io import load_dataframe -__all__ = ['GCVFeatureType', 'GCVAgent', - 'TesseractFeatureType', 'TesseractAgent'] +__all__ = ["GCVFeatureType", "GCVAgent", "TesseractFeatureType", "TesseractAgent"] def _cvt_GCV_vertices_to_points(vertices): @@ -21,34 +23,32 @@ def _cvt_GCV_vertices_to_points(vertices): class BaseOCRElementType(IntEnum): - @property @abstractmethod - def attr_name(self): pass + def attr_name(self): + pass class BaseOCRAgent(ABC): - @property @abstractmethod def DEPENDENCIES(self): - """DEPENDENCIES lists all necessary dependencies for the class. - """ + """DEPENDENCIES lists all necessary dependencies for the class.""" pass @property @abstractmethod def MODULES(self): - """MODULES instructs how to import these necessary libraries. + """MODULES instructs how to import these necessary libraries. - Note: - Sometimes a python module have different installation name and module name (e.g., + Note: + Sometimes a python module have different installation name and module name (e.g., `pip install tensorflow-gpu` when installing and `import tensorflow` when using - ). And sometimes we only need to import a submodule but not whole module. MODULES - is designed for this purpose. + ). And sometimes we only need to import a submodule but not whole module. MODULES + is designed for this purpose. Returns: - :obj: list(dict): A list of dict indicate how the model is imported. + :obj: list(dict): A list of dict indicate how the model is imported. Example:: @@ -65,8 +65,9 @@ def MODULES(self): def _import_module(cls): for m in cls.MODULES: if importlib.util.find_spec(m["module_path"]): - setattr(cls, m["import_name"], - importlib.import_module(m["module_path"])) + setattr( + cls, m["import_name"], importlib.import_module(m["module_path"]) + ) else: raise ModuleNotFoundError( f"\n " @@ -81,7 +82,8 @@ def __new__(cls, *args, **kwargs): return super().__new__(cls) @abstractmethod - def detect(self, image): pass + def detect(self, image): + pass class GCVFeatureType(BaseOCRElementType): @@ -98,11 +100,11 @@ class GCVFeatureType(BaseOCRElementType): @property def attr_name(self): name_cvt = { - GCVFeatureType.PAGE: 'pages', - GCVFeatureType.BLOCK: 'blocks', - GCVFeatureType.PARA: 'paragraphs', - GCVFeatureType.WORD: 'words', - GCVFeatureType.SYMBOL: 'symbols' + GCVFeatureType.PAGE: "pages", + GCVFeatureType.BLOCK: "blocks", + GCVFeatureType.PARA: "paragraphs", + GCVFeatureType.WORD: "words", + GCVFeatureType.SYMBOL: "symbols", } return name_cvt[self] @@ -113,72 +115,64 @@ def child_level(self): GCVFeatureType.BLOCK: GCVFeatureType.PARA, GCVFeatureType.PARA: GCVFeatureType.WORD, GCVFeatureType.WORD: GCVFeatureType.SYMBOL, - GCVFeatureType.SYMBOL: None + GCVFeatureType.SYMBOL: None, } return child_cvt[self] class GCVAgent(BaseOCRAgent): - """A wrapper for `Google Cloud Vision (GCV) `_ Text - Detection APIs. + """A wrapper for `Google Cloud Vision (GCV) `_ Text + Detection APIs. Note: - Google Cloud Vision API returns the output text in two types: + Google Cloud Vision API returns the output text in two types: - * `text_annotations`: + * `text_annotations`: - In this format, GCV automatically find the best aggregation - level for the text, and return the results in a list. We use - :obj:`~gather_text_annotations` to reterive this type of + In this format, GCV automatically find the best aggregation + level for the text, and return the results in a list. We use + :obj:`~gather_text_annotations` to reterive this type of information. * `full_text_annotation`: - To support better user control, GCV also provides the - `full_text_annotation` output, where it returns the hierarchical - structure of the output text. To process this output, we provide - the :obj:`~gather_full_text_annotation` function to aggregate the - texts of the given aggregation level. + To support better user control, GCV also provides the + `full_text_annotation` output, where it returns the hierarchical + structure of the output text. To process this output, we provide + the :obj:`~gather_full_text_annotation` function to aggregate the + texts of the given aggregation level. """ - DEPENDENCIES = ['google-cloud-vision'] + DEPENDENCIES = ["google-cloud-vision"] MODULES = [ - { - "import_name": "_vision", - "module_path": "google.cloud.vision" - }, - { - "import_name": "_json_format", - "module_path": "google.protobuf.json_format" - }, + {"import_name": "_vision", "module_path": "google.cloud.vision"}, + {"import_name": "_json_format", "module_path": "google.protobuf.json_format"}, ] - def __init__(self, - languages=None, - ocr_image_decode_type='.png'): - """Create a Google Cloud Vision OCR Agent. + def __init__(self, languages=None, ocr_image_decode_type=".png"): + """Create a Google Cloud Vision OCR Agent. Args: - languages (:obj:`list`, optional): - You can specify the language code of the documents to detect to improve + languages (:obj:`list`, optional): + You can specify the language code of the documents to detect to improve accuracy. The supported language and their code can be found on `this page - `_. + `_. Defaults to None. - ocr_image_decode_type (:obj:`str`, optional): - The format to convert the input image to before sending for GCV OCR. + ocr_image_decode_type (:obj:`str`, optional): + The format to convert the input image to before sending for GCV OCR. Defaults to `".png"`. - * `".png"` is suggested as it does not compress the image. - * But `".jpg"` could also be a good choice if the input image is very large. + * `".png"` is suggested as it does not compress the image. + * But `".jpg"` could also be a good choice if the input image is very large. """ try: self._client = self._vision.ImageAnnotatorClient() except: warnings.warn( - "The GCV credential has not been set. You could not run the detect command.") - self._context = self._vision.types.ImageContext( - language_hints=languages) + "The GCV credential has not been set. You could not run the detect command." + ) + self._context = self._vision.types.ImageContext(language_hints=languages) self.ocr_image_decode_type = ocr_image_decode_type @classmethod @@ -188,41 +182,43 @@ def with_credential(cls, credential_path, **kwargs): Args: credential_path (:obj:`str`): The path to the credential file """ - os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = credential_path + os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = credential_path return cls(**kwargs) def _detect(self, img_content): img_content = self._vision.types.Image(content=img_content) response = self._client.document_text_detection( - image=img_content, - image_context=self._context) + image=img_content, image_context=self._context + ) return response - def detect(self, image, - return_response=False, - return_only_text=False, - agg_output_level=None): + def detect( + self, + image, + return_response=False, + return_only_text=False, + agg_output_level=None, + ): """Send the input image for OCR. Args: image (:obj:`np.ndarray` or :obj:`str`): The input image array or the name of the image file - return_response (:obj:`bool`, optional): - Whether directly return the google cloud response. + return_response (:obj:`bool`, optional): + Whether directly return the google cloud response. Defaults to `False`. - return_only_text (:obj:`bool`, optional): - Whether return only the texts in the OCR results. + return_only_text (:obj:`bool`, optional): + Whether return only the texts in the OCR results. Defaults to `False`. - agg_output_level (:obj:`~GCVFeatureType`, optional): - When set, aggregate the GCV output with respect to the + agg_output_level (:obj:`~GCVFeatureType`, optional): + When set, aggregate the GCV output with respect to the specified aggregation level. Defaults to `None`. """ if isinstance(image, np.ndarray): - img_content = imencode(self.ocr_image_decode_type, - image)[1].tostring() + img_content = imencode(self.ocr_image_decode_type, image)[1].tostring() elif isinstance(image, str): - with io.open(image, 'rb') as image_file: + with io.open(image, "rb") as image_file: img_content = image_file.read() res = self._detect(img_content) @@ -243,11 +239,11 @@ def gather_text_annotations(response): """Convert the text_annotations from GCV output to an :obj:`Layout` object. Args: - response (:obj:`AnnotateImageResponse`): + response (:obj:`AnnotateImageResponse`): The returned Google Cloud Vision AnnotateImageResponse object. Returns: - :obj:`Layout`: The reterived layout from the response. + :obj:`Layout`: The reterived layout from the response. """ # The 0th element contains all texts @@ -255,14 +251,9 @@ def gather_text_annotations(response): gathered_text = Layout() for i, text_comp in enumerate(doc): - points = _cvt_GCV_vertices_to_points( - text_comp.bounding_poly.vertices) + points = _cvt_GCV_vertices_to_points(text_comp.bounding_poly.vertices) gathered_text.append( - TextBlock( - block=Quadrilateral(points), - text=text_comp.description, - id=i - ) + TextBlock(block=Quadrilateral(points), text=text_comp.description, id=i) ) return gathered_text @@ -272,21 +263,23 @@ def gather_full_text_annotation(response, agg_level): """Convert the full_text_annotation from GCV output to an :obj:`Layout` object. Args: - response (:obj:`AnnotateImageResponse`): + response (:obj:`AnnotateImageResponse`): The returned Google Cloud Vision AnnotateImageResponse object. agg_level (:obj:`~GCVFeatureType`): The layout level to aggregate the text in full_text_annotation. Returns: - :obj:`Layout`: The reterived layout from the response. + :obj:`Layout`: The reterived layout from the response. """ - def iter_level(iter, - agg_level=None, - text_blocks=None, - texts=None, - cur_level=GCVFeatureType.PAGE): + def iter_level( + iter, + agg_level=None, + text_blocks=None, + texts=None, + cur_level=GCVFeatureType.PAGE, + ): for item in getattr(iter, cur_level.attr_name): if cur_level == agg_level: @@ -295,24 +288,28 @@ def iter_level(iter, # Go down levels to fetch the texts if cur_level == GCVFeatureType.SYMBOL: texts.append(item.text) - elif cur_level == GCVFeatureType.WORD and agg_level != GCVFeatureType.SYMBOL: + elif ( + cur_level == GCVFeatureType.WORD + and agg_level != GCVFeatureType.SYMBOL + ): chars = [] - iter_level(item, agg_level, text_blocks, - chars, cur_level.child_level) - texts.append(''.join(chars)) + iter_level( + item, agg_level, text_blocks, chars, cur_level.child_level + ) + texts.append("".join(chars)) else: - iter_level(item, agg_level, text_blocks, - texts, cur_level.child_level) + iter_level( + item, agg_level, text_blocks, texts, cur_level.child_level + ) if cur_level == agg_level: nonlocal element_id - points = _cvt_GCV_vertices_to_points( - item.bounding_box.vertices) + points = _cvt_GCV_vertices_to_points(item.bounding_box.vertices) text_block = TextBlock( block=Quadrilateral(points), - text=' '.join(texts), + text=" ".join(texts), score=item.confidence, - id=element_id + id=element_id, ) text_blocks.append(text_block) @@ -322,12 +319,7 @@ def iter_level(iter, doc = response.text_annotations[0] points = _cvt_GCV_vertices_to_points(doc.bounding_poly.vertices) - text_blocks = [ - TextBlock( - block=Quadrilateral(points), - text=doc.description - ) - ] + text_blocks = [TextBlock(block=Quadrilateral(points), text=doc.description)] else: doc = response.full_text_annotation @@ -338,17 +330,16 @@ def iter_level(iter, return Layout(text_blocks) def load_response(self, filename): - with open(filename, 'r') as f: + with open(filename, "r") as f: data = f.read() return self._json_format.Parse( - data, - self._vision.types.AnnotateImageResponse(), - ignore_unknown_fields=True) + data, self._vision.types.AnnotateImageResponse(), ignore_unknown_fields=True + ) def save_response(self, res, file_name): res = self._json_format.MessageToJson(res) - with open(file_name, 'w') as f: + with open(file_name, "w") as f: json_file = json.loads(res) json.dump(json_file, f) @@ -367,18 +358,18 @@ class TesseractFeatureType(BaseOCRElementType): @property def attr_name(self): name_cvt = { - TesseractFeatureType.PAGE: 'page_num', - TesseractFeatureType.BLOCK: 'block_num', - TesseractFeatureType.PARA: 'par_num', - TesseractFeatureType.LINE: 'line_num', - TesseractFeatureType.WORD: 'word_num' + TesseractFeatureType.PAGE: "page_num", + TesseractFeatureType.BLOCK: "block_num", + TesseractFeatureType.PARA: "par_num", + TesseractFeatureType.LINE: "line_num", + TesseractFeatureType.WORD: "word_num", } return name_cvt[self] @property def group_levels(self): - levels = ['page_num', 'block_num', 'par_num', 'line_num', 'word_num'] - return levels[:self+1] + levels = ["page_num", "block_num", "par_num", "line_num", "word_num"] + return levels[: self + 1] class TesseractAgent(BaseOCRAgent): @@ -387,29 +378,23 @@ class TesseractAgent(BaseOCRAgent): Detection APIs based on `PyTesseract `_. """ - DEPENDENCIES = ['pytesseract'] - MODULES = [ - { - "import_name": "_pytesseract", - "module_path": "pytesseract" - } - ] + DEPENDENCIES = ["pytesseract"] + MODULES = [{"import_name": "_pytesseract", "module_path": "pytesseract"}] - def __init__(self, languages='eng', **kwargs): - """Create a Tesseract OCR Agent. + def __init__(self, languages="eng", **kwargs): + """Create a Tesseract OCR Agent. Args: - languages (:obj:`list` or :obj:`str`, optional): - You can specify the language code(s) of the documents to detect to improve - accuracy. The supported language and their code can be found on - `its github repo `_. - It supports two formats: 1) you can pass in the languages code as a string + languages (:obj:`list` or :obj:`str`, optional): + You can specify the language code(s) of the documents to detect to improve + accuracy. The supported language and their code can be found on + `its github repo `_. + It supports two formats: 1) you can pass in the languages code as a string of format like `"eng+fra"`, or 2) you can pack them as a list of strings - `["eng", "fra"]`. + `["eng", "fra"]`. Defaults to 'eng'. """ - self.lang = languages if isinstance( - languages, str) else '+'.join(languages) + self.lang = languages if isinstance(languages, str) else "+".join(languages) self.configs = kwargs @classmethod @@ -420,33 +405,34 @@ def with_tesseract_executable(cls, tesseract_cmd_path, **kwargs): def _detect(self, img_content): res = {} - res['text'] = self._pytesseract.image_to_string( - img_content, lang=self.lang, **self.configs) + res["text"] = self._pytesseract.image_to_string( + img_content, lang=self.lang, **self.configs + ) _data = self._pytesseract.image_to_data( - img_content, lang=self.lang, **self.configs) - res['data'] = pd.read_csv(io.StringIO(_data), - quoting=csv.QUOTE_NONE, encoding='utf-8', - sep='\t') + img_content, lang=self.lang, **self.configs + ) + res["data"] = pd.read_csv( + io.StringIO(_data), quoting=csv.QUOTE_NONE, encoding="utf-8", sep="\t" + ) return res - def detect(self, image, - return_response=False, - return_only_text=True, - agg_output_level=None): + def detect( + self, image, return_response=False, return_only_text=True, agg_output_level=None + ): """Send the input image for OCR. Args: image (:obj:`np.ndarray` or :obj:`str`): The input image array or the name of the image file - return_response (:obj:`bool`, optional): - Whether directly return all output (string and boxes + return_response (:obj:`bool`, optional): + Whether directly return all output (string and boxes info) from Tesseract. Defaults to `False`. - return_only_text (:obj:`bool`, optional): - Whether return only the texts in the OCR results. + return_only_text (:obj:`bool`, optional): + Whether return only the texts in the OCR results. Defaults to `False`. - agg_output_level (:obj:`~TesseractFeatureType`, optional): - When set, aggregate the GCV output with respect to the + agg_output_level (:obj:`~TesseractFeatureType`, optional): + When set, aggregate the GCV output with respect to the specified aggregation level. Defaults to `None`. """ @@ -456,12 +442,12 @@ def detect(self, image, return res if return_only_text: - return res['text'] + return res["text"] if agg_output_level is not None: return self.gather_data(res, agg_output_level) - return res['text'] + return res["text"] @staticmethod def gather_data(response, agg_level): @@ -470,34 +456,51 @@ def gather_data(response, agg_level): in a given aggeragation level. """ assert isinstance( - agg_level, TesseractFeatureType), f"Invalid agg_level {agg_level}" - res = response['data'] - df = res[~res.text.isna()].\ - groupby(agg_level.group_levels).\ - apply(lambda gp: pd.Series([ - gp['left'].min(), - gp['top'].min(), - gp['width'].max(), - gp['height'].max(), - gp['conf'].mean(), - gp['text'].str.cat(sep=' ') - ])).\ - reset_index(drop=True).\ - reset_index().\ - rename(columns={0: 'x_1', 1: 'y_1', 2: 'w', 3: 'h', 4: 'score', 5: 'text', 'index': 'id'}).\ - assign(x_2=lambda x: x.x_1 + x.w, y_2=lambda x: x.y_1 + x.h, type=None).\ - drop(columns=['w', 'h']) - - return Layout.from_dataframe(df) + agg_level, TesseractFeatureType + ), f"Invalid agg_level {agg_level}" + res = response["data"] + df = ( + res[~res.text.isna()] + .groupby(agg_level.group_levels) + .apply( + lambda gp: pd.Series( + [ + gp["left"].min(), + gp["top"].min(), + gp["width"].max(), + gp["height"].max(), + gp["conf"].mean(), + gp["text"].str.cat(sep=" "), + ] + ) + ) + .reset_index(drop=True) + .reset_index() + .rename( + columns={ + 0: "x_1", + 1: "y_1", + 2: "w", + 3: "h", + 4: "score", + 5: "text", + "index": "id", + } + ) + .assign(x_2=lambda x: x.x_1 + x.w, y_2=lambda x: x.y_1 + x.h, block_type="rectangle") + .drop(columns=["w", "h"]) + ) + + return load_dataframe(df) @staticmethod def load_response(filename): - with open(filename, 'rb') as fp: + with open(filename, "rb") as fp: res = pickle.load(fp) return res @staticmethod def save_response(res, file_name): - with open(file_name, 'wb') as fp: + with open(file_name, "wb") as fp: pickle.dump(res, fp, protocol=pickle.HIGHEST_PROTOCOL) diff --git a/src/layoutparser/visualization.py b/src/layoutparser/visualization.py index 111af4d..9a79f2d 100644 --- a/src/layoutparser/visualization.py +++ b/src/layoutparser/visualization.py @@ -10,26 +10,31 @@ # We need to fix this ugly hack some time in the future _lib_path = os.path.dirname(sys.modules[layoutparser.__package__].__file__) -_font_path = os.path.join(_lib_path, 'misc', 'NotoSerifCJKjp-Regular.otf') +_font_path = os.path.join(_lib_path, "misc", "NotoSerifCJKjp-Regular.otf") DEFAULT_BOX_WIDTH_RATIO = 0.005 -DEFAULT_OUTLINE_COLOR = 'red' +DEFAULT_OUTLINE_COLOR = "red" DEAFULT_COLOR_PALETTE = "#f6bd60-#f7ede2-#f5cac3-#84a59d-#f28482" # From https://coolors.co/f6bd60-f7ede2-f5cac3-84a59d-f28482 DEFAULT_FONT_PATH = _font_path DEFAULT_FONT_SIZE = 15 DEFAULT_FONT_OBJECT = ImageFont.truetype(DEFAULT_FONT_PATH, DEFAULT_FONT_SIZE) -DEFAULT_TEXT_COLOR = 'black' -DEFAULT_TEXT_BACKGROUND = 'white' +DEFAULT_TEXT_COLOR = "black" +DEFAULT_TEXT_BACKGROUND = "white" -__all__ = ['draw_box', 'draw_text'] +__all__ = ["draw_box", "draw_text"] -def _draw_vertical_text(text, image_font, - text_color, text_background_color, - character_spacing=2, space_width=1): - """Helper function to draw text vertically. +def _draw_vertical_text( + text, + image_font, + text_color, + text_background_color, + character_spacing=2, + space_width=1, +): + """Helper function to draw text vertically. Ref: https://github.com/Belval/TextRecognitionDataGenerator/blob/7f4c782c33993d2b6f712d01e86a2f342025f2df/trdg/computer_text_generator.py """ @@ -41,10 +46,8 @@ def _draw_vertical_text(text, image_font, text_width = max([image_font.getsize(c)[0] for c in text]) text_height = sum(char_heights) + character_spacing * len(text) - txt_img = Image.new("RGB", (text_width, text_height), - color=text_background_color) - txt_mask = Image.new("RGB", (text_width, text_height), - color=text_background_color) + txt_img = Image.new("RGB", (text_width, text_height), color=text_background_color) + txt_mask = Image.new("RGB", (text_width, text_height), color=text_background_color) txt_img_draw = ImageDraw.Draw(txt_img) txt_mask_draw = ImageDraw.Draw(txt_mask) @@ -70,21 +73,26 @@ def _create_font_object(font_size=None, font_path=None): return DEFAULT_FONT_OBJECT else: return ImageFont.truetype( - font_path or DEFAULT_FONT_PATH, - font_size or DEFAULT_FONT_SIZE + font_path or DEFAULT_FONT_PATH, font_size or DEFAULT_FONT_SIZE ) def _create_new_canvas(canvas, arrangement, text_background_color): - if arrangement == 'lr': - new_canvas = Image.new('RGB', (canvas.width*2, canvas.height), - color=text_background_color or DEFAULT_TEXT_BACKGROUND) + if arrangement == "lr": + new_canvas = Image.new( + "RGB", + (canvas.width * 2, canvas.height), + color=text_background_color or DEFAULT_TEXT_BACKGROUND, + ) new_canvas.paste(canvas, (canvas.width, 0)) - elif arrangement == 'ud': - new_canvas = Image.new('RGB', (canvas.width, canvas.height*2), - color=text_background_color or DEFAULT_TEXT_BACKGROUND) + elif arrangement == "ud": + new_canvas = Image.new( + "RGB", + (canvas.width, canvas.height * 2), + color=text_background_color or DEFAULT_TEXT_BACKGROUND, + ) new_canvas.paste(canvas, (0, canvas.height)) else: @@ -94,7 +102,10 @@ def _create_new_canvas(canvas, arrangement, text_background_color): def _create_color_palette(types): - return {type: color for type, color in zip(types, cycle(DEAFULT_COLOR_PALETTE.split('-')))} + return { + type: color + for type, color in zip(types, cycle(DEAFULT_COLOR_PALETTE.split("-"))) + } def image_loader(func): @@ -102,8 +113,8 @@ def image_loader(func): def wrap(canvas, layout, *args, **kwargs): if isinstance(canvas, Image.Image): - if canvas.mode != 'RGB': - canvas = canvas.convert('RGB') + if canvas.mode != "RGB": + canvas = canvas.convert("RGB") canvas = canvas.copy() elif isinstance(canvas, np.ndarray): canvas = Image.fromarray(canvas) @@ -114,54 +125,57 @@ def wrap(canvas, layout, *args, **kwargs): @image_loader -def draw_box(canvas, layout, - box_width=None, - color_map=None, - show_element_id=False, - id_font_size=None, - id_font_path=None, - id_text_color=None, - id_text_background_color=None): - """Draw the layout region on the input canvas(image). +def draw_box( + canvas, + layout, + box_width=None, + color_map=None, + show_element_id=False, + id_font_size=None, + id_font_path=None, + id_text_color=None, + id_text_background_color=None, +): + """Draw the layout region on the input canvas(image). Args: - canvas (:obj:`~np.ndarray` or :obj:`~PIL.Image.Image`): - The canvas to draw the layout boxes. - layout (:obj:`Layout` or :obj:`list`): - The layout of the canvas to show. - box_width (:obj:`int`, optional): + canvas (:obj:`~np.ndarray` or :obj:`~PIL.Image.Image`): + The canvas to draw the layout boxes. + layout (:obj:`Layout` or :obj:`list`): + The layout of the canvas to show. + box_width (:obj:`int`, optional): Set to change the width of the drawn layout box boundary. - Defaults to None, when the boundary is automatically - calculated as the the :const:`DEFAULT_BOX_WIDTH_RATIO` - * the maximum of (height, width) of the canvas. - color_map (dict, optional): - A map from `block.type` to the colors, e.g., `{1: 'red'}`. - You can set it to `{}` to use only the + Defaults to None, when the boundary is automatically + calculated as the the :const:`DEFAULT_BOX_WIDTH_RATIO` + * the maximum of (height, width) of the canvas. + color_map (dict, optional): + A map from `block.type` to the colors, e.g., `{1: 'red'}`. + You can set it to `{}` to use only the :const:`DEFAULT_OUTLINE_COLOR` for the outlines. - Defaults to None, when a color palette is is automatically + Defaults to None, when a color palette is is automatically created based on the input layout. - show_element_id (bool, optional): - Whether to display `block.id` on the top-left corner of - the block. + show_element_id (bool, optional): + Whether to display `block.id` on the top-left corner of + the block. Defaults to False. - id_font_size (int, optional): + id_font_size (int, optional): Set to change the font size used for drawing `block.id`. - Defaults to None, when the size is set to - :const:`DEFAULT_FONT_SIZE`. - id_font_path (:obj:`str`, optional): + Defaults to None, when the size is set to + :const:`DEFAULT_FONT_SIZE`. + id_font_path (:obj:`str`, optional): Set to change the font used for drawing `block.id`. - Defaults to None, when the :const:`DEFAULT_FONT_OBJECT` is used. - id_text_color (:obj:`str`, optional): + Defaults to None, when the :const:`DEFAULT_FONT_OBJECT` is used. + id_text_color (:obj:`str`, optional): Set to change the text color used for drawing `block.id`. - Defaults to None, when the color is set to + Defaults to None, when the color is set to :const:`DEFAULT_TEXT_COLOR`. - id_text_background_color (:obj:`str`, optional): + id_text_background_color (:obj:`str`, optional): Set to change the text region background used for drawing `block.id`. - Defaults to None, when the color is set to + Defaults to None, when the color is set to :const:`DEFAULT_TEXT_BACKGROUND`. Returns: - :obj:`PIL.Image.Image`: - A Image object containing the `layout` draw upon the input `canvas`. + :obj:`PIL.Image.Image`: + A Image object containing the `layout` draw upon the input `canvas`. """ draw = ImageDraw.Draw(canvas) @@ -173,7 +187,7 @@ def draw_box(canvas, layout, font_obj = _create_font_object(id_font_size, id_font_path) if color_map is None: - all_types = set([b.type for b in layout if hasattr(b, 'type')]) + all_types = set([b.type for b in layout if hasattr(b, "type")]) color_map = _create_color_palette(all_types) for idx, ele in enumerate(layout): @@ -181,106 +195,114 @@ def draw_box(canvas, layout, if isinstance(ele, Interval): ele = ele.put_on_canvas(canvas) - outline_color = DEFAULT_OUTLINE_COLOR if not isinstance(ele, TextBlock) \ + outline_color = ( + DEFAULT_OUTLINE_COLOR + if not isinstance(ele, TextBlock) else color_map.get(ele.type, DEFAULT_OUTLINE_COLOR) + ) if not isinstance(ele, Quadrilateral): - draw.rectangle(ele.coordinates, width=box_width, - outline=outline_color) + draw.rectangle(ele.coordinates, width=box_width, outline=outline_color) else: p = ele.points.ravel().tolist() - draw.line(p+p[:2], width=box_width, - fill=outline_color) + draw.line(p + p[:2], width=box_width, fill=outline_color) if show_element_id: ele_id = ele.id or idx start_x, start_y = ele.coordinates[:2] - text_w, text_h = font_obj.getsize(f'{ele_id}') + text_w, text_h = font_obj.getsize(f"{ele_id}") # Add a small background for the text - draw.rectangle((start_x, start_y, start_x + text_w, start_y + text_h), - fill=id_text_background_color or DEFAULT_TEXT_BACKGROUND) + draw.rectangle( + (start_x, start_y, start_x + text_w, start_y + text_h), + fill=id_text_background_color or DEFAULT_TEXT_BACKGROUND, + ) # Draw the ids - draw.text((start_x, start_y), f'{ele_id}', - fill=id_text_color or DEFAULT_TEXT_COLOR, - font=font_obj) + draw.text( + (start_x, start_y), + f"{ele_id}", + fill=id_text_color or DEFAULT_TEXT_COLOR, + font=font_obj, + ) return canvas @image_loader -def draw_text(canvas, layout, - arrangement='lr', - font_size=None, - font_path=None, - text_color=None, - text_background_color=None, - vertical_text=False, - with_box_on_text=False, - text_box_width=None, - text_box_color=None, - with_layout=False, - **kwargs - ): +def draw_text( + canvas, + layout, + arrangement="lr", + font_size=None, + font_path=None, + text_color=None, + text_background_color=None, + vertical_text=False, + with_box_on_text=False, + text_box_width=None, + text_box_color=None, + with_layout=False, + **kwargs, +): """Draw the (detected) text in the `layout` according to - their coordinates next to the input `canvas` (image) for better comparison. + their coordinates next to the input `canvas` (image) for better comparison. Args: - canvas (:obj:`~np.ndarray` or :obj:`~PIL.Image.Image`): - The canvas to draw the layout boxes. - layout (:obj:`Layout` or :obj:`list`): - The layout of the canvas to show. - arrangement (`{'lr', 'ud'}`, optional): + canvas (:obj:`~np.ndarray` or :obj:`~PIL.Image.Image`): + The canvas to draw the layout boxes. + layout (:obj:`Layout` or :obj:`list`): + The layout of the canvas to show. + arrangement (`{'lr', 'ud'}`, optional): The arrangement of the drawn text canvas and the original - image canvas: - * `lr` - left and right + image canvas: + * `lr` - left and right * `ud` - up and down Defaults to 'lr'. font_size (:obj:`str`, optional): - Set to change the size of the font used for + Set to change the size of the font used for drawing `block.text`. - Defaults to None, when the size is set to - :const:`DEFAULT_FONT_SIZE`. - font_path (:obj:`str`, optional): + Defaults to None, when the size is set to + :const:`DEFAULT_FONT_SIZE`. + font_path (:obj:`str`, optional): Set to change the font used for drawing `block.text`. - Defaults to None, when the :const:`DEFAULT_FONT_OBJECT` is used. - text_color ([type], optional): + Defaults to None, when the :const:`DEFAULT_FONT_OBJECT` is used. + text_color ([type], optional): Set to change the text color used for drawing `block.text`. - Defaults to None, when the color is set to + Defaults to None, when the color is set to :const:`DEFAULT_TEXT_COLOR`. - text_background_color ([type], optional): - Set to change the text region background used for drawing + text_background_color ([type], optional): + Set to change the text region background used for drawing `block.text`. - Defaults to None, when the color is set to + Defaults to None, when the color is set to :const:`DEFAULT_TEXT_BACKGROUND`. - vertical_text (bool, optional): - Whether the text in a block should be drawn vertically. + vertical_text (bool, optional): + Whether the text in a block should be drawn vertically. Defaults to False. - with_box_on_text (bool, optional): - Whether to draw the layout box boundary of a text region + with_box_on_text (bool, optional): + Whether to draw the layout box boundary of a text region on the text canvas. Defaults to False. text_box_width (:obj:`int`, optional): Set to change the width of the drawn layout box boundary. - Defaults to None, when the boundary is automatically - calculated as the the :const:`DEFAULT_BOX_WIDTH_RATIO` - * the maximum of (height, width) of the canvas. + Defaults to None, when the boundary is automatically + calculated as the the :const:`DEFAULT_BOX_WIDTH_RATIO` + * the maximum of (height, width) of the canvas. text_box_color (:obj:`int`, optional): Set to change the color of the drawn layout box boundary. - Defaults to None, when the color is set to + Defaults to None, when the color is set to :const:`DEFAULT_OUTLINE_COLOR`. - with_layout (bool, optional): - Whether to draw the layout boxes on the input (image) canvas. + with_layout (bool, optional): + Whether to draw the layout boxes on the input (image) canvas. Defaults to False. - When set to true, you can pass in the arguments in + When set to true, you can pass in the arguments in :obj:`draw_box` to change the style of the drawn layout boxes. Returns: - :obj:`PIL.Image.Image`: + :obj:`PIL.Image.Image`: A Image object containing the drawn text from `layout`. """ if with_box_on_text: @@ -301,29 +323,30 @@ def draw_text(canvas, layout, for idx, ele in enumerate(layout): if with_box_on_text: - p = ele.pad(right=text_box_width, - bottom=text_box_width).points.ravel().tolist() + p = ( + ele.pad(right=text_box_width, bottom=text_box_width) + .points.ravel() + .tolist() + ) - draw.line(p+p[:2], - width=text_box_width, - fill=text_box_color) + draw.line(p + p[:2], width=text_box_width, fill=text_box_color) - if not hasattr(ele, 'text') or ele.text == '': + if not hasattr(ele, "text") or ele.text == "": continue (start_x, start_y) = ele.coordinates[:2] if not vertical_text: - draw.text((start_x, start_y), ele.text, - font=font_obj, - fill=text_color) + draw.text((start_x, start_y), ele.text, font=font_obj, fill=text_color) else: text_segment = _draw_vertical_text( - ele.text, font_obj, text_color, text_background_color) + ele.text, font_obj, text_color, text_background_color + ) if with_box_on_text: # Avoid cover the box regions canvas.paste( - text_segment, (start_x+text_box_width, start_y+text_box_width)) + text_segment, (start_x + text_box_width, start_y + text_box_width) + ) else: canvas.paste(text_segment, (start_x, start_y)) diff --git a/tests/fixtures/io/generate_test_jsons.py b/tests/fixtures/io/generate_test_jsons.py new file mode 100644 index 0000000..e0c5183 --- /dev/null +++ b/tests/fixtures/io/generate_test_jsons.py @@ -0,0 +1,35 @@ +import json +import numpy as np +from layoutparser.elements import Interval, Rectangle, Quadrilateral, TextBlock, Layout + +if __name__ == "__main__": + + i = Interval(1, 2, "y", canvas_height=5) + r = Rectangle(1, 2, 3, 4) + q = Quadrilateral(np.arange(8).reshape(4, 2), 200, 400) + l = Layout([i, r, q], page_data={"width": 200, "height": 200}) + + with open("interval.json", "w") as fp: + json.dump(i.to_dict(), fp) + with open("rectangle.json", "w") as fp: + json.dump(r.to_dict(), fp) + with open("quadrilateral.json", "w") as fp: + json.dump(q.to_dict(), fp) + with open("layout.json", "w") as fp: + json.dump(l.to_dict(), fp) + l.to_dataframe().to_csv("layout.csv", index=None) + + i2 = TextBlock(i, "") + r2 = TextBlock(r, id=24) + q2 = TextBlock(q, text="test", parent=45) + l2 = Layout([i2, r2, q2]) + + with open("interval_textblock.json", "w") as fp: + json.dump(i2.to_dict(), fp) + with open("rectangle_textblock.json", "w") as fp: + json.dump(r2.to_dict(), fp) + with open("quadrilateral_textblock.json", "w") as fp: + json.dump(q2.to_dict(), fp) + with open("layout_textblock.json", "w") as fp: + json.dump(l2.to_dict(), fp) + l2.to_dataframe().to_csv("layout_textblock.csv", index=None) \ No newline at end of file diff --git a/tests/fixtures/io/interval.json b/tests/fixtures/io/interval.json new file mode 100644 index 0000000..ec9b27e --- /dev/null +++ b/tests/fixtures/io/interval.json @@ -0,0 +1 @@ +{"start": 1, "end": 2, "axis": "y", "canvas_height": 5, "canvas_width": 0, "block_type": "interval"} \ No newline at end of file diff --git a/tests/fixtures/io/interval_textblock.json b/tests/fixtures/io/interval_textblock.json new file mode 100644 index 0000000..92bc529 --- /dev/null +++ b/tests/fixtures/io/interval_textblock.json @@ -0,0 +1 @@ +{"start": 1, "end": 2, "axis": "y", "canvas_height": 5, "canvas_width": 0, "block_type": "interval", "text": ""} \ No newline at end of file diff --git a/tests/fixtures/io/layout.csv b/tests/fixtures/io/layout.csv new file mode 100644 index 0000000..d6992b9 --- /dev/null +++ b/tests/fixtures/io/layout.csv @@ -0,0 +1,4 @@ +start,end,axis,canvas_height,canvas_width,block_type,x_1,y_1,x_2,y_2,points,height,width +1.0,2.0,y,5.0,0.0,interval,,,,,,, +,,,,,rectangle,1.0,2.0,3.0,4.0,,, +,,,,,quadrilateral,,,,,"[0, 1, 2, 3, 4, 5, 6, 7]",200.0,400.0 diff --git a/tests/fixtures/io/layout.json b/tests/fixtures/io/layout.json new file mode 100644 index 0000000..bd04ba7 --- /dev/null +++ b/tests/fixtures/io/layout.json @@ -0,0 +1 @@ +{"page_data": {"width": 200, "height": 200}, "blocks": [{"start": 1, "end": 2, "axis": "y", "canvas_height": 5, "canvas_width": 0, "block_type": "interval"}, {"x_1": 1, "y_1": 2, "x_2": 3, "y_2": 4, "block_type": "rectangle"}, {"points": [0, 1, 2, 3, 4, 5, 6, 7], "height": 200, "width": 400, "block_type": "quadrilateral"}]} \ No newline at end of file diff --git a/tests/fixtures/io/layout_textblock.csv b/tests/fixtures/io/layout_textblock.csv new file mode 100644 index 0000000..708a6af --- /dev/null +++ b/tests/fixtures/io/layout_textblock.csv @@ -0,0 +1,4 @@ +start,end,axis,canvas_height,canvas_width,block_type,text,x_1,y_1,x_2,y_2,id,points,height,width,parent +1.0,2.0,y,5.0,0.0,interval,,,,,,,,,, +,,,,,rectangle,,1.0,2.0,3.0,4.0,24.0,,,, +,,,,,quadrilateral,test,,,,,,"[0, 1, 2, 3, 4, 5, 6, 7]",200.0,400.0,45.0 diff --git a/tests/fixtures/io/layout_textblock.json b/tests/fixtures/io/layout_textblock.json new file mode 100644 index 0000000..9b23079 --- /dev/null +++ b/tests/fixtures/io/layout_textblock.json @@ -0,0 +1 @@ +{"page_data": {}, "blocks": [{"start": 1, "end": 2, "axis": "y", "canvas_height": 5, "canvas_width": 0, "block_type": "interval", "text": ""}, {"x_1": 1, "y_1": 2, "x_2": 3, "y_2": 4, "block_type": "rectangle", "id": 24}, {"points": [0, 1, 2, 3, 4, 5, 6, 7], "height": 200, "width": 400, "block_type": "quadrilateral", "text": "test", "parent": 45}]} \ No newline at end of file diff --git a/tests/fixtures/io/quadrilateral.json b/tests/fixtures/io/quadrilateral.json new file mode 100644 index 0000000..4490df5 --- /dev/null +++ b/tests/fixtures/io/quadrilateral.json @@ -0,0 +1 @@ +{"points": [0, 1, 2, 3, 4, 5, 6, 7], "height": 200, "width": 400, "block_type": "quadrilateral"} \ No newline at end of file diff --git a/tests/fixtures/io/quadrilateral_textblock.json b/tests/fixtures/io/quadrilateral_textblock.json new file mode 100644 index 0000000..148b71a --- /dev/null +++ b/tests/fixtures/io/quadrilateral_textblock.json @@ -0,0 +1 @@ +{"points": [0, 1, 2, 3, 4, 5, 6, 7], "height": 200, "width": 400, "block_type": "quadrilateral", "text": "test", "parent": 45} \ No newline at end of file diff --git a/tests/fixtures/io/rectangle.json b/tests/fixtures/io/rectangle.json new file mode 100644 index 0000000..ec57aeb --- /dev/null +++ b/tests/fixtures/io/rectangle.json @@ -0,0 +1 @@ +{"x_1": 1, "y_1": 2, "x_2": 3, "y_2": 4, "block_type": "rectangle"} \ No newline at end of file diff --git a/tests/fixtures/io/rectangle_textblock.json b/tests/fixtures/io/rectangle_textblock.json new file mode 100644 index 0000000..bf39ad3 --- /dev/null +++ b/tests/fixtures/io/rectangle_textblock.json @@ -0,0 +1 @@ +{"x_1": 1, "y_1": 2, "x_2": 3, "y_2": 4, "block_type": "rectangle", "id": 24} \ No newline at end of file diff --git a/tests/source/config.yml b/tests/fixtures/model/config.yml similarity index 100% rename from tests/source/config.yml rename to tests/fixtures/model/config.yml diff --git a/tests/source/test_gcv_image.jpg b/tests/fixtures/model/test_model_image.jpg similarity index 100% rename from tests/source/test_gcv_image.jpg rename to tests/fixtures/model/test_model_image.jpg diff --git a/tests/fixtures/ocr/test_gcv_image.jpg b/tests/fixtures/ocr/test_gcv_image.jpg new file mode 100644 index 0000000..942775d Binary files /dev/null and b/tests/fixtures/ocr/test_gcv_image.jpg differ diff --git a/tests/source/test_gcv_response.json b/tests/fixtures/ocr/test_gcv_response.json similarity index 100% rename from tests/source/test_gcv_response.json rename to tests/fixtures/ocr/test_gcv_response.json diff --git a/tests/source/test_tesseract_response.pickle b/tests/fixtures/ocr/test_tesseract_response.pickle similarity index 100% rename from tests/source/test_tesseract_response.pickle rename to tests/fixtures/ocr/test_tesseract_response.pickle diff --git a/tests/test_elements.py b/tests/test_elements.py index 4e4e2dd..27faca7 100644 --- a/tests/test_elements.py +++ b/tests/test_elements.py @@ -1,134 +1,158 @@ -from layoutparser.elements import Interval, Rectangle, Quadrilateral, TextBlock, Layout +import pytest import numpy as np import pandas as pd +from layoutparser.elements import Interval, Rectangle, Quadrilateral, TextBlock, Layout + + def test_interval(): - - i = Interval(1, 2, axis='y', canvas_height=30, canvas_width=400) + + i = Interval(1, 2, axis="y", canvas_height=30, canvas_width=400) i.to_rectangle() i.to_quadrilateral() - assert i.shift(1) == Interval(2, 3, axis='y', canvas_height=30, canvas_width=400) - assert i.area == 1*400 - - i = Interval(1, 2, axis='x') - assert i.shift([1,2]) == Interval(2, 3, axis='x') - assert i.scale([2,1]) == Interval(2, 4, axis='x') - assert i.pad(left=10, right=20) == Interval(0, 22) # Test the safe_mode - assert i.pad(left=10, right=20, safe_mode=False) == Interval(-9, 22) + assert i.shift(1) == Interval(2, 3, axis="y", canvas_height=30, canvas_width=400) + assert i.area == 1 * 400 + + i = Interval(1, 2, axis="x") + assert i.shift([1, 2]) == Interval(2, 3, axis="x") + assert i.scale([2, 1]) == Interval(2, 4, axis="x") + assert i.pad(left=10, right=20) == Interval(0, 22, axis="x") # Test the safe_mode + assert i.pad(left=10, right=20, safe_mode=False) == Interval(-9, 22, axis="x") assert i.area == 0 - - img = np.random.randint(12, 24, (40,40)) + + img = np.random.randint(12, 24, (40, 40)) img[:, 10:20] = 0 - i = Interval(5, 11, axis='x') + i = Interval(5, 11, axis="x") assert np.unique(i.crop_image(img)[:, -1]) == np.array([0]) - + + def test_rectangle(): - + r = Rectangle(1, 2, 3, 4) - r.to_interval() + r.to_interval(axis="x") r.to_quadrilateral() assert r.pad(left=1, right=5, top=2, bottom=4) == Rectangle(0, 0, 8, 8) - assert r.shift([1,2]) == Rectangle(2, 4, 4, 6) + assert r.shift([1, 2]) == Rectangle(2, 4, 4, 6) assert r.shift(1) == Rectangle(2, 3, 4, 5) - assert r.scale([3,2]) == Rectangle(3, 4, 9, 8) + assert r.scale([3, 2]) == Rectangle(3, 4, 9, 8) assert r.scale(2) == Rectangle(2, 4, 6, 8) assert r.area == 4 - - img = np.random.randint(12, 24, (40,40)) + + img = np.random.randint(12, 24, (40, 40)) r.crop_image(img).shape == (2, 2) - + + def test_quadrilateral(): - - points = np.array([[2, 2], [6, 2], [6,7], [2,6]]) + + points = np.array([[2, 2], [6, 2], [6, 7], [2, 6]]) q = Quadrilateral(points) q.to_interval() q.to_rectangle() assert q.shift(1) == Quadrilateral(points + 1) - assert q.shift([1,2]) == Quadrilateral(points + np.array([1,2])) + assert q.shift([1, 2]) == Quadrilateral(points + np.array([1, 2])) assert q.scale(2) == Quadrilateral(points * 2) - assert q.scale([3,2]) == Quadrilateral(points * np.array([3,2])) - assert q.pad(left=1, top=2, bottom=4) == Quadrilateral(np.array([[1, 0], [6, 0], [6, 11], [1, 10]])) - assert (q.mapped_rectangle_points == np.array([[0,0],[4,0],[4,5],[0,5]])).all() + assert q.scale([3, 2]) == Quadrilateral(points * np.array([3, 2])) + assert q.pad(left=1, top=2, bottom=4) == Quadrilateral( + np.array([[1, 0], [6, 0], [6, 11], [1, 10]]) + ) + assert ( + q.mapped_rectangle_points == np.array([[0, 0], [4, 0], [4, 5], [0, 5]]) + ).all() - points = np.array([[2, 2], [6, 2], [6,5], [2,5]]) + points = np.array([[2, 2], [6, 2], [6, 5], [2, 5]]) q = Quadrilateral(points) - img = np.random.randint(2, 24, (30, 20)).astype('uint8') + img = np.random.randint(2, 24, (30, 20)).astype("uint8") img[2:5, 2:6] = 0 assert np.unique(q.crop_image(img)) == np.array([0]) - + q = Quadrilateral(np.array([[-2, 0], [0, 2], [2, 0], [0, -2]])) - assert q.area == 8. - + assert q.area == 8.0 + + q = Quadrilateral([1, 2, 3, 4, 5, 6, 7, 8]) + assert (q.points == np.array([[1, 2], [3, 4], [5, 6], [7, 8]])).all() + + with pytest.raises(ValueError): + Quadrilateral([1, 2, 3, 4, 5, 6, 7]) # Incompatible list length + + with pytest.raises(ValueError): + Quadrilateral(np.array([[2, 2], [6, 2], [6, 5]])) # Incompatible ndarray shape + + def test_interval_relations(): - - i = Interval(4, 5, axis='y') + + i = Interval(4, 5, axis="y") r = Rectangle(3, 3, 5, 6) - q = Quadrilateral(np.array([[2,2],[6,2],[6,7],[2,5]])) - + q = Quadrilateral(np.array([[2, 2], [6, 2], [6, 7], [2, 5]])) + assert i.is_in(i) assert i.is_in(r) assert i.is_in(q) - + # convert to absolute then convert back to relative assert i.condition_on(i).relative_to(i) == i assert i.condition_on(r).relative_to(r) == i.put_on_canvas(r).to_rectangle() assert i.condition_on(q).relative_to(q) == i.put_on_canvas(q).to_quadrilateral() - + # convert to relative then convert back to absolute assert i.relative_to(i).condition_on(i) == i assert i.relative_to(r).condition_on(r) == i.put_on_canvas(r).to_rectangle() assert i.relative_to(q).condition_on(q) == i.put_on_canvas(q).to_quadrilateral() - + + def test_rectangle_relations(): - - i = Interval(4, 5, axis='y') - q = Quadrilateral(np.array([[2,2],[6,2],[6,7],[2,5]])) + + i = Interval(4, 5, axis="y") + q = Quadrilateral(np.array([[2, 2], [6, 2], [6, 7], [2, 5]])) r = Rectangle(3, 3, 5, 6) - + assert not r.is_in(q) - assert r.is_in(q, soft_margin={"bottom":1}) + assert r.is_in(q, soft_margin={"bottom": 1}) assert r.is_in(q.to_rectangle()) assert r.is_in(q.to_interval()) - + # convert to absolute then convert back to relative assert r.condition_on(i).relative_to(i) == r assert r.condition_on(r).relative_to(r) == r assert r.condition_on(q).relative_to(q) == r.to_quadrilateral() - + # convert to relative then convert back to absolute assert r.relative_to(i).condition_on(i) == r - assert r.relative_to(r).condition_on(r) == r + assert r.relative_to(r).condition_on(r) == r assert r.relative_to(q).condition_on(q) == r.to_quadrilateral() - + + def test_quadrilateral_relations(): - - i = Interval(4, 5, axis='y') - q = Quadrilateral(np.array([[2,2],[6,2],[6,7],[2,5]])) + + i = Interval(4, 5, axis="y") + q = Quadrilateral(np.array([[2, 2], [6, 2], [6, 7], [2, 5]])) r = Rectangle(3, 3, 5, 6) - + assert not q.is_in(r) - assert q.is_in(i, soft_margin={"top":2, "bottom":2}) - assert q.is_in(r, soft_margin={"left":1, "top":1, "right":1,"bottom":1}) + assert q.is_in(i, soft_margin={"top": 2, "bottom": 2}) + assert q.is_in(r, soft_margin={"left": 1, "top": 1, "right": 1, "bottom": 1}) assert q.is_in(q) - + # convert to absolute then convert back to relative assert q.condition_on(i).relative_to(i) == q assert q.condition_on(r).relative_to(r) == q assert q.condition_on(q).relative_to(q) == q - + # convert to relative then convert back to absolute assert q.relative_to(i).condition_on(i) == q assert q.relative_to(r).condition_on(r) == q assert q.relative_to(q).condition_on(q) == q + def test_textblock(): - - i = Interval(4, 5, axis='y') - q = Quadrilateral(np.array([[2,2],[6,2],[6,7],[2,5]])) + + i = Interval(4, 5, axis="y") + q = Quadrilateral(np.array([[2, 2], [6, 2], [6, 7], [2, 5]])) r = Rectangle(3, 3, 5, 6) - + t = TextBlock(i, id=1, type=2, text="12") - assert t.relative_to(q).condition_on(q).block == i.put_on_canvas(q).to_quadrilateral() + assert ( + t.relative_to(q).condition_on(q).block == i.put_on_canvas(q).to_quadrilateral() + ) t.area t = TextBlock(r, id=1, type=2, parent="a") assert t.relative_to(i).condition_on(i).block == r @@ -136,160 +160,119 @@ def test_textblock(): t = TextBlock(q, id=1, type=2, parent="a") assert t.relative_to(r).condition_on(r).block == q t.area - + # Ensure the operations did not change the object itself assert t == TextBlock(q, id=1, type=2, parent="a") t1 = TextBlock(q, id=1, type=2, parent="a") t2 = TextBlock(i, id=1, type=2, text="12") t1.relative_to(t2) assert t2.is_in(t1) - + t = TextBlock(q, score=0.2) + def test_layout(): - i = Interval(4, 5, axis='y') - q = Quadrilateral(np.array([[2,2],[6,2],[6,7],[2,5]])) + i = Interval(4, 5, axis="y") + q = Quadrilateral(np.array([[2, 2], [6, 2], [6, 7], [2, 5]])) r = Rectangle(3, 3, 5, 6) t = TextBlock(i, id=1, type=2, text="12") - + l = Layout([i, q, r]) l.get_texts() l.condition_on(i) l.relative_to(q) l.filter_by(t) l.is_in(r) - - l = Layout([ - TextBlock(i, id=1, type=2, text="12"), - TextBlock(r, id=1, type=2, parent="a"), - TextBlock(q, id=1, type=2, next="a") - ]) - l.get_texts() - l.get_info('next') - l.condition_on(i) - l.relative_to(q) - l.filter_by(t) - l.is_in(r) - - l.scale(4) - l.shift(4) - l.pad(left=2) - -def test_df(): - - df = pd.DataFrame( - columns=\ - ["_identifier", "x_1", "y_1", "x_2", "y_2", "height", "width", "p11", "p12", "p21", "p22", "p31", "p32", "p41", "p42"], - data=[ - ['_interval', None, 10, None, 12, 240, None, None, None, None, None, None, None, None, None ], - ['_interval', 12, None, 24, None, 120, 50, None, None, None, None, None, None, None, None ], - ['_interval', 0, 10, 0, 12, 120, 50, None, None, None, None, None, None, None, None ], # for fillna with 0 - ['_rectangle', 12, 32, 24, 55, None, None, None, None, None, None, None, None, None, None ], - ['_rectangle', 12, 32, 24, 55, 0, 0, None, None, None, None, None, None, None, None ], - ['_quadrilateral',None,None, None, None, None, None, 1, 2, 3, 2, 3, 6, 1, 4 ], - ['_quadrilateral',None,None, None, None, 0, 0, 1, 2, 3, 2, 3, 6, 1, 4 ], - ['_quadrilateral',0, 0, 0, 0, 0, 0, 1, 2, 3, 2, 3, 6, 1, 4 ], - ] - ) - - layout = Layout.from_dataframe(df) - assert layout[0] == Interval(10, 12, 'y', canvas_height=240) - assert layout[2] == Interval(10, 12, 'y', canvas_height=120, canvas_width=50) - - assert layout[3] == Rectangle(x_1=12, y_1=32, x_2=24, y_2=55) - assert layout[3] == layout[4] - - assert not layout[5] == Quadrilateral(np.arange(8).reshape(4,-1)) - assert layout[6] == Quadrilateral(np.array([[1,2], [3,2], [3,6], [1,4]])) - - df = pd.DataFrame( - columns=\ - ["_identifier", "x_1", "y_1", "x_2", "y_2", "height", "width", "p11", "p12", "p21", "p22", "p31", "p32", "p41", "p42", 'next', 'parent'], - data=[ - ['_interval', None, 10, None, 12, 240, None, None, None, None, None, None, None, None, None, None, None ], - ['_interval', 12, None, 24, None, 120, 50, None, None, None, None, None, None, None, None, None, None ], - ['_interval', 0, 10, 0, 12, 120, 50, None, None, None, None, None, None, None, None, None, 24 ], - # for fillna with 0 - ['_rectangle', 12, 32, 24, 55, None, None, None, None, None, None, None, None, None, None, None, None ], - ['_rectangle', 12, 32, 24, 55, 0, 0, None, None, None, None, None, None, None, None, 12, None ], - ['_quadrilateral',None,None, None, None, None, None, 1, 2, 3, 2, 3, 6, 1, 4, None, None ], - ['_quadrilateral',None,None, None, None, 0, 0, 1, 2, 3, 2, 3, 6, 1, 4, None, None ], - ['_textblock', None,None, None, None, 0, 0, 1, 2, 3, 2, 3, 6, 1, 4, None, 28 ], - ] - ) - - layout = Layout.from_dataframe(df) - assert layout[0] == Interval(10, 12, 'y', canvas_height=240) - assert layout[2] == Interval(10, 12, 'y', canvas_height=120, canvas_width=50) - - assert layout[3] == Rectangle(x_1=12, y_1=32, x_2=24, y_2=55) - assert layout[3] == layout[4] - assert layout[6] == Quadrilateral(np.array([[1,2], [3,2], [3,6], [1,4]])) - - assert layout[-1].block == Quadrilateral(np.array([[1,2], [3,2], [3,6], [1,4]])) - assert layout[-1].parent == 28 - - - df = pd.DataFrame( - columns=\ - ["x_1", "y_1", "x_2", "y_2", "height", "width", "p11", "p12", "p21", "p22", "p31", "p32", "p41", "p42", 'next', 'parent'], - data=[ - [None, 10, None, 12, 240, None, None, None, None, None, None, None, None, None, None, None ], - [12, None, 24, None, 120, 50, None, None, None, None, None, None, None, None, None, None ], - [0, 10, 0, 12, 120, 50, None, None, None, None, None, None, None, None, None, 24 ], - # for fillna with 0 - [12, 32, 24, 55, None, None, None, None, None, None, None, None, None, None, None, None ], - [12, 32, 24, 55, None, None, None, None, None, None, None, None, None, None, 12, None ], - [None, None, None, None, None, None, 1, 2, 3, 2, 3, 6, 1, 4, None, None ], - [None, None, None, None, 0, 0, 1, 2, 3, 2, 3, 6, 1, 4, None, None ], - [None, None, None, None, 0, 0, 1, 2, 3, 2, 3, 6, 1, 4, None, 28 ], - ] - ) - - layout = Layout.from_dataframe(df) - assert layout[0].block == Interval(10, 12, 'y', canvas_height=240) - assert layout[2].block == Interval(10, 12, 'y', canvas_height=120, canvas_width=50) - - assert layout[3].block == Rectangle(x_1=12, y_1=32, x_2=24, y_2=55) - assert not layout[3] == layout[4] - assert layout[6].block == Quadrilateral(np.array([[1,2], [3,2], [3,6], [1,4]])) - - assert layout[-1].block == Quadrilateral(np.array([[1,2], [3,2], [3,6], [1,4]])) - assert layout[-1].parent == 28 - - df = pd.DataFrame( - columns=\ - ["x_1", "y_1", "x_2", "y_2"], - data=[ - [0, 10, 0, 12, ], - [12, 32, 24, 55, ], - ]) - - layout = Layout.from_dataframe(df) - assert layout[1] == Rectangle(x_1=12, y_1=32, x_2=24, y_2=55) - - - df = pd.DataFrame( - columns=\ - ["x_1", "y_1", "x_2", "y_2", "height", "width"], - data=[ - [0, 10, 0, 12, 240, 520 ], - [12, None, 24, None, 240, None ], - ]) - - layout = Layout.from_dataframe(df) - assert layout[1] == Interval(12, 24, 'x', canvas_height=240) - - - df = pd.DataFrame( - columns=\ - ["p11", "p12", "p21", "p22", "p31", "p32", "p41", "p42", 'width', 'height'], - data=[ - [1, 2, 3, 2, 3, 6, 1, 4, None, None ], - [1, 2, 3, 2, 3, 6, 1, 4, None, None ], - [1, 2, 3, 2, 3, 6, 1, 4, None, 28 ], - ]) - - layout = Layout.from_dataframe(df) - assert layout[1] == Quadrilateral(np.array([[1,2], [3,2], [3,6], [1,4]])) - assert layout[2] == Quadrilateral(np.array([[1,2], [3,2], [3,6], [1,4]]), height=28) \ No newline at end of file + assert l.get_homogeneous_blocks() == [i.to_quadrilateral(), q, r.to_quadrilateral()] + + i2 = TextBlock(i, id=1, type=2, text="12") + r2 = TextBlock(r, id=1, type=2, parent="a") + q2 = TextBlock(q, id=1, type=2, next="a") + l2 = Layout([i2, r2, q2], page_data={"width": 200, "height": 200}) + + l2.get_texts() + l2.get_info("next") + l2.condition_on(i) + l2.relative_to(q) + l2.filter_by(t) + l2.is_in(r) + + l2.scale(4) + l2.shift(4) + l2.pad(left=2) + + # Test slicing function + homogeneous_blocks = l2[:2].get_homogeneous_blocks() + assert homogeneous_blocks[0].block == i.to_rectangle() + assert homogeneous_blocks[1].block == r + + # Test appending and extending + assert l + [i2] == Layout([i, q, r, i2]) + assert l + l == Layout([i, q, r] * 2) + l.append(i) + assert l == Layout([i, q, r, i]) + l2.extend([q]) + assert l2 == Layout([i2, r2, q2, q], page_data={"width": 200, "height": 200}) + + # Test addition + l + l2 + with pytest.raises(ValueError): + l.page_data = {"width": 200, "height": 400} + l + l2 + + +def test_dict(): + + i = Interval(1, 2, "y", canvas_height=5) + i_dict = { + "block_type": "interval", + "start": 1, + "end": 2, + "axis": "y", + "canvas_height": 5, + "canvas_width": 0, + } + assert i.to_dict() == i_dict + assert i == Interval.from_dict(i_dict) + + r = Rectangle(1, 2, 3, 4) + r_dict = {"block_type": "rectangle", "x_1": 1, "y_1": 2, "x_2": 3, "y_2": 4} + assert r.to_dict() == r_dict + assert r == Rectangle.from_dict(r_dict) + + q = Quadrilateral(np.arange(8).reshape(4, 2), 200, 400) + q_dict = { + "block_type": "quadrilateral", + "points": [0, 1, 2, 3, 4, 5, 6, 7], + "height": 200, + "width": 400, + } + assert q.to_dict() == q_dict + assert q == Quadrilateral.from_dict(q_dict) + + l = Layout([i, r, q], page_data={"width": 200, "height": 200}) + l_dict = { + "page_data": {"width": 200, "height": 200}, + "blocks": [i_dict, r_dict, q_dict], + } + assert l.to_dict() == l_dict + + i2 = TextBlock(i, "") + i_dict["text"] = "" + assert i2.to_dict() == i_dict + assert i2 == TextBlock.from_dict(i_dict) + + r2 = TextBlock(r, id=24) + r_dict["id"] = 24 + assert r2.to_dict() == r_dict + assert r2 == TextBlock.from_dict(r_dict) + + q2 = TextBlock(q, text="test", parent=45) + q_dict["text"] = "test" + q_dict["parent"] = 45 + assert q2.to_dict() == q_dict + assert q2 == TextBlock.from_dict(q_dict) + + l2 = Layout([i2, r2, q2]) + l2_dict = {"page_data": {}, "blocks": [i_dict, r_dict, q_dict]} + assert l2.to_dict() == l2_dict \ No newline at end of file diff --git a/tests/test_io.py b/tests/test_io.py new file mode 100644 index 0000000..305094e --- /dev/null +++ b/tests/test_io.py @@ -0,0 +1,56 @@ +import numpy as np +from layoutparser.elements import Interval, Rectangle, Quadrilateral, TextBlock, Layout +from layoutparser.io import load_json, load_dict, load_csv + + +def test_json(): + + i = Interval(1, 2, "y", canvas_height=5) + r = Rectangle(1, 2, 3, 4) + q = Quadrilateral(np.arange(8).reshape(4, 2), 200, 400) + l = Layout([i, r, q], page_data={"width": 200, "height": 200}) + + i2 = TextBlock(i, "") + r2 = TextBlock(r, id=24) + q2 = TextBlock(q, text="test", parent=45) + l2 = Layout([i2, r2, q2]) + + i3 = TextBlock(i, None) + r3 = TextBlock(r, id=None) + q3 = TextBlock(q, text=None, parent=None) + l3 = Layout([i3, r3, q3], page_data={"width": 200, "height": 200}) + + # fmt: off + assert i == load_dict(i.to_dict()) == load_json("tests/fixtures/io/interval.json") + assert r == load_dict(r.to_dict()) == load_json("tests/fixtures/io/rectangle.json") + assert q == load_dict(q.to_dict()) == load_json("tests/fixtures/io/quadrilateral.json") + assert l == load_dict(l.to_dict()) == load_json("tests/fixtures/io/layout.json") + + assert i2 == load_dict(i2.to_dict()) == load_json("tests/fixtures/io/interval_textblock.json") + assert r2 == load_dict(r2.to_dict()) == load_json("tests/fixtures/io/rectangle_textblock.json") + assert q2 == load_dict(q2.to_dict()) == load_json("tests/fixtures/io/quadrilateral_textblock.json") + assert l2 == load_dict(l2.to_dict()) == load_json("tests/fixtures/io/layout_textblock.json") + + # Test if LP can ignore the unused None features + assert l == load_dict(l3.to_dict()) + # fmt: on + + +def test_csv(): + i = Interval(1, 2, "y", canvas_height=5) + r = Rectangle(1, 2, 3, 4) + q = Quadrilateral(np.arange(8).reshape(4, 2), 200, 400) + l = Layout([i, r, q], page_data={"width": 200, "height": 200}) + + _l = load_csv("tests/fixtures/io/layout.csv") + assert _l != l + _l.page_data = {"width": 200, "height": 200} + assert _l == l + + i2 = TextBlock(i, "") + r2 = TextBlock(r, id=24) + q2 = TextBlock(q, text="test", parent=45) + l2 = Layout([i2, r2, q2]) + + _l2 = load_csv("tests/fixtures/io/layout_textblock.csv") + assert _l2 == l2 \ No newline at end of file diff --git a/tests/test_model.py b/tests/test_model.py index c77b0a7..618eba3 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -1,28 +1,28 @@ -from layoutparser.models import * +from layoutparser.models import * import cv2 ALL_CONFIGS = [ - 'lp://PrimaLayout/mask_rcnn_R_50_FPN_3x/config', - 'lp://HJDataset/faster_rcnn_R_50_FPN_3x/config', - 'lp://HJDataset/mask_rcnn_R_50_FPN_3x/config', - 'lp://HJDataset/retinanet_R_50_FPN_3x/config', - 'lp://PubLayNet/faster_rcnn_R_50_FPN_3x/config', - 'lp://PubLayNet/mask_rcnn_R_50_FPN_3x/config', - 'lp://PubLayNet/mask_rcnn_X_101_32x8d_FPN_3x/config', - 'lp://NewspaperNavigator/faster_rcnn_R_50_FPN_3x/config', - ] + "lp://PrimaLayout/mask_rcnn_R_50_FPN_3x/config", + "lp://HJDataset/faster_rcnn_R_50_FPN_3x/config", + "lp://HJDataset/mask_rcnn_R_50_FPN_3x/config", + "lp://HJDataset/retinanet_R_50_FPN_3x/config", + "lp://PubLayNet/faster_rcnn_R_50_FPN_3x/config", + "lp://PubLayNet/mask_rcnn_R_50_FPN_3x/config", + "lp://PubLayNet/mask_rcnn_X_101_32x8d_FPN_3x/config", + "lp://NewspaperNavigator/faster_rcnn_R_50_FPN_3x/config", +] def test_Detectron2Model(is_large_scale=False): - + if is_large_scale: - - for config in ALL_CONFIGS: + + for config in ALL_CONFIGS: model = Detectron2LayoutModel(config) - - image = cv2.imread("tests/source/test_gcv_image.jpg") + + image = cv2.imread("tests/fixtures/model/test_model_image.jpg") layout = model.detect(image) else: - model = Detectron2LayoutModel('tests/source/config.yml') - image = cv2.imread("tests/source/test_gcv_image.jpg") + model = Detectron2LayoutModel("tests/fixtures/model/config.yml") + image = cv2.imread("tests/fixtures/model/test_model_image.jpg") layout = model.detect(image) \ No newline at end of file diff --git a/tests/test_ocr.py b/tests/test_ocr.py index 8b47cfe..844c0a7 100644 --- a/tests/test_ocr.py +++ b/tests/test_ocr.py @@ -1,28 +1,34 @@ -from layoutparser.ocr import GCVAgent, GCVFeatureType, TesseractAgent, TesseractFeatureType +from layoutparser.ocr import ( + GCVAgent, + GCVFeatureType, + TesseractAgent, + TesseractFeatureType, +) import json, cv2, os -image = cv2.imread("tests/source/test_gcv_image.jpg") +image = cv2.imread("tests/fixtures/ocr/test_gcv_image.jpg") + def test_gcv_agent(test_detect=False): - + # Test loading the agent with designated credential ocr_agent = GCVAgent() - - # Test loading the saved response and parse the data - res = ocr_agent.load_response("tests/source/test_gcv_response.json") + + # Test loading the saved response and parse the data + res = ocr_agent.load_response("tests/fixtures/ocr/test_gcv_response.json") r0 = ocr_agent.gather_text_annotations(res) r1 = ocr_agent.gather_full_text_annotation(res, GCVFeatureType.SYMBOL) r2 = ocr_agent.gather_full_text_annotation(res, GCVFeatureType.WORD) r3 = ocr_agent.gather_full_text_annotation(res, GCVFeatureType.PARA) r4 = ocr_agent.gather_full_text_annotation(res, GCVFeatureType.BLOCK) r5 = ocr_agent.gather_full_text_annotation(res, GCVFeatureType.PAGE) - + # Test with a online image detection and compare the results with the stored one - # Warning: there could be updates on the GCV side. So it would be good to not - # frequently test this part. + # Warning: there could be updates on the GCV side. So it would be good to not + # frequently test this part. if test_detect: res2 = ocr_agent.detect(image, return_response=True) - + assert res == res2 assert r0 == ocr_agent.gather_text_annotations(res2) assert r1 == ocr_agent.gather_full_text_annotation(res2, GCVFeatureType.SYMBOL) @@ -30,27 +36,28 @@ def test_gcv_agent(test_detect=False): assert r3 == ocr_agent.gather_full_text_annotation(res2, GCVFeatureType.PARA) assert r4 == ocr_agent.gather_full_text_annotation(res2, GCVFeatureType.BLOCK) assert r5 == ocr_agent.gather_full_text_annotation(res2, GCVFeatureType.PAGE) - + # Finally, test the response storage and remove the file - ocr_agent.save_response(res, "tests/source/.test_gcv_response.json") - os.remove("tests/source/.test_gcv_response.json") - + ocr_agent.save_response(res, "tests/fixtures/ocr/.test_gcv_response.json") + os.remove("tests/fixtures/ocr/.test_gcv_response.json") + + def test_tesseract(test_detect=False): - ocr_agent = TesseractAgent(languages='eng') - res = ocr_agent.load_response("tests/source/test_tesseract_response.pickle") - r0 = res['text'] - r1 = ocr_agent.gather_data(res, agg_level=TesseractFeatureType.PAGE) + ocr_agent = TesseractAgent(languages="eng") + res = ocr_agent.load_response("tests/fixtures/ocr/test_tesseract_response.pickle") + r0 = res["text"] + r1 = ocr_agent.gather_data(res, agg_level=TesseractFeatureType.PAGE) r2 = ocr_agent.gather_data(res, agg_level=TesseractFeatureType.BLOCK) - r3 = ocr_agent.gather_data(res, agg_level=TesseractFeatureType.PARA) - r4 = ocr_agent.gather_data(res, agg_level=TesseractFeatureType.LINE) - r5 = ocr_agent.gather_data(res, agg_level=TesseractFeatureType.WORD) - - # The results could be different is using another version of Tesseract Engine. - # tesseract 4.1.1 is used for generating the pickle test file. + r3 = ocr_agent.gather_data(res, agg_level=TesseractFeatureType.PARA) + r4 = ocr_agent.gather_data(res, agg_level=TesseractFeatureType.LINE) + r5 = ocr_agent.gather_data(res, agg_level=TesseractFeatureType.WORD) + + # The results could be different is using another version of Tesseract Engine. + # tesseract 4.1.1 is used for generating the pickle test file. if test_detect: res = ocr_agent.detect(image, return_response=True) - assert r0 == res['text'] + assert r0 == res["text"] assert r1 == ocr_agent.gather_data(res, agg_level=TesseractFeatureType.PAGE) assert r2 == ocr_agent.gather_data(res, agg_level=TesseractFeatureType.BLOCK) assert r3 == ocr_agent.gather_data(res, agg_level=TesseractFeatureType.PARA) diff --git a/tests/test_visualization.py b/tests/test_visualization.py index 643022d..a670ab2 100644 --- a/tests/test_visualization.py +++ b/tests/test_visualization.py @@ -1,51 +1,71 @@ -from layoutparser.elements import * +from layoutparser.elements import * from layoutparser.ocr import * from layoutparser.visualization import * import cv2 import numpy as np + def test_viz(): - image = cv2.imread("tests/source/test_gcv_image.jpg") - ocr_agent = GCVAgent.with_credential("tests/source/test_gcv_credential.json", languages = ['en']) - res = ocr_agent.load_response("tests/source/test_gcv_response.json") - + image = cv2.imread("tests/fixtures/ocr/test_gcv_image.jpg") + ocr_agent = GCVAgent.with_credential( + "tests/fixtures/ocr/test_gcv_credential.json", languages=["en"] + ) + res = ocr_agent.load_response("tests/fixtures/ocr/test_gcv_response.json") + draw_box(image, Layout([])) draw_text(image, Layout([])) - - draw_box(image, Layout([ - Interval(0,10,axis='x'), + + draw_box( + image, + Layout( + [ + Interval(0, 10, axis="x"), Rectangle(0, 50, 100, 80), - Quadrilateral(np.array([[10, 10], [30, 40], [90, 40], [10, 20]])) - ])) - - draw_text(image, Layout([ - Interval(0,10,axis='x'), + Quadrilateral(np.array([[10, 10], [30, 40], [90, 40], [10, 20]])), + ] + ), + ) + + draw_text( + image, + Layout( + [ + Interval(0, 10, axis="x"), Rectangle(0, 50, 100, 80), - Quadrilateral(np.array([[10, 10], [30, 40], [90, 40], [10, 20]])) - ])) - - for idx, level in enumerate([GCVFeatureType.SYMBOL, - GCVFeatureType.WORD, - GCVFeatureType.PARA, - GCVFeatureType.BLOCK, - GCVFeatureType.PAGE]): - + Quadrilateral(np.array([[10, 10], [30, 40], [90, 40], [10, 20]])), + ] + ), + ) + + for idx, level in enumerate( + [ + GCVFeatureType.SYMBOL, + GCVFeatureType.WORD, + GCVFeatureType.PARA, + GCVFeatureType.BLOCK, + GCVFeatureType.PAGE, + ] + ): + layout = ocr_agent.gather_full_text_annotation(res, level) - - draw_text(image, layout, - arrangement='ud' if idx%2 else 'ud', - font_size=15, - text_color='pink', - text_background_color='grey', - with_box_on_text = True, - text_box_width= 2, - text_box_color='yellow', - with_layout=True, - box_width=1, - color_map={None: 'blue'}, - show_element_id=True, - id_font_size=8) - + + draw_text( + image, + layout, + arrangement="ud" if idx % 2 else "ud", + font_size=15, + text_color="pink", + text_background_color="grey", + with_box_on_text=True, + text_box_width=2, + text_box_color="yellow", + with_layout=True, + box_width=1, + color_map={None: "blue"}, + show_element_id=True, + id_font_size=8, + ) + draw_box(image, layout) draw_text(image, layout) \ No newline at end of file