import difflib from typing import Any, Dict, List import numpy as np import pandas as pd from unstructured_inference.models.eval import compare_contents_as_df class TableAlignment: def __init__(self, cutoff: float = 0.8): self.cutoff = cutoff @staticmethod def get_content_in_tables(table_data: List[List[Dict[str, Any]]]) -> List[str]: # Replace below docstring with google-style docstring """Extracts and concatenates the content of cells from each table in a list of tables. Args: table_data: A list of tables, each table being a list of cell data dictionaries. Returns: List of strings where each string represents the concatenated content of one table. """ return [" ".join([d["content"] for d in td if "content" in d]) for td in table_data] @staticmethod def get_table_level_alignment( predicted_table_data: List[List[Dict[str, Any]]], ground_truth_table_data: List[List[Dict[str, Any]]], ) -> List[int]: """Compares predicted table data with ground truth data to find the best matching table index for each predicted table. Args: predicted_table_data: A list of predicted tables. ground_truth_table_data: A list of ground truth tables. Returns: A list of indices indicating the best match in the ground truth for each predicted table. """ ground_truth_texts = TableAlignment.get_content_in_tables(ground_truth_table_data) matched_indices = [] for td in predicted_table_data: reference = TableAlignment.get_content_in_tables([td])[0] matches = difflib.get_close_matches(reference, ground_truth_texts, cutoff=0.1, n=1) matched_indices.append(ground_truth_texts.index(matches[0]) if matches else -1) return matched_indices @staticmethod def _zip_to_dataframe(table_data: List[Dict[str, Any]]) -> pd.DataFrame: df = pd.DataFrame(table_data, columns=["row_index", "col_index", "content"]) df = df.set_index("row_index") df["col_index"] = df["col_index"].astype(str) return df @staticmethod def get_element_level_alignment( predicted_table_data: List[List[Dict[str, Any]]], ground_truth_table_data: List[List[Dict[str, Any]]], matched_indices: List[int], cutoff: float = 0.8, ) -> Dict[str, float]: """Aligns elements of the predicted tables with the ground truth tables at the cell level. Args: predicted_table_data: A list of predicted tables. ground_truth_table_data: A list of ground truth tables. matched_indices: Indices of the best matching ground truth table for each predicted table. cutoff: The cutoff value for the close matches. Returns: A dictionary with column and row alignment accuracies. """ content_diff_cols = [] content_diff_rows = [] col_index_acc = [] row_index_acc = [] for idx, td in zip(matched_indices, predicted_table_data): if idx == -1: content_diff_cols.append(0) content_diff_rows.append(0) col_index_acc.append(0) row_index_acc.append(0) continue ground_truth_td = ground_truth_table_data[idx] # Get row and col content accuracy predict_table_df = TableAlignment._zip_to_dataframe(td) ground_truth_table_df = TableAlignment._zip_to_dataframe(ground_truth_td) table_content_diff = compare_contents_as_df( ground_truth_table_df.fillna(""), predict_table_df.fillna(""), ) content_diff_cols.append(table_content_diff["by_col_token_ratio"]) content_diff_rows.append(table_content_diff["by_row_token_ratio"]) aligned_element_col_count = 0 aligned_element_row_count = 0 total_element_count = 0 # Get row and col index accuracy ground_truth_td_contents_list = [gtd["content"].lower() for gtd in ground_truth_td] used_indices = set() indices_tuple_pairs = [] for td_ele in td: content = td_ele["content"].lower() row_index = td_ele["row_index"] col_idx = td_ele["col_index"] matches = difflib.get_close_matches( content, ground_truth_td_contents_list, cutoff=cutoff, n=1, ) # BUG FIX: the previous matched_idx will only output the first matched index if # the match has duplicates in the # ground_truth_td_contents_list, the current fix will output its correspondence idx # once matching is exhausted, it will go back search again the same fashion matching_indices = [] if matches != []: b_indices = [ i for i, b_string in enumerate(ground_truth_td_contents_list) if b_string == matches[0] and i not in used_indices ] if not b_indices: # If all indices are used, reset used_indices and use the first index used_indices.clear() b_indices = [ i for i, b_string in enumerate(ground_truth_td_contents_list) if b_string == matches[0] and i not in used_indices ] matching_index = b_indices[0] matching_indices.append(matching_index) used_indices.add(matching_index) else: matching_indices = [-1] matched_idx = matching_indices[0] if matched_idx >= 0: gt_row_index = ground_truth_td[matched_idx]["row_index"] gt_col_index = ground_truth_td[matched_idx]["col_index"] indices_tuple_pairs.append(((row_index, col_idx), (gt_row_index, gt_col_index))) for indices_tuple_pair in indices_tuple_pairs: if indices_tuple_pair[0][0] == indices_tuple_pair[1][0]: aligned_element_row_count += 1 if indices_tuple_pair[0][1] == indices_tuple_pair[1][1]: aligned_element_col_count += 1 total_element_count += 1 table_col_index_acc = 0 table_row_index_acc = 0 if total_element_count > 0: table_col_index_acc = round(aligned_element_col_count / total_element_count, 2) table_row_index_acc = round(aligned_element_row_count / total_element_count, 2) col_index_acc.append(table_col_index_acc) row_index_acc.append(table_row_index_acc) not_found_gt_table_indexes = [ id for id in range(len(ground_truth_table_data)) if id not in matched_indices ] for _ in not_found_gt_table_indexes: content_diff_cols.append(0) content_diff_rows.append(0) col_index_acc.append(0) row_index_acc.append(0) return { "col_index_acc": round(np.mean(col_index_acc), 2), "row_index_acc": round(np.mean(row_index_acc), 2), "col_content_acc": round(np.mean(content_diff_cols) / 100.0, 2), "row_content_acc": round(np.mean(content_diff_rows) / 100.0, 2), }
Memory