from __future__ import annotations
import copy
from io import BytesIO
from typing import IO, Any, Iterator, cast
from lxml import etree
from unstructured.chunking import add_chunking_strategy
from unstructured.documents.elements import Element, ElementMetadata, Text
from unstructured.file_utils.encoding import read_txt_file
from unstructured.file_utils.model import FileType
from unstructured.partition.common.common import (
exactly_one,
spooled_to_bytes_io_if_needed,
)
from unstructured.partition.common.metadata import apply_metadata, get_last_modified_date
from unstructured.partition.text import element_from_text
DETECTION_ORIGIN: str = "xml"
@apply_metadata(FileType.XML)
@add_chunking_strategy
def partition_xml(
filename: str | None = None,
*,
file: IO[bytes] | None = None,
text: str | None = None,
encoding: str | None = None,
xml_keep_tags: bool = False,
xml_path: str | None = None,
**kwargs: Any,
) -> list[Element]:
"""Partitions an XML document into its document elements.
Parameters
----------
filename
A string defining the target filename path.
file
A file-like object using "rb" mode --> open(filename, "rb").
text
The text of the XML file.
encoding
The encoding method used to decode the text input. If None, utf-8 will be used.
xml_keep_tags
If True, will retain the XML tags in the output. Otherwise it will simply extract
the text from within the tags.
xml_path
The xml_path to use for extracting the text. Only used if xml_keep_tags=False.
"""
exactly_one(filename=filename, file=file, text=text)
elements: list[Element] = []
metadata = ElementMetadata(
filename=filename, last_modified=get_last_modified_date(filename) if filename else None
)
metadata.detection_origin = DETECTION_ORIGIN
if xml_keep_tags:
if filename:
raw_text = read_txt_file(filename=filename, encoding=encoding)[1]
elif file:
raw_text = read_txt_file(file=spooled_to_bytes_io_if_needed(file), encoding=encoding)[1]
else:
assert text is not None
raw_text = text
elements = [Text(text=raw_text, metadata=metadata)]
else:
leaf_elements = get_leaf_elements(
filename=filename,
file=file,
text=text,
xml_path=xml_path,
)
for leaf_element in leaf_elements:
if leaf_element:
element = element_from_text(leaf_element)
element.metadata = copy.deepcopy(metadata)
elements.append(element)
return elements
def get_leaf_elements(
filename: str | None, file: IO[bytes] | None, text: str | None, xml_path: str | None
) -> Iterator[str | None]:
"""Get leaf elements from the XML tree defined in filename, file, or text."""
exactly_one(filename=filename, file=file, text=text)
if filename:
return _get_leaf_elements(filename, xml_path=xml_path)
elif file:
return _get_leaf_elements(file=spooled_to_bytes_io_if_needed(file), xml_path=xml_path)
else:
b = BytesIO(bytes(cast(str, text), encoding="utf-8"))
return _get_leaf_elements(b, xml_path=xml_path)
def _get_leaf_elements(
file: str | IO[bytes],
xml_path: str | None,
) -> Iterator[str | None]:
"""Parse the XML tree in a memory efficient manner if possible."""
element_stack: list[etree._Element] = [] # pyright: ignore[reportPrivateUsage]
element_iterator = etree.iterparse(file, events=("start", "end"), resolve_entities=False)
# NOTE(alan) If xml_path is used for filtering, I've yet to find a good way to stream
# elements through in a memory efficient way, so we bite the bullet and load it all into
# memory.
if xml_path is not None:
_, element = next(element_iterator)
compiled_path = etree.XPath(xml_path)
element_iterator = (("end", el) for el in compiled_path(element))
for event, element in element_iterator:
if event == "start":
element_stack.append(element)
if event == "end":
if element.text is not None and element.text.strip():
yield element.text
element.clear()
while element_stack and element_stack[-1].getparent() is None:
element_stack.pop()