import difflib
from typing import Any, Dict, List
import numpy as np
import pandas as pd
from unstructured_inference.models.eval import compare_contents_as_df
class TableAlignment:
def __init__(self, cutoff: float = 0.8):
self.cutoff = cutoff
@staticmethod
def get_content_in_tables(table_data: List[List[Dict[str, Any]]]) -> List[str]:
# Replace below docstring with google-style docstring
"""Extracts and concatenates the content of cells from each table in a list of tables.
Args:
table_data: A list of tables, each table being a list of cell data dictionaries.
Returns:
List of strings where each string represents the concatenated content of one table.
"""
return [" ".join([d["content"] for d in td if "content" in d]) for td in table_data]
@staticmethod
def get_table_level_alignment(
predicted_table_data: List[List[Dict[str, Any]]],
ground_truth_table_data: List[List[Dict[str, Any]]],
) -> List[int]:
"""Compares predicted table data with ground truth data to find the best
matching table index for each predicted table.
Args:
predicted_table_data: A list of predicted tables.
ground_truth_table_data: A list of ground truth tables.
Returns:
A list of indices indicating the best match in the ground truth for
each predicted table.
"""
ground_truth_texts = TableAlignment.get_content_in_tables(ground_truth_table_data)
matched_indices = []
for td in predicted_table_data:
reference = TableAlignment.get_content_in_tables([td])[0]
matches = difflib.get_close_matches(reference, ground_truth_texts, cutoff=0.1, n=1)
matched_indices.append(ground_truth_texts.index(matches[0]) if matches else -1)
return matched_indices
@staticmethod
def _zip_to_dataframe(table_data: List[Dict[str, Any]]) -> pd.DataFrame:
df = pd.DataFrame(table_data, columns=["row_index", "col_index", "content"])
df = df.set_index("row_index")
df["col_index"] = df["col_index"].astype(str)
return df
@staticmethod
def get_element_level_alignment(
predicted_table_data: List[List[Dict[str, Any]]],
ground_truth_table_data: List[List[Dict[str, Any]]],
matched_indices: List[int],
cutoff: float = 0.8,
) -> Dict[str, float]:
"""Aligns elements of the predicted tables with the ground truth tables at the cell level.
Args:
predicted_table_data: A list of predicted tables.
ground_truth_table_data: A list of ground truth tables.
matched_indices: Indices of the best matching ground truth table for each predicted table.
cutoff: The cutoff value for the close matches.
Returns:
A dictionary with column and row alignment accuracies.
"""
content_diff_cols = []
content_diff_rows = []
col_index_acc = []
row_index_acc = []
for idx, td in zip(matched_indices, predicted_table_data):
if idx == -1:
content_diff_cols.append(0)
content_diff_rows.append(0)
col_index_acc.append(0)
row_index_acc.append(0)
continue
ground_truth_td = ground_truth_table_data[idx]
# Get row and col content accuracy
predict_table_df = TableAlignment._zip_to_dataframe(td)
ground_truth_table_df = TableAlignment._zip_to_dataframe(ground_truth_td)
table_content_diff = compare_contents_as_df(
ground_truth_table_df.fillna(""),
predict_table_df.fillna(""),
)
content_diff_cols.append(table_content_diff["by_col_token_ratio"])
content_diff_rows.append(table_content_diff["by_row_token_ratio"])
aligned_element_col_count = 0
aligned_element_row_count = 0
total_element_count = 0
# Get row and col index accuracy
ground_truth_td_contents_list = [gtd["content"].lower() for gtd in ground_truth_td]
used_indices = set()
indices_tuple_pairs = []
for td_ele in td:
content = td_ele["content"].lower()
row_index = td_ele["row_index"]
col_idx = td_ele["col_index"]
matches = difflib.get_close_matches(
content,
ground_truth_td_contents_list,
cutoff=cutoff,
n=1,
)
# BUG FIX: the previous matched_idx will only output the first matched index if
# the match has duplicates in the
# ground_truth_td_contents_list, the current fix will output its correspondence idx
# once matching is exhausted, it will go back search again the same fashion
matching_indices = []
if matches != []:
b_indices = [
i
for i, b_string in enumerate(ground_truth_td_contents_list)
if b_string == matches[0] and i not in used_indices
]
if not b_indices:
# If all indices are used, reset used_indices and use the first index
used_indices.clear()
b_indices = [
i
for i, b_string in enumerate(ground_truth_td_contents_list)
if b_string == matches[0] and i not in used_indices
]
matching_index = b_indices[0]
matching_indices.append(matching_index)
used_indices.add(matching_index)
else:
matching_indices = [-1]
matched_idx = matching_indices[0]
if matched_idx >= 0:
gt_row_index = ground_truth_td[matched_idx]["row_index"]
gt_col_index = ground_truth_td[matched_idx]["col_index"]
indices_tuple_pairs.append(((row_index, col_idx), (gt_row_index, gt_col_index)))
for indices_tuple_pair in indices_tuple_pairs:
if indices_tuple_pair[0][0] == indices_tuple_pair[1][0]:
aligned_element_row_count += 1
if indices_tuple_pair[0][1] == indices_tuple_pair[1][1]:
aligned_element_col_count += 1
total_element_count += 1
table_col_index_acc = 0
table_row_index_acc = 0
if total_element_count > 0:
table_col_index_acc = round(aligned_element_col_count / total_element_count, 2)
table_row_index_acc = round(aligned_element_row_count / total_element_count, 2)
col_index_acc.append(table_col_index_acc)
row_index_acc.append(table_row_index_acc)
not_found_gt_table_indexes = [
id for id in range(len(ground_truth_table_data)) if id not in matched_indices
]
for _ in not_found_gt_table_indexes:
content_diff_cols.append(0)
content_diff_rows.append(0)
col_index_acc.append(0)
row_index_acc.append(0)
return {
"col_index_acc": round(np.mean(col_index_acc), 2),
"row_index_acc": round(np.mean(row_index_acc), 2),
"col_content_acc": round(np.mean(content_diff_cols) / 100.0, 2),
"row_content_acc": round(np.mean(content_diff_rows) / 100.0, 2),
}