"""LangSmith Pytest hooks.""" import importlib.util import json import logging import os import time from collections import defaultdict from threading import Lock from typing import Any import pytest from langsmith import utils as ls_utils from langsmith.testing._internal import test as ls_test logger = logging.getLogger(__name__) def pytest_addoption(parser): """Set a boolean flag for LangSmith output. Skip if --langsmith-output is already defined. """ try: # Try to add the option, will raise if it already exists group = parser.getgroup("langsmith", "LangSmith") group.addoption( "--langsmith-output", action="store_true", default=False, help="Use LangSmith output (requires 'rich').", ) except ValueError: # Option already exists logger.warning( "LangSmith output flag cannot be added because it's already defined." ) def _handle_output_args(args): """Handle output arguments.""" if any(opt in args for opt in ["--langsmith-output"]): # Only add --quiet if it's not already there if not any(a in args for a in ["-qq"]): args.insert(0, "-qq") # Disable built-in output capturing if not any(a in args for a in ["-s", "--capture=no"]): args.insert(0, "-s") if pytest.__version__.startswith("7."): def pytest_cmdline_preparse(config, args): """Call immediately after command line options are parsed (pytest v7).""" _handle_output_args(args) else: def pytest_load_initial_conftests(args): """Handle args in pytest v8+.""" _handle_output_args(args) @pytest.hookimpl(hookwrapper=True) def pytest_runtest_call(item): """Apply LangSmith tracking to tests marked with @pytest.mark.langsmith.""" marker = item.get_closest_marker("langsmith") if marker: # Get marker kwargs if any (e.g., # @pytest.mark.langsmith(output_keys=["expected"])) kwargs = marker.kwargs if marker else {} # Wrap the test function with our test decorator original_func = item.obj item.obj = ls_test(**kwargs)(original_func) request_obj = getattr(item, "_request", None) if request_obj is not None and "request" not in item.funcargs: item.funcargs["request"] = request_obj if request_obj is not None and "request" not in item._fixtureinfo.argnames: # Create a new FuncFixtureInfo instance with updated argnames item._fixtureinfo = type(item._fixtureinfo)( argnames=item._fixtureinfo.argnames + ("request",), initialnames=item._fixtureinfo.initialnames, names_closure=item._fixtureinfo.names_closure, name2fixturedefs=item._fixtureinfo.name2fixturedefs, ) yield @pytest.hookimpl def pytest_report_teststatus(report, config): """Remove the short test-status character outputs ("./F").""" # The hook normally returns a 3-tuple: (short_letter, verbose_word, color) # By returning empty strings, the progress characters won't show. if config.getoption("--langsmith-output"): return "", "", "" class LangSmithPlugin: """Plugin for rendering LangSmith results.""" def __init__(self): """Initialize.""" from rich.console import Console # type: ignore[import-not-found] from rich.live import Live # type: ignore[import-not-found] self.test_suites = defaultdict(list) self.test_suite_urls = {} self.process_status = {} # Track process status self.status_lock = Lock() # Thread-safe updates self.console = Console() self.live = Live( self.generate_tables(), console=self.console, refresh_per_second=10 ) self.live.start() self.live.console.print("Collecting tests...") def pytest_collection_finish(self, session): """Call after collection phase is completed and session.items is populated.""" self.collected_nodeids = set() for item in session.items: self.collected_nodeids.add(item.nodeid) def add_process_to_test_suite(self, test_suite, process_id): """Group a test case with its test suite.""" self.test_suites[test_suite].append(process_id) def update_process_status(self, process_id, status): """Update test results.""" # First update if not self.process_status: self.live.console.print("Running tests...") with self.status_lock: current_status = self.process_status.get(process_id, {}) self.process_status[process_id] = _merge_statuses( status, current_status, unpack=["feedback", "inputs", "reference_outputs", "outputs"], ) self.live.update(self.generate_tables()) def pytest_runtest_logstart(self, nodeid): """Initialize live display when first test starts.""" self.update_process_status(nodeid, {"status": "running"}) def generate_tables(self): """Generate a collection of tables—one per suite. Returns a 'Group' object so it can be rendered simultaneously by Rich Live. """ from rich.console import Group tables = [] for suite_name in self.test_suites: table = self._generate_table(suite_name) tables.append(table) group = Group(*tables) return group def _generate_table(self, suite_name: str): """Generate results table.""" from rich.table import Table # type: ignore[import-not-found] process_ids = self.test_suites[suite_name] title = f"""Test Suite: [bold]{suite_name}[/bold] LangSmith URL: [bright_cyan]{self.test_suite_urls[suite_name]}[/bright_cyan]""" # noqa: E501 table = Table(title=title, title_justify="left") table.add_column("Test") table.add_column("Inputs") table.add_column("Ref outputs") table.add_column("Outputs") table.add_column("Status") table.add_column("Feedback") table.add_column("Duration") # Test, inputs, ref outputs, outputs col width max_status = len("status") max_duration = len("duration") now = time.time() durations = [] numeric_feedbacks = defaultdict(list) # Gather data only for this suite suite_statuses = {pid: self.process_status[pid] for pid in process_ids} for pid, status in suite_statuses.items(): duration = status.get("end_time", now) - status.get("start_time", now) durations.append(duration) for k, v in status.get("feedback", {}).items(): if isinstance(v, (float, int, bool)): numeric_feedbacks[k].append(v) max_duration = max(len(f"{duration:.2f}s"), max_duration) max_status = max(len(status.get("status", "queued")), max_status) passed_count = sum(s.get("status") == "passed" for s in suite_statuses.values()) failed_count = sum(s.get("status") == "failed" for s in suite_statuses.values()) # You could arrange a row to show the aggregated data—here, in the last column: if passed_count + failed_count: rate = passed_count / (passed_count + failed_count) color = "green" if rate == 1 else "red" aggregate_status = f"[{color}]{rate:.0%}[/{color}]" else: aggregate_status = "Passed: --" if durations: aggregate_duration = f"{sum(durations) / len(durations):.2f}s" else: aggregate_duration = "--s" if numeric_feedbacks: aggregate_feedback = "\n".join( f"{k}: {sum(v) / len(v)}" for k, v in numeric_feedbacks.items() ) else: aggregate_feedback = "--" max_duration = max(max_duration, len(aggregate_duration)) max_dynamic_col_width = (self.console.width - (max_status + max_duration)) // 5 max_dynamic_col_width = max(max_dynamic_col_width, 8) for pid, status in suite_statuses.items(): status_color = { "running": "yellow", "passed": "green", "failed": "red", "skipped": "cyan", }.get(status.get("status", "queued"), "white") duration = status.get("end_time", now) - status.get("start_time", now) feedback = "\n".join( f"{_abbreviate(k, max_len=max_dynamic_col_width)}: {int(v) if isinstance(v, bool) else v}" # noqa: E501 for k, v in status.get("feedback", {}).items() ) inputs = _dumps_with_fallback(status.get("inputs", {})) reference_outputs = _dumps_with_fallback( status.get("reference_outputs", {}) ) outputs = _dumps_with_fallback(status.get("outputs", {})) table.add_row( _abbreviate_test_name(str(pid), max_len=max_dynamic_col_width), _abbreviate(inputs, max_len=max_dynamic_col_width), _abbreviate(reference_outputs, max_len=max_dynamic_col_width), _abbreviate(outputs, max_len=max_dynamic_col_width)[ -max_dynamic_col_width: ], f"[{status_color}]{status.get('status', 'queued')}[/{status_color}]", feedback, f"{duration:.2f}s", ) # Add a blank row or a section separator if you like: table.add_row("", "", "", "", "", "", "") # Finally, our “footer” row: table.add_row( "[bold]Averages[/bold]", "", "", "", aggregate_status, aggregate_feedback, aggregate_duration, ) return table def pytest_configure(self, config): """Disable warning reporting and show no warnings in output.""" # Disable general warning reporting config.option.showwarnings = False # Disable warning summary reporter = config.pluginmanager.get_plugin("warnings-plugin") if reporter: reporter.warning_summary = lambda *args, **kwargs: None def pytest_sessionfinish(self, session): """Stop Rich Live rendering at the end of the session.""" self.live.stop() self.live.console.print("\nFinishing up...") def pytest_configure(config): """Register the 'langsmith' marker.""" config.addinivalue_line( "markers", "langsmith: mark test to be tracked in LangSmith" ) if config.getoption("--langsmith-output"): if not importlib.util.find_spec("rich"): msg = ( "Must have 'rich' installed to use --langsmith-output. " "Please install with: `pip install -U 'langsmith[pytest]'`" ) raise ValueError(msg) if os.environ.get("PYTEST_XDIST_TESTRUNUID"): msg = ( "--langsmith-output not supported with pytest-xdist. " "Please remove the '--langsmith-output' option or '-n' option." ) raise ValueError(msg) if ls_utils.test_tracking_is_disabled(): msg = ( "--langsmith-output not supported when env var" "LANGSMITH_TEST_TRACKING='false'. Please remove the" "'--langsmith-output' option " "or enable test tracking." ) raise ValueError(msg) config.pluginmanager.register(LangSmithPlugin(), "langsmith_output_plugin") # Suppress warnings summary config.option.showwarnings = False def _abbreviate(x: str, max_len: int) -> str: if len(x) > max_len: return x[: max_len - 3] + "..." else: return x def _abbreviate_test_name(test_name: str, max_len: int) -> str: if len(test_name) > max_len: file, test = test_name.split("::") if len(".py::" + test) > max_len: return "..." + test[-(max_len - 3) :] file_len = max_len - len("...::" + test) return "..." + file[-file_len:] + "::" + test else: return test_name def _merge_statuses(update: dict, current: dict, *, unpack: list[str]) -> dict: for path in unpack: if path_update := update.pop(path, None): path_current = current.get(path, {}) if isinstance(path_update, dict) and isinstance(path_current, dict): current[path] = {**path_current, **path_update} else: current[path] = path_update return {**current, **update} def _dumps_with_fallback(obj: Any) -> str: try: return json.dumps(obj) except Exception: return "unserializable"
Memory