#! /usr/bin/env python3
from __future__ import annotations
import concurrent.futures
import json
import logging
import os
import sys
from abc import ABC, abstractmethod
from dataclasses import dataclass
from pathlib import Path
from typing import List, Optional, Union
import numpy as np
import pandas as pd
from tqdm import tqdm
from unstructured.metrics.element_type import (
calculate_element_type_percent_match,
get_element_type_frequency,
)
from unstructured.metrics.object_detection import (
ObjectDetectionEvalProcessor,
)
from unstructured.metrics.table.table_eval import TableEvalProcessor
from unstructured.metrics.text_extraction import calculate_accuracy, calculate_percent_missing_text
from unstructured.metrics.utils import (
_count,
_display,
_format_grouping_output,
_mean,
_prepare_output_cct,
_pstdev,
_read_text_file,
_rename_aggregated_columns,
_stdev,
_write_to_file,
)
logger = logging.getLogger("unstructured.eval")
handler = logging.StreamHandler()
handler.name = "eval_log_handler"
formatter = logging.Formatter("%(asctime)s %(processName)-10s %(levelname)-8s %(message)s")
handler.setFormatter(formatter)
# Only want to add the handler once
if "eval_log_handler" not in [h.name for h in logger.handlers]:
logger.addHandler(handler)
logger.setLevel(logging.DEBUG)
AGG_HEADERS = ["metric", "average", "sample_sd", "population_sd", "count"]
AGG_HEADERS_MAPPING = {
"index": "metric",
"_mean": "average",
"_stdev": "sample_sd",
"_pstdev": "population_sd",
"_count": "count",
}
OUTPUT_TYPE_OPTIONS = ["json", "txt"]
@dataclass
class BaseMetricsCalculator(ABC):
"""Foundation class for specialized metrics calculators.
It provides a common interface for calculating metrics based on outputs and ground truths.
Those can be provided as either directories or lists of files.
"""
documents_dir: str | Path
ground_truths_dir: str | Path
def __post_init__(self):
"""Discover all files in the provided directories."""
self.documents_dir = Path(self.documents_dir).resolve()
self.ground_truths_dir = Path(self.ground_truths_dir).resolve()
# -- auto-discover all files in the directories --
self._document_paths = [
path.relative_to(self.documents_dir)
for path in self.documents_dir.glob("*")
if path.is_file()
]
self._ground_truth_paths = [
path.relative_to(self.ground_truths_dir)
for path in self.ground_truths_dir.glob("*")
if path.is_file()
]
@property
@abstractmethod
def default_tsv_name(self):
"""Default name for the per-document metrics TSV file."""
@property
@abstractmethod
def default_agg_tsv_name(self):
"""Default name for the aggregated metrics TSV file."""
@abstractmethod
def _generate_dataframes(self, rows: list) -> tuple[pd.DataFrame, pd.DataFrame]:
"""Generates pandas DataFrames from the list of rows.
The first DF (index 0) is a dataframe containing metrics per file.
The second DF (index 1) is a dataframe containing the aggregated
metrics.
"""
def on_files(
self,
document_paths: Optional[list[str | Path]] = None,
ground_truth_paths: Optional[list[str | Path]] = None,
) -> BaseMetricsCalculator:
"""Overrides the default list of files to process."""
if document_paths:
self._document_paths = [Path(p) for p in document_paths]
if ground_truth_paths:
self._ground_truth_paths = [Path(p) for p in ground_truth_paths]
return self
def calculate(
self,
executor: Optional[concurrent.futures.Executor] = None,
export_dir: Optional[str | Path] = None,
visualize_progress: bool = True,
display_agg_df: bool = True,
) -> pd.DataFrame:
"""Calculates metrics for each document using the provided executor.
* Optionally, the results can be exported and displayed.
* It loops through the list of structured output from all of `documents_dir` or
selected files from `document_paths`, and compares them with gold-standard
of the same file name under `ground_truths_dir` or selected files from `ground_truth_paths`.
Args:
executor: concurrent.futures.Executor instance
export_dir: directory to export the results
visualize_progress: whether to display progress bar
display_agg_df: whether to display the aggregated results
Returns:
Metrics for each document as a pandas DataFrame
"""
if executor is None:
executor = self._default_executor()
rows = self._process_all_documents(executor, visualize_progress)
df, agg_df = self._generate_dataframes(rows)
if export_dir is not None:
_write_to_file(export_dir, self.default_tsv_name, df)
_write_to_file(export_dir, self.default_agg_tsv_name, agg_df)
if display_agg_df is True:
_display(agg_df)
return df
@classmethod
def _default_executor(cls):
max_processors = int(os.environ.get("MAX_PROCESSES", os.cpu_count()))
logger.info(f"Configuring a pool of {max_processors} processors for parallel processing.")
return cls._get_executor_class()(max_workers=max_processors)
@classmethod
def _get_executor_class(
cls,
) -> type[concurrent.futures.ThreadPoolExecutor] | type[concurrent.futures.ProcessPoolExecutor]:
return concurrent.futures.ProcessPoolExecutor
def _process_all_documents(
self, executor: concurrent.futures.Executor, visualize_progress: bool
) -> list:
"""Triggers processing of all documents using the provided executor.
Failures are omitted from the returned result.
"""
with executor:
return [
row
for row in tqdm(
executor.map(self._try_process_document, self._document_paths),
total=len(self._document_paths),
leave=False,
disable=not visualize_progress,
)
if row is not None
]
def _try_process_document(self, doc: Path) -> Optional[list]:
"""Safe wrapper around the document processing method."""
logger.info(f"Processing {doc}")
try:
return self._process_document(doc)
except Exception as e:
logger.error(f"Failed to process document {doc}: {e}")
return None
@abstractmethod
def _process_document(self, doc: Path) -> Optional[list]:
"""Should return all metadata and metrics for a single document."""
@dataclass
class TableStructureMetricsCalculator(BaseMetricsCalculator):
"""Calculates the following metrics for tables:
- tables found accuracy
- table-level accuracy
- element in column index accuracy
- element in row index accuracy
- element's column content accuracy
- element's row content accuracy
It also calculates the aggregated accuracy.
"""
cutoff: Optional[float] = None
weighted_average: bool = True
include_false_positives: bool = True
def __post_init__(self):
super().__post_init__()
@property
def supported_metric_names(self):
return [
"total_tables",
"table_level_acc",
"table_detection_recall",
"table_detection_precision",
"table_detection_f1",
"composite_structure_acc",
"element_col_level_index_acc",
"element_row_level_index_acc",
"element_col_level_content_acc",
"element_row_level_content_acc",
]
@property
def default_tsv_name(self):
return "all-docs-table-structure-accuracy.tsv"
@property
def default_agg_tsv_name(self):
return "aggregate-table-structure-accuracy.tsv"
def _process_document(self, doc: Path) -> Optional[list]:
doc_path = Path(doc)
out_filename = doc_path.stem
doctype = Path(out_filename).suffix
src_gt_filename = out_filename + ".json"
connector = doc_path.parts[-2] if len(doc_path.parts) > 1 else None
if src_gt_filename in self._ground_truth_paths: # type: ignore
return None
prediction_file = self.documents_dir / doc
if not prediction_file.exists():
logger.warning(f"Prediction file {prediction_file} does not exist, skipping")
return None
ground_truth_file = self.ground_truths_dir / src_gt_filename
if not ground_truth_file.exists():
logger.warning(f"Ground truth file {ground_truth_file} does not exist, skipping")
return None
processor_from_text_as_html = TableEvalProcessor.from_json_files(
prediction_file=prediction_file,
ground_truth_file=ground_truth_file,
cutoff=self.cutoff,
source_type="html",
)
report_from_html = processor_from_text_as_html.process_file()
return [
out_filename,
doctype,
connector,
report_from_html.total_predicted_tables,
] + [getattr(report_from_html, metric) for metric in self.supported_metric_names]
def _generate_dataframes(self, rows):
headers = [
"filename",
"doctype",
"connector",
"total_predicted_tables",
] + self.supported_metric_names
df = pd.DataFrame(rows, columns=headers)
df["_table_weights"] = df["total_tables"]
if self.include_false_positives:
# we give false positive tables a 1 table worth of weight in computing table level acc
df["_table_weights"][df.total_tables.eq(0) & df.total_predicted_tables.gt(0)] = 1
# filter down to only those with actual and/or predicted tables
has_tables_df = df[df["_table_weights"] > 0]
if not self.weighted_average:
# for all non zero elements assign them value 1
df["_table_weights"] = df["_table_weights"].apply(
lambda table_weight: 1 if table_weight != 0 else 0
)
if has_tables_df.empty:
agg_df = pd.DataFrame(
[[metric, None, None, None, 0] for metric in self.supported_metric_names]
).reset_index()
else:
element_metrics_results = {}
for metric in self.supported_metric_names:
metric_df = has_tables_df[has_tables_df[metric].notnull()]
agg_metric = metric_df[metric].agg([_stdev, _pstdev, _count]).transpose()
if metric.startswith("total_tables"):
agg_metric["_mean"] = metric_df[metric].mean()
elif metric.startswith("table_level_acc"):
agg_metric["_mean"] = np.round(
np.average(metric_df[metric], weights=metric_df["_table_weights"]),
3,
)
else:
# false positive tables do not contribute to table structure and content
# extraction metrics
agg_metric["_mean"] = np.round(
np.average(metric_df[metric], weights=metric_df["total_tables"]),
3,
)
if agg_metric.empty:
element_metrics_results[metric] = pd.Series(
data=[None, None, None, 0], index=["_mean", "_stdev", "_pstdev", "_count"]
)
else:
element_metrics_results[metric] = agg_metric
agg_df = pd.DataFrame(element_metrics_results).transpose().reset_index()
agg_df = agg_df.rename(columns=AGG_HEADERS_MAPPING)
return df, agg_df
@dataclass
class TextExtractionMetricsCalculator(BaseMetricsCalculator):
"""Calculates text accuracy and percent missing between document and ground truth texts.
It also calculates the aggregated accuracy and percent missing.
"""
group_by: Optional[str] = None
weights: tuple[int, int, int] = (1, 1, 1)
document_type: str = "json"
def __post_init__(self):
super().__post_init__()
self._validate_inputs()
@property
def default_tsv_name(self) -> str:
return "all-docs-cct.tsv"
@property
def default_agg_tsv_name(self) -> str:
return "aggregate-scores-cct.tsv"
def calculate(
self,
executor: Optional[concurrent.futures.Executor] = None,
export_dir: Optional[str | Path] = None,
visualize_progress: bool = True,
display_agg_df: bool = True,
) -> pd.DataFrame:
"""See the parent class for the method's docstring."""
df = super().calculate(
executor=executor,
export_dir=export_dir,
visualize_progress=visualize_progress,
display_agg_df=display_agg_df,
)
if export_dir is not None and self.group_by:
get_mean_grouping(self.group_by, df, export_dir, "text_extraction")
return df
def _validate_inputs(self):
if not self._document_paths:
logger.info("No output files to calculate to edit distances for, exiting")
sys.exit(0)
if self.document_type not in OUTPUT_TYPE_OPTIONS:
raise ValueError(
"Specified file type under `documents_dir` or `output_list` should be one of "
f"`json` or `txt`. The given file type is {self.document_type}, exiting."
)
for path in self._document_paths:
try:
path.suffixes[-1]
except IndexError:
logger.error(f"File {path} does not have a suffix, skipping")
continue
if path.suffixes[-1] != f".{self.document_type}":
logger.warning(
"The directory contains file type inconsistent with the given input. "
"Please note that some files will be skipped."
)
if not all(path.suffixes[-1] == f".{self.document_type}" for path in self._document_paths):
logger.warning(
"The directory contains file type inconsistent with the given input. "
"Please note that some files will be skipped."
)
def _process_document(self, doc: Path) -> Optional[list]:
filename = doc.stem
doctype = doc.suffixes[-2]
connector = doc.parts[0] if len(doc.parts) > 1 else None
output_cct, source_cct = self._get_ccts(doc)
# NOTE(amadeusz): Levenshtein distance calculation takes too long
# skip it if file sizes differ wildly
if 0.5 < len(output_cct.encode()) / len(source_cct.encode()) < 2.0:
accuracy = round(calculate_accuracy(output_cct, source_cct, self.weights), 3)
else:
# 0.01 to distinguish it was set manually
accuracy = 0.01
percent_missing = round(calculate_percent_missing_text(output_cct, source_cct), 3)
return [filename, doctype, connector, accuracy, percent_missing]
def _get_ccts(self, doc: Path) -> tuple[str, str]:
output_cct = _prepare_output_cct(
docpath=self.documents_dir / doc, output_type=self.document_type
)
source_cct = _read_text_file(self.ground_truths_dir / doc.with_suffix(".txt"))
return output_cct, source_cct
def _generate_dataframes(self, rows):
headers = ["filename", "doctype", "connector", "cct-accuracy", "cct-%missing"]
df = pd.DataFrame(rows, columns=headers)
acc = df[["cct-accuracy"]].agg([_mean, _stdev, _pstdev, _count]).transpose()
miss = df[["cct-%missing"]].agg([_mean, _stdev, _pstdev, _count]).transpose()
if acc.shape[1] == 0 and miss.shape[1] == 0:
agg_df = pd.DataFrame(columns=AGG_HEADERS)
else:
agg_df = pd.concat((acc, miss)).reset_index()
agg_df.columns = AGG_HEADERS
return df, agg_df
@dataclass
class ElementTypeMetricsCalculator(BaseMetricsCalculator):
"""
Calculates element type frequency accuracy, percent missing and
aggregated accuracy between document and ground truth.
"""
group_by: Optional[str] = None
def calculate(
self,
executor: Optional[concurrent.futures.Executor] = None,
export_dir: Optional[str | Path] = None,
visualize_progress: bool = True,
display_agg_df: bool = False,
) -> pd.DataFrame:
"""See the parent class for the method's docstring."""
df = super().calculate(
executor=executor,
export_dir=export_dir,
visualize_progress=visualize_progress,
display_agg_df=display_agg_df,
)
if export_dir is not None and self.group_by:
get_mean_grouping(self.group_by, df, export_dir, "element_type")
return df
@property
def default_tsv_name(self) -> str:
return "all-docs-element-type-frequency.tsv"
@property
def default_agg_tsv_name(self) -> str:
return "aggregate-scores-element-type.tsv"
def _process_document(self, doc: Path) -> Optional[list]:
filename = doc.stem
doctype = doc.suffixes[-2]
connector = doc.parts[0] if len(doc.parts) > 1 else None
output = get_element_type_frequency(_read_text_file(self.documents_dir / doc))
source = get_element_type_frequency(
_read_text_file(self.ground_truths_dir / doc.with_suffix(".json"))
)
accuracy = round(calculate_element_type_percent_match(output, source), 3)
return [filename, doctype, connector, accuracy]
def _generate_dataframes(self, rows):
headers = ["filename", "doctype", "connector", "element-type-accuracy"]
df = pd.DataFrame(rows, columns=headers)
if df.empty:
agg_df = pd.DataFrame(["element-type-accuracy", None, None, None, 0]).transpose()
else:
agg_df = df.agg({"element-type-accuracy": [_mean, _stdev, _pstdev, _count]}).transpose()
agg_df = agg_df.reset_index()
agg_df.columns = AGG_HEADERS
return df, agg_df
def get_mean_grouping(
group_by: str,
data_input: Union[pd.DataFrame, str],
export_dir: str,
eval_name: str,
agg_name: Optional[str] = None,
export_filename: Optional[str] = None,
) -> None:
"""Aggregates accuracy and missing metrics by column name 'doctype' or 'connector',
or 'all' for all rows. Export to TSV.
If `all`, passing export_name is recommended.
Args:
group_by (str): Grouping category ('doctype' or 'connector' or 'all').
data_input (Union[pd.DataFrame, str]): DataFrame or path to a CSV/TSV file.
export_dir (str): Directory for the exported TSV file.
eval_name (str): Evaluated metric ('text_extraction' or 'element_type').
agg_name (str, optional): String to use with export filename. Default is `cct` for
group_by `text_extraction` and `element-type` for `element_type`
export_name (str, optional): Export filename.
"""
if group_by not in ("doctype", "connector") and group_by != "all":
raise ValueError("Invalid grouping category. Returning a non-group evaluation.")
if eval_name == "text_extraction":
agg_fields = ["cct-accuracy", "cct-%missing"]
agg_name = "cct"
elif eval_name == "element_type":
agg_fields = ["element-type-accuracy"]
agg_name = "element-type"
elif eval_name == "object_detection":
agg_fields = ["f1_score", "m_ap"]
agg_name = "object-detection"
else:
raise ValueError(
f"Unknown metric for eval {eval_name}. "
f"Expected `text_extraction` or `element_type` or `table_extraction`."
)
if isinstance(data_input, str):
if not os.path.exists(data_input):
raise FileNotFoundError(f"File {data_input} not found.")
if data_input.endswith(".csv"):
df = pd.read_csv(data_input, header=None)
elif data_input.endswith(".tsv"):
df = pd.read_csv(data_input, sep="\t")
elif data_input.endswith(".txt"):
df = pd.read_csv(data_input, sep="\t", header=None)
else:
raise ValueError("Please provide a .csv or .tsv file.")
else:
df = data_input
if df.empty:
raise SystemExit("Data is empty. Exiting.")
elif group_by != "all" and (group_by not in df.columns or df[group_by].isnull().all()):
raise SystemExit(
f"Data cannot be aggregated by `{group_by}`."
f" Check if it's empty or the column is missing/empty."
)
grouped_df = []
if group_by and group_by != "all":
for field in agg_fields:
grouped_df.append(
_rename_aggregated_columns(
df.groupby(group_by).agg({field: [_mean, _stdev, _pstdev, _count]})
)
)
if group_by == "all":
df["grouping_key"] = 0
for field in agg_fields:
grouped_df.append(
_rename_aggregated_columns(
df.groupby("grouping_key").agg({field: [_mean, _stdev, _pstdev, _count]})
)
)
grouped_df = _format_grouping_output(*grouped_df)
if "grouping_key" in grouped_df.columns.get_level_values(0):
grouped_df = grouped_df.drop("grouping_key", axis=1, level=0)
if export_filename:
if not export_filename.endswith(".tsv"):
export_filename = export_filename + ".tsv"
_write_to_file(export_dir, export_filename, grouped_df)
else:
_write_to_file(export_dir, f"all-{group_by}-agg-{agg_name}.tsv", grouped_df)
def filter_metrics(
data_input: Union[str, pd.DataFrame],
filter_list: Union[str, List[str]],
filter_by: str = "filename",
export_filename: Optional[str] = None,
export_dir: str = "metrics",
return_type: str = "file",
) -> Optional[pd.DataFrame]:
"""Reads the data_input file and filter only selected row available in filter_list.
Args:
data_input (str, dataframe): the source data, path to file or dataframe
filter_list (str, list): the filter, path to file or list of string
filter_by (str): data_input's column to filter the filter_list to
export_filename (str, optional): export filename. required when return_type is "file"
export_dir (str, optional): export directory. default to <current directory>/metrics
return_type (str): "file" or "dataframe"
"""
if isinstance(data_input, str):
if not os.path.exists(data_input):
raise FileNotFoundError(f"File {data_input} not found.")
if data_input.endswith(".csv"):
df = pd.read_csv(data_input, header=None)
elif data_input.endswith(".tsv"):
df = pd.read_csv(data_input, sep="\t")
elif data_input.endswith(".txt"):
df = pd.read_csv(data_input, sep="\t", header=None)
else:
raise ValueError("Please provide a .csv or .tsv file.")
else:
df = data_input
if isinstance(filter_list, str):
if not os.path.exists(filter_list):
raise FileNotFoundError(f"File {filter_list} not found.")
if filter_list.endswith(".csv"):
filter_df = pd.read_csv(filter_list, header=None)
elif filter_list.endswith(".tsv"):
filter_df = pd.read_csv(filter_list, sep="\t")
elif filter_list.endswith(".txt"):
filter_df = pd.read_csv(filter_list, sep="\t", header=None)
else:
raise ValueError("Please provide a .csv or .tsv file.")
filter_list = filter_df.iloc[:, 0].astype(str).values.tolist()
elif not isinstance(filter_list, list):
raise ValueError("Please provide a List of strings or path to file.")
if filter_by not in df.columns:
raise ValueError("`filter_by` key does not exists in the data provided.")
res = df[df[filter_by].isin(filter_list)]
if res.empty:
raise SystemExit("No common file names between data_input and filter_list. Exiting.")
if return_type == "dataframe":
return res
elif return_type == "file" and export_filename:
_write_to_file(export_dir, export_filename, res)
elif return_type == "file" and not export_filename:
raise ValueError("Please provide `export_filename`.")
else:
raise ValueError("Return type must be either `dataframe` or `file`.")
@dataclass
class ObjectDetectionMetricsCalculatorBase(BaseMetricsCalculator, ABC):
"""
Calculates object detection metrics for each document:
- f1 score
- precision
- recall
- average precision (mAP)
It also calculates aggregated metrics.
"""
def __post_init__(self):
super().__post_init__()
self._document_paths = [
path.relative_to(self.documents_dir)
for path in self.documents_dir.rglob("analysis/*/layout_dump/object_detection.json")
if path.is_file()
]
@property
def supported_metric_names(self):
return ["f1_score", "precision", "recall", "m_ap"]
@property
def default_tsv_name(self):
return "all-docs-object-detection-metrics.tsv"
@property
def default_agg_tsv_name(self):
return "aggregate-object-detection-metrics.tsv"
def _find_file_in_ground_truth(self, file_stem: str) -> Optional[Path]:
"""Find the file corresponding to OD model dump file among the set of ground truth files
The files in ground truth paths keep the original extension and have .json suffix added,
e.g.:
some_document.pdf.json
poster.jpg.json
To compare to `file_stem` we need to take the prefix part of the file, thus double-stem
is applied.
"""
for path in self._ground_truth_paths:
if Path(path.stem).stem == file_stem:
return path
return None
def _get_paths(self, doc: Path) -> tuple(str, Path, Path):
"""Resolves ground doctype, prediction file path and ground truth path.
As OD dump directory structure differes from other simple outputs, it needs
a specific processing to match the output OD dump file with corresponding
OD GT file.
The outputs are placed in a dicrectory structure:
analysis
|- document_name
|- layout_dump
|- object_detection.json
|- bboxes # not used in this evaluation
and the GT file is pleced in od_gt directory for given dataset
dataset_name
|- od_gt
|- document_name.pdf.json
Args:
doc (Path): path to the OD dump file
Returns:
tuple: doctype, prediction file path, ground truth path
"""
od_dump_path = Path(doc)
file_stem = od_dump_path.parts[-3] # we take the `document_name` - so the filename stem
src_gt_filename = self._find_file_in_ground_truth(file_stem)
if src_gt_filename not in self._ground_truth_paths:
raise ValueError(f"Ground truth file {src_gt_filename} not found in list of GT files")
doctype = Path(src_gt_filename.stem).suffix[1:]
prediction_file = self.documents_dir / doc
if not prediction_file.exists():
logger.warning(f"Prediction file {prediction_file} does not exist, skipping")
raise ValueError(f"Prediction file {prediction_file} does not exist")
ground_truth_file = self.ground_truths_dir / src_gt_filename
if not ground_truth_file.exists():
logger.warning(f"Ground truth file {ground_truth_file} does not exist, skipping")
raise ValueError(f"Ground truth file {ground_truth_file} does not exist")
return doctype, prediction_file, ground_truth_file
def _generate_dataframes(self, rows) -> tuple[pd.DataFrame, pd.DataFrame]:
headers = ["filename", "doctype", "connector"] + self.supported_metric_names
df = pd.DataFrame(rows, columns=headers)
if df.empty:
agg_df = pd.DataFrame(columns=AGG_HEADERS)
else:
element_metrics_results = {}
for metric in self.supported_metric_names:
metric_df = df[df[metric].notnull()]
agg_metric = metric_df[metric].agg([_mean, _stdev, _pstdev, _count]).transpose()
if agg_metric.empty:
element_metrics_results[metric] = pd.Series(
data=[None, None, None, 0], index=["_mean", "_stdev", "_pstdev", "_count"]
)
else:
element_metrics_results[metric] = agg_metric
agg_df = pd.DataFrame(element_metrics_results).transpose().reset_index()
agg_df.columns = AGG_HEADERS
return df, agg_df
class ObjectDetectionPerClassMetricsCalculator(ObjectDetectionMetricsCalculatorBase):
def __post_init__(self):
super().__post_init__()
self.per_class_metric_names: list[str] | None = None
self._set_supported_metrics()
@property
def supported_metric_names(self):
if self.per_class_metric_names:
return self.per_class_metric_names
else:
raise ValueError("per_class_metrics not initialized - cannot get class names")
@property
def default_tsv_name(self):
return "all-docs-object-detection-metrics-per-class.tsv"
@property
def default_agg_tsv_name(self):
return "aggregate-object-detection-metrics-per-class.tsv"
def _process_document(self, doc: Path) -> Optional[list]:
"""Calculate both class-aggregated and per-class metrics for a single document.
Args:
doc (Path): path to the OD dump file
Returns:
tuple: a tuple of aggregated and per-class metrics for a single document
"""
try:
doctype, prediction_file, ground_truth_file = self._get_paths(doc)
except ValueError as e:
logger.error(f"Failed to process document {doc}: {e}")
return None
processor = ObjectDetectionEvalProcessor.from_json_files(
prediction_file_path=prediction_file,
ground_truth_file_path=ground_truth_file,
)
_, per_class_metrics = processor.get_metrics()
per_class_metrics_row = [
ground_truth_file.stem,
doctype,
None, # connector
]
for combined_metric_name in self.supported_metric_names:
metric = "_".join(combined_metric_name.split("_")[:-1])
class_name = combined_metric_name.split("_")[-1]
class_metrics = getattr(per_class_metrics, metric)
per_class_metrics_row.append(class_metrics[class_name])
return per_class_metrics_row
def _set_supported_metrics(self):
"""Sets the supported metrics based on the classes found in the ground truth files.
The difference between per class and aggregated calculator is that the list of classes
(so the metrics) bases on the contents of the GT / prediction files.
"""
metrics = ["f1_score", "precision", "recall", "m_ap"]
classes = set()
for gt_file in self._ground_truth_paths:
gt_file_path = self.ground_truths_dir / gt_file
with open(gt_file_path) as f:
gt = json.load(f)
gt_classes = gt["object_detection_classes"]
classes.update(gt_classes)
per_class_metric_names = []
for metric in metrics:
for class_name in classes:
per_class_metric_names.append(f"{metric}_{class_name}")
self.per_class_metric_names = sorted(per_class_metric_names)
class ObjectDetectionAggregatedMetricsCalculator(ObjectDetectionMetricsCalculatorBase):
"""Calculates object detection metrics for each document and aggregates by all classes"""
@property
def supported_metric_names(self):
return ["f1_score", "precision", "recall", "m_ap"]
@property
def default_tsv_name(self):
return "all-docs-object-detection-metrics.tsv"
@property
def default_agg_tsv_name(self):
return "aggregate-object-detection-metrics.tsv"
def _process_document(self, doc: Path) -> Optional[list]:
"""Calculate both class-aggregated and per-class metrics for a single document.
Args:
doc (Path): path to the OD dump file
Returns:
list: a list of aggregated metrics for a single document
"""
try:
doctype, prediction_file, ground_truth_file = self._get_paths(doc)
except ValueError as e:
logger.error(f"Failed to process document {doc}: {e}")
return None
processor = ObjectDetectionEvalProcessor.from_json_files(
prediction_file_path=prediction_file,
ground_truth_file_path=ground_truth_file,
)
metrics, _ = processor.get_metrics()
return [
ground_truth_file.stem,
doctype,
None, # connector
] + [getattr(metrics, metric) for metric in self.supported_metric_names]