from __future__ import annotations
from typing import TYPE_CHECKING, Optional
import numpy as np
from unstructured_inference.constants import Source
from unstructured_inference.inference.elements import TextRegion, TextRegions
from unstructured_inference.inference.layoutelement import (
LayoutElement,
LayoutElements,
partition_groups_from_regions,
)
from unstructured.documents.elements import ElementType
if TYPE_CHECKING:
from unstructured_inference.inference.elements import Rectangle
def build_text_region_from_coords(
x1: int | float,
y1: int | float,
x2: int | float,
y2: int | float,
text: Optional[str] = None,
source: Optional[Source] = None,
) -> TextRegion:
""""""
return TextRegion.from_coords(x1, y1, x2, y2, text=text, source=source)
def build_layout_element(
bbox: "Rectangle",
text: Optional[str] = None,
source: Optional[Source] = None,
element_type: Optional[str] = None,
) -> LayoutElement:
""""""
return LayoutElement(bbox=bbox, text=text, source=source, type=element_type)
def build_layout_elements_from_ocr_regions(
ocr_regions: TextRegions,
ocr_text: Optional[str] = None,
group_by_ocr_text: bool = False,
) -> LayoutElements:
"""
Get layout elements from OCR regions
"""
grouped_regions = []
if group_by_ocr_text:
text_sections = ocr_text.split("\n\n")
mask = np.ones(ocr_regions.texts.shape).astype(bool)
indices = np.arange(len(mask))
for text_section in text_sections:
regions = []
words = text_section.replace("\n", " ").split()
for i, text in enumerate(ocr_regions.texts[mask]):
if not words:
break
if text in words:
regions.append(indices[mask][i])
words.remove(text)
if not regions:
continue
mask[regions] = False
grouped_regions.append(ocr_regions.slice(regions))
else:
grouped_regions = partition_groups_from_regions(ocr_regions)
merged_regions = TextRegions.from_list([merge_text_regions(group) for group in grouped_regions])
return LayoutElements(
element_coords=merged_regions.element_coords,
texts=merged_regions.texts,
sources=merged_regions.sources,
element_class_ids=np.zeros(merged_regions.texts.shape),
element_class_id_map={0: ElementType.UNCATEGORIZED_TEXT},
)
def merge_text_regions(regions: TextRegions) -> TextRegion:
"""
Merge a list of TextRegion objects into a single TextRegion.
Parameters:
- group (TextRegions): A group of TextRegion objects to be merged.
Returns:
- TextRegion: A single merged TextRegion object.
"""
if not regions:
raise ValueError("The text regions to be merged must be provided.")
min_x1 = regions.x1.min().astype(float)
min_y1 = regions.y1.min().astype(float)
max_x2 = regions.x2.max().astype(float)
max_y2 = regions.y2.max().astype(float)
merged_text = " ".join([text for text in regions.texts if text])
# assumption is the regions has the same source
source = regions.sources[0]
return TextRegion.from_coords(min_x1, min_y1, max_x2, max_y2, merged_text, source)