#! /usr/bin/env python3 from __future__ import annotations import concurrent.futures import json import logging import os import sys from abc import ABC, abstractmethod from dataclasses import dataclass from pathlib import Path from typing import List, Optional, Union import numpy as np import pandas as pd from tqdm import tqdm from unstructured.metrics.element_type import ( calculate_element_type_percent_match, get_element_type_frequency, ) from unstructured.metrics.object_detection import ( ObjectDetectionEvalProcessor, ) from unstructured.metrics.table.table_eval import TableEvalProcessor from unstructured.metrics.text_extraction import calculate_accuracy, calculate_percent_missing_text from unstructured.metrics.utils import ( _count, _display, _format_grouping_output, _mean, _prepare_output_cct, _pstdev, _read_text_file, _rename_aggregated_columns, _stdev, _write_to_file, ) logger = logging.getLogger("unstructured.eval") handler = logging.StreamHandler() handler.name = "eval_log_handler" formatter = logging.Formatter("%(asctime)s %(processName)-10s %(levelname)-8s %(message)s") handler.setFormatter(formatter) # Only want to add the handler once if "eval_log_handler" not in [h.name for h in logger.handlers]: logger.addHandler(handler) logger.setLevel(logging.DEBUG) AGG_HEADERS = ["metric", "average", "sample_sd", "population_sd", "count"] AGG_HEADERS_MAPPING = { "index": "metric", "_mean": "average", "_stdev": "sample_sd", "_pstdev": "population_sd", "_count": "count", } OUTPUT_TYPE_OPTIONS = ["json", "txt"] @dataclass class BaseMetricsCalculator(ABC): """Foundation class for specialized metrics calculators. It provides a common interface for calculating metrics based on outputs and ground truths. Those can be provided as either directories or lists of files. """ documents_dir: str | Path ground_truths_dir: str | Path def __post_init__(self): """Discover all files in the provided directories.""" self.documents_dir = Path(self.documents_dir).resolve() self.ground_truths_dir = Path(self.ground_truths_dir).resolve() # -- auto-discover all files in the directories -- self._document_paths = [ path.relative_to(self.documents_dir) for path in self.documents_dir.glob("*") if path.is_file() ] self._ground_truth_paths = [ path.relative_to(self.ground_truths_dir) for path in self.ground_truths_dir.glob("*") if path.is_file() ] @property @abstractmethod def default_tsv_name(self): """Default name for the per-document metrics TSV file.""" @property @abstractmethod def default_agg_tsv_name(self): """Default name for the aggregated metrics TSV file.""" @abstractmethod def _generate_dataframes(self, rows: list) -> tuple[pd.DataFrame, pd.DataFrame]: """Generates pandas DataFrames from the list of rows. The first DF (index 0) is a dataframe containing metrics per file. The second DF (index 1) is a dataframe containing the aggregated metrics. """ def on_files( self, document_paths: Optional[list[str | Path]] = None, ground_truth_paths: Optional[list[str | Path]] = None, ) -> BaseMetricsCalculator: """Overrides the default list of files to process.""" if document_paths: self._document_paths = [Path(p) for p in document_paths] if ground_truth_paths: self._ground_truth_paths = [Path(p) for p in ground_truth_paths] return self def calculate( self, executor: Optional[concurrent.futures.Executor] = None, export_dir: Optional[str | Path] = None, visualize_progress: bool = True, display_agg_df: bool = True, ) -> pd.DataFrame: """Calculates metrics for each document using the provided executor. * Optionally, the results can be exported and displayed. * It loops through the list of structured output from all of `documents_dir` or selected files from `document_paths`, and compares them with gold-standard of the same file name under `ground_truths_dir` or selected files from `ground_truth_paths`. Args: executor: concurrent.futures.Executor instance export_dir: directory to export the results visualize_progress: whether to display progress bar display_agg_df: whether to display the aggregated results Returns: Metrics for each document as a pandas DataFrame """ if executor is None: executor = self._default_executor() rows = self._process_all_documents(executor, visualize_progress) df, agg_df = self._generate_dataframes(rows) if export_dir is not None: _write_to_file(export_dir, self.default_tsv_name, df) _write_to_file(export_dir, self.default_agg_tsv_name, agg_df) if display_agg_df is True: _display(agg_df) return df @classmethod def _default_executor(cls): max_processors = int(os.environ.get("MAX_PROCESSES", os.cpu_count())) logger.info(f"Configuring a pool of {max_processors} processors for parallel processing.") return cls._get_executor_class()(max_workers=max_processors) @classmethod def _get_executor_class( cls, ) -> type[concurrent.futures.ThreadPoolExecutor] | type[concurrent.futures.ProcessPoolExecutor]: return concurrent.futures.ProcessPoolExecutor def _process_all_documents( self, executor: concurrent.futures.Executor, visualize_progress: bool ) -> list: """Triggers processing of all documents using the provided executor. Failures are omitted from the returned result. """ with executor: return [ row for row in tqdm( executor.map(self._try_process_document, self._document_paths), total=len(self._document_paths), leave=False, disable=not visualize_progress, ) if row is not None ] def _try_process_document(self, doc: Path) -> Optional[list]: """Safe wrapper around the document processing method.""" logger.info(f"Processing {doc}") try: return self._process_document(doc) except Exception as e: logger.error(f"Failed to process document {doc}: {e}") return None @abstractmethod def _process_document(self, doc: Path) -> Optional[list]: """Should return all metadata and metrics for a single document.""" @dataclass class TableStructureMetricsCalculator(BaseMetricsCalculator): """Calculates the following metrics for tables: - tables found accuracy - table-level accuracy - element in column index accuracy - element in row index accuracy - element's column content accuracy - element's row content accuracy It also calculates the aggregated accuracy. """ cutoff: Optional[float] = None weighted_average: bool = True include_false_positives: bool = True def __post_init__(self): super().__post_init__() @property def supported_metric_names(self): return [ "total_tables", "table_level_acc", "table_detection_recall", "table_detection_precision", "table_detection_f1", "composite_structure_acc", "element_col_level_index_acc", "element_row_level_index_acc", "element_col_level_content_acc", "element_row_level_content_acc", ] @property def default_tsv_name(self): return "all-docs-table-structure-accuracy.tsv" @property def default_agg_tsv_name(self): return "aggregate-table-structure-accuracy.tsv" def _process_document(self, doc: Path) -> Optional[list]: doc_path = Path(doc) out_filename = doc_path.stem doctype = Path(out_filename).suffix src_gt_filename = out_filename + ".json" connector = doc_path.parts[-2] if len(doc_path.parts) > 1 else None if src_gt_filename in self._ground_truth_paths: # type: ignore return None prediction_file = self.documents_dir / doc if not prediction_file.exists(): logger.warning(f"Prediction file {prediction_file} does not exist, skipping") return None ground_truth_file = self.ground_truths_dir / src_gt_filename if not ground_truth_file.exists(): logger.warning(f"Ground truth file {ground_truth_file} does not exist, skipping") return None processor_from_text_as_html = TableEvalProcessor.from_json_files( prediction_file=prediction_file, ground_truth_file=ground_truth_file, cutoff=self.cutoff, source_type="html", ) report_from_html = processor_from_text_as_html.process_file() return [ out_filename, doctype, connector, report_from_html.total_predicted_tables, ] + [getattr(report_from_html, metric) for metric in self.supported_metric_names] def _generate_dataframes(self, rows): headers = [ "filename", "doctype", "connector", "total_predicted_tables", ] + self.supported_metric_names df = pd.DataFrame(rows, columns=headers) df["_table_weights"] = df["total_tables"] if self.include_false_positives: # we give false positive tables a 1 table worth of weight in computing table level acc df["_table_weights"][df.total_tables.eq(0) & df.total_predicted_tables.gt(0)] = 1 # filter down to only those with actual and/or predicted tables has_tables_df = df[df["_table_weights"] > 0] if not self.weighted_average: # for all non zero elements assign them value 1 df["_table_weights"] = df["_table_weights"].apply( lambda table_weight: 1 if table_weight != 0 else 0 ) if has_tables_df.empty: agg_df = pd.DataFrame( [[metric, None, None, None, 0] for metric in self.supported_metric_names] ).reset_index() else: element_metrics_results = {} for metric in self.supported_metric_names: metric_df = has_tables_df[has_tables_df[metric].notnull()] agg_metric = metric_df[metric].agg([_stdev, _pstdev, _count]).transpose() if metric.startswith("total_tables"): agg_metric["_mean"] = metric_df[metric].mean() elif metric.startswith("table_level_acc"): agg_metric["_mean"] = np.round( np.average(metric_df[metric], weights=metric_df["_table_weights"]), 3, ) else: # false positive tables do not contribute to table structure and content # extraction metrics agg_metric["_mean"] = np.round( np.average(metric_df[metric], weights=metric_df["total_tables"]), 3, ) if agg_metric.empty: element_metrics_results[metric] = pd.Series( data=[None, None, None, 0], index=["_mean", "_stdev", "_pstdev", "_count"] ) else: element_metrics_results[metric] = agg_metric agg_df = pd.DataFrame(element_metrics_results).transpose().reset_index() agg_df = agg_df.rename(columns=AGG_HEADERS_MAPPING) return df, agg_df @dataclass class TextExtractionMetricsCalculator(BaseMetricsCalculator): """Calculates text accuracy and percent missing between document and ground truth texts. It also calculates the aggregated accuracy and percent missing. """ group_by: Optional[str] = None weights: tuple[int, int, int] = (1, 1, 1) document_type: str = "json" def __post_init__(self): super().__post_init__() self._validate_inputs() @property def default_tsv_name(self) -> str: return "all-docs-cct.tsv" @property def default_agg_tsv_name(self) -> str: return "aggregate-scores-cct.tsv" def calculate( self, executor: Optional[concurrent.futures.Executor] = None, export_dir: Optional[str | Path] = None, visualize_progress: bool = True, display_agg_df: bool = True, ) -> pd.DataFrame: """See the parent class for the method's docstring.""" df = super().calculate( executor=executor, export_dir=export_dir, visualize_progress=visualize_progress, display_agg_df=display_agg_df, ) if export_dir is not None and self.group_by: get_mean_grouping(self.group_by, df, export_dir, "text_extraction") return df def _validate_inputs(self): if not self._document_paths: logger.info("No output files to calculate to edit distances for, exiting") sys.exit(0) if self.document_type not in OUTPUT_TYPE_OPTIONS: raise ValueError( "Specified file type under `documents_dir` or `output_list` should be one of " f"`json` or `txt`. The given file type is {self.document_type}, exiting." ) for path in self._document_paths: try: path.suffixes[-1] except IndexError: logger.error(f"File {path} does not have a suffix, skipping") continue if path.suffixes[-1] != f".{self.document_type}": logger.warning( "The directory contains file type inconsistent with the given input. " "Please note that some files will be skipped." ) if not all(path.suffixes[-1] == f".{self.document_type}" for path in self._document_paths): logger.warning( "The directory contains file type inconsistent with the given input. " "Please note that some files will be skipped." ) def _process_document(self, doc: Path) -> Optional[list]: filename = doc.stem doctype = doc.suffixes[-2] connector = doc.parts[0] if len(doc.parts) > 1 else None output_cct, source_cct = self._get_ccts(doc) # NOTE(amadeusz): Levenshtein distance calculation takes too long # skip it if file sizes differ wildly if 0.5 < len(output_cct.encode()) / len(source_cct.encode()) < 2.0: accuracy = round(calculate_accuracy(output_cct, source_cct, self.weights), 3) else: # 0.01 to distinguish it was set manually accuracy = 0.01 percent_missing = round(calculate_percent_missing_text(output_cct, source_cct), 3) return [filename, doctype, connector, accuracy, percent_missing] def _get_ccts(self, doc: Path) -> tuple[str, str]: output_cct = _prepare_output_cct( docpath=self.documents_dir / doc, output_type=self.document_type ) source_cct = _read_text_file(self.ground_truths_dir / doc.with_suffix(".txt")) return output_cct, source_cct def _generate_dataframes(self, rows): headers = ["filename", "doctype", "connector", "cct-accuracy", "cct-%missing"] df = pd.DataFrame(rows, columns=headers) acc = df[["cct-accuracy"]].agg([_mean, _stdev, _pstdev, _count]).transpose() miss = df[["cct-%missing"]].agg([_mean, _stdev, _pstdev, _count]).transpose() if acc.shape[1] == 0 and miss.shape[1] == 0: agg_df = pd.DataFrame(columns=AGG_HEADERS) else: agg_df = pd.concat((acc, miss)).reset_index() agg_df.columns = AGG_HEADERS return df, agg_df @dataclass class ElementTypeMetricsCalculator(BaseMetricsCalculator): """ Calculates element type frequency accuracy, percent missing and aggregated accuracy between document and ground truth. """ group_by: Optional[str] = None def calculate( self, executor: Optional[concurrent.futures.Executor] = None, export_dir: Optional[str | Path] = None, visualize_progress: bool = True, display_agg_df: bool = False, ) -> pd.DataFrame: """See the parent class for the method's docstring.""" df = super().calculate( executor=executor, export_dir=export_dir, visualize_progress=visualize_progress, display_agg_df=display_agg_df, ) if export_dir is not None and self.group_by: get_mean_grouping(self.group_by, df, export_dir, "element_type") return df @property def default_tsv_name(self) -> str: return "all-docs-element-type-frequency.tsv" @property def default_agg_tsv_name(self) -> str: return "aggregate-scores-element-type.tsv" def _process_document(self, doc: Path) -> Optional[list]: filename = doc.stem doctype = doc.suffixes[-2] connector = doc.parts[0] if len(doc.parts) > 1 else None output = get_element_type_frequency(_read_text_file(self.documents_dir / doc)) source = get_element_type_frequency( _read_text_file(self.ground_truths_dir / doc.with_suffix(".json")) ) accuracy = round(calculate_element_type_percent_match(output, source), 3) return [filename, doctype, connector, accuracy] def _generate_dataframes(self, rows): headers = ["filename", "doctype", "connector", "element-type-accuracy"] df = pd.DataFrame(rows, columns=headers) if df.empty: agg_df = pd.DataFrame(["element-type-accuracy", None, None, None, 0]).transpose() else: agg_df = df.agg({"element-type-accuracy": [_mean, _stdev, _pstdev, _count]}).transpose() agg_df = agg_df.reset_index() agg_df.columns = AGG_HEADERS return df, agg_df def get_mean_grouping( group_by: str, data_input: Union[pd.DataFrame, str], export_dir: str, eval_name: str, agg_name: Optional[str] = None, export_filename: Optional[str] = None, ) -> None: """Aggregates accuracy and missing metrics by column name 'doctype' or 'connector', or 'all' for all rows. Export to TSV. If `all`, passing export_name is recommended. Args: group_by (str): Grouping category ('doctype' or 'connector' or 'all'). data_input (Union[pd.DataFrame, str]): DataFrame or path to a CSV/TSV file. export_dir (str): Directory for the exported TSV file. eval_name (str): Evaluated metric ('text_extraction' or 'element_type'). agg_name (str, optional): String to use with export filename. Default is `cct` for group_by `text_extraction` and `element-type` for `element_type` export_name (str, optional): Export filename. """ if group_by not in ("doctype", "connector") and group_by != "all": raise ValueError("Invalid grouping category. Returning a non-group evaluation.") if eval_name == "text_extraction": agg_fields = ["cct-accuracy", "cct-%missing"] agg_name = "cct" elif eval_name == "element_type": agg_fields = ["element-type-accuracy"] agg_name = "element-type" elif eval_name == "object_detection": agg_fields = ["f1_score", "m_ap"] agg_name = "object-detection" else: raise ValueError( f"Unknown metric for eval {eval_name}. " f"Expected `text_extraction` or `element_type` or `table_extraction`." ) if isinstance(data_input, str): if not os.path.exists(data_input): raise FileNotFoundError(f"File {data_input} not found.") if data_input.endswith(".csv"): df = pd.read_csv(data_input, header=None) elif data_input.endswith(".tsv"): df = pd.read_csv(data_input, sep="\t") elif data_input.endswith(".txt"): df = pd.read_csv(data_input, sep="\t", header=None) else: raise ValueError("Please provide a .csv or .tsv file.") else: df = data_input if df.empty: raise SystemExit("Data is empty. Exiting.") elif group_by != "all" and (group_by not in df.columns or df[group_by].isnull().all()): raise SystemExit( f"Data cannot be aggregated by `{group_by}`." f" Check if it's empty or the column is missing/empty." ) grouped_df = [] if group_by and group_by != "all": for field in agg_fields: grouped_df.append( _rename_aggregated_columns( df.groupby(group_by).agg({field: [_mean, _stdev, _pstdev, _count]}) ) ) if group_by == "all": df["grouping_key"] = 0 for field in agg_fields: grouped_df.append( _rename_aggregated_columns( df.groupby("grouping_key").agg({field: [_mean, _stdev, _pstdev, _count]}) ) ) grouped_df = _format_grouping_output(*grouped_df) if "grouping_key" in grouped_df.columns.get_level_values(0): grouped_df = grouped_df.drop("grouping_key", axis=1, level=0) if export_filename: if not export_filename.endswith(".tsv"): export_filename = export_filename + ".tsv" _write_to_file(export_dir, export_filename, grouped_df) else: _write_to_file(export_dir, f"all-{group_by}-agg-{agg_name}.tsv", grouped_df) def filter_metrics( data_input: Union[str, pd.DataFrame], filter_list: Union[str, List[str]], filter_by: str = "filename", export_filename: Optional[str] = None, export_dir: str = "metrics", return_type: str = "file", ) -> Optional[pd.DataFrame]: """Reads the data_input file and filter only selected row available in filter_list. Args: data_input (str, dataframe): the source data, path to file or dataframe filter_list (str, list): the filter, path to file or list of string filter_by (str): data_input's column to filter the filter_list to export_filename (str, optional): export filename. required when return_type is "file" export_dir (str, optional): export directory. default to <current directory>/metrics return_type (str): "file" or "dataframe" """ if isinstance(data_input, str): if not os.path.exists(data_input): raise FileNotFoundError(f"File {data_input} not found.") if data_input.endswith(".csv"): df = pd.read_csv(data_input, header=None) elif data_input.endswith(".tsv"): df = pd.read_csv(data_input, sep="\t") elif data_input.endswith(".txt"): df = pd.read_csv(data_input, sep="\t", header=None) else: raise ValueError("Please provide a .csv or .tsv file.") else: df = data_input if isinstance(filter_list, str): if not os.path.exists(filter_list): raise FileNotFoundError(f"File {filter_list} not found.") if filter_list.endswith(".csv"): filter_df = pd.read_csv(filter_list, header=None) elif filter_list.endswith(".tsv"): filter_df = pd.read_csv(filter_list, sep="\t") elif filter_list.endswith(".txt"): filter_df = pd.read_csv(filter_list, sep="\t", header=None) else: raise ValueError("Please provide a .csv or .tsv file.") filter_list = filter_df.iloc[:, 0].astype(str).values.tolist() elif not isinstance(filter_list, list): raise ValueError("Please provide a List of strings or path to file.") if filter_by not in df.columns: raise ValueError("`filter_by` key does not exists in the data provided.") res = df[df[filter_by].isin(filter_list)] if res.empty: raise SystemExit("No common file names between data_input and filter_list. Exiting.") if return_type == "dataframe": return res elif return_type == "file" and export_filename: _write_to_file(export_dir, export_filename, res) elif return_type == "file" and not export_filename: raise ValueError("Please provide `export_filename`.") else: raise ValueError("Return type must be either `dataframe` or `file`.") @dataclass class ObjectDetectionMetricsCalculatorBase(BaseMetricsCalculator, ABC): """ Calculates object detection metrics for each document: - f1 score - precision - recall - average precision (mAP) It also calculates aggregated metrics. """ def __post_init__(self): super().__post_init__() self._document_paths = [ path.relative_to(self.documents_dir) for path in self.documents_dir.rglob("analysis/*/layout_dump/object_detection.json") if path.is_file() ] @property def supported_metric_names(self): return ["f1_score", "precision", "recall", "m_ap"] @property def default_tsv_name(self): return "all-docs-object-detection-metrics.tsv" @property def default_agg_tsv_name(self): return "aggregate-object-detection-metrics.tsv" def _find_file_in_ground_truth(self, file_stem: str) -> Optional[Path]: """Find the file corresponding to OD model dump file among the set of ground truth files The files in ground truth paths keep the original extension and have .json suffix added, e.g.: some_document.pdf.json poster.jpg.json To compare to `file_stem` we need to take the prefix part of the file, thus double-stem is applied. """ for path in self._ground_truth_paths: if Path(path.stem).stem == file_stem: return path return None def _get_paths(self, doc: Path) -> tuple(str, Path, Path): """Resolves ground doctype, prediction file path and ground truth path. As OD dump directory structure differes from other simple outputs, it needs a specific processing to match the output OD dump file with corresponding OD GT file. The outputs are placed in a dicrectory structure: analysis |- document_name |- layout_dump |- object_detection.json |- bboxes # not used in this evaluation and the GT file is pleced in od_gt directory for given dataset dataset_name |- od_gt |- document_name.pdf.json Args: doc (Path): path to the OD dump file Returns: tuple: doctype, prediction file path, ground truth path """ od_dump_path = Path(doc) file_stem = od_dump_path.parts[-3] # we take the `document_name` - so the filename stem src_gt_filename = self._find_file_in_ground_truth(file_stem) if src_gt_filename not in self._ground_truth_paths: raise ValueError(f"Ground truth file {src_gt_filename} not found in list of GT files") doctype = Path(src_gt_filename.stem).suffix[1:] prediction_file = self.documents_dir / doc if not prediction_file.exists(): logger.warning(f"Prediction file {prediction_file} does not exist, skipping") raise ValueError(f"Prediction file {prediction_file} does not exist") ground_truth_file = self.ground_truths_dir / src_gt_filename if not ground_truth_file.exists(): logger.warning(f"Ground truth file {ground_truth_file} does not exist, skipping") raise ValueError(f"Ground truth file {ground_truth_file} does not exist") return doctype, prediction_file, ground_truth_file def _generate_dataframes(self, rows) -> tuple[pd.DataFrame, pd.DataFrame]: headers = ["filename", "doctype", "connector"] + self.supported_metric_names df = pd.DataFrame(rows, columns=headers) if df.empty: agg_df = pd.DataFrame(columns=AGG_HEADERS) else: element_metrics_results = {} for metric in self.supported_metric_names: metric_df = df[df[metric].notnull()] agg_metric = metric_df[metric].agg([_mean, _stdev, _pstdev, _count]).transpose() if agg_metric.empty: element_metrics_results[metric] = pd.Series( data=[None, None, None, 0], index=["_mean", "_stdev", "_pstdev", "_count"] ) else: element_metrics_results[metric] = agg_metric agg_df = pd.DataFrame(element_metrics_results).transpose().reset_index() agg_df.columns = AGG_HEADERS return df, agg_df class ObjectDetectionPerClassMetricsCalculator(ObjectDetectionMetricsCalculatorBase): def __post_init__(self): super().__post_init__() self.per_class_metric_names: list[str] | None = None self._set_supported_metrics() @property def supported_metric_names(self): if self.per_class_metric_names: return self.per_class_metric_names else: raise ValueError("per_class_metrics not initialized - cannot get class names") @property def default_tsv_name(self): return "all-docs-object-detection-metrics-per-class.tsv" @property def default_agg_tsv_name(self): return "aggregate-object-detection-metrics-per-class.tsv" def _process_document(self, doc: Path) -> Optional[list]: """Calculate both class-aggregated and per-class metrics for a single document. Args: doc (Path): path to the OD dump file Returns: tuple: a tuple of aggregated and per-class metrics for a single document """ try: doctype, prediction_file, ground_truth_file = self._get_paths(doc) except ValueError as e: logger.error(f"Failed to process document {doc}: {e}") return None processor = ObjectDetectionEvalProcessor.from_json_files( prediction_file_path=prediction_file, ground_truth_file_path=ground_truth_file, ) _, per_class_metrics = processor.get_metrics() per_class_metrics_row = [ ground_truth_file.stem, doctype, None, # connector ] for combined_metric_name in self.supported_metric_names: metric = "_".join(combined_metric_name.split("_")[:-1]) class_name = combined_metric_name.split("_")[-1] class_metrics = getattr(per_class_metrics, metric) per_class_metrics_row.append(class_metrics[class_name]) return per_class_metrics_row def _set_supported_metrics(self): """Sets the supported metrics based on the classes found in the ground truth files. The difference between per class and aggregated calculator is that the list of classes (so the metrics) bases on the contents of the GT / prediction files. """ metrics = ["f1_score", "precision", "recall", "m_ap"] classes = set() for gt_file in self._ground_truth_paths: gt_file_path = self.ground_truths_dir / gt_file with open(gt_file_path) as f: gt = json.load(f) gt_classes = gt["object_detection_classes"] classes.update(gt_classes) per_class_metric_names = [] for metric in metrics: for class_name in classes: per_class_metric_names.append(f"{metric}_{class_name}") self.per_class_metric_names = sorted(per_class_metric_names) class ObjectDetectionAggregatedMetricsCalculator(ObjectDetectionMetricsCalculatorBase): """Calculates object detection metrics for each document and aggregates by all classes""" @property def supported_metric_names(self): return ["f1_score", "precision", "recall", "m_ap"] @property def default_tsv_name(self): return "all-docs-object-detection-metrics.tsv" @property def default_agg_tsv_name(self): return "aggregate-object-detection-metrics.tsv" def _process_document(self, doc: Path) -> Optional[list]: """Calculate both class-aggregated and per-class metrics for a single document. Args: doc (Path): path to the OD dump file Returns: list: a list of aggregated metrics for a single document """ try: doctype, prediction_file, ground_truth_file = self._get_paths(doc) except ValueError as e: logger.error(f"Failed to process document {doc}: {e}") return None processor = ObjectDetectionEvalProcessor.from_json_files( prediction_file_path=prediction_file, ground_truth_file_path=ground_truth_file, ) metrics, _ = processor.get_metrics() return [ ground_truth_file.stem, doctype, None, # connector ] + [getattr(metrics, metric) for metric in self.supported_metric_names]
Memory