"""LangSmith Pytest hooks."""
import importlib.util
import json
import logging
import os
import time
from collections import defaultdict
from threading import Lock
from typing import Any
import pytest
from langsmith import utils as ls_utils
from langsmith.testing._internal import test as ls_test
logger = logging.getLogger(__name__)
def pytest_addoption(parser):
"""Set a boolean flag for LangSmith output.
Skip if --langsmith-output is already defined.
"""
try:
# Try to add the option, will raise if it already exists
group = parser.getgroup("langsmith", "LangSmith")
group.addoption(
"--langsmith-output",
action="store_true",
default=False,
help="Use LangSmith output (requires 'rich').",
)
except ValueError:
# Option already exists
logger.warning(
"LangSmith output flag cannot be added because it's already defined."
)
def _handle_output_args(args):
"""Handle output arguments."""
if any(opt in args for opt in ["--langsmith-output"]):
# Only add --quiet if it's not already there
if not any(a in args for a in ["-qq"]):
args.insert(0, "-qq")
# Disable built-in output capturing
if not any(a in args for a in ["-s", "--capture=no"]):
args.insert(0, "-s")
if pytest.__version__.startswith("7."):
def pytest_cmdline_preparse(config, args):
"""Call immediately after command line options are parsed (pytest v7)."""
_handle_output_args(args)
else:
def pytest_load_initial_conftests(args):
"""Handle args in pytest v8+."""
_handle_output_args(args)
@pytest.hookimpl(hookwrapper=True)
def pytest_runtest_call(item):
"""Apply LangSmith tracking to tests marked with @pytest.mark.langsmith."""
marker = item.get_closest_marker("langsmith")
if marker:
# Get marker kwargs if any (e.g.,
# @pytest.mark.langsmith(output_keys=["expected"]))
kwargs = marker.kwargs if marker else {}
# Wrap the test function with our test decorator
original_func = item.obj
item.obj = ls_test(**kwargs)(original_func)
request_obj = getattr(item, "_request", None)
if request_obj is not None and "request" not in item.funcargs:
item.funcargs["request"] = request_obj
if request_obj is not None and "request" not in item._fixtureinfo.argnames:
# Create a new FuncFixtureInfo instance with updated argnames
item._fixtureinfo = type(item._fixtureinfo)(
argnames=item._fixtureinfo.argnames + ("request",),
initialnames=item._fixtureinfo.initialnames,
names_closure=item._fixtureinfo.names_closure,
name2fixturedefs=item._fixtureinfo.name2fixturedefs,
)
yield
@pytest.hookimpl
def pytest_report_teststatus(report, config):
"""Remove the short test-status character outputs ("./F")."""
# The hook normally returns a 3-tuple: (short_letter, verbose_word, color)
# By returning empty strings, the progress characters won't show.
if config.getoption("--langsmith-output"):
return "", "", ""
class LangSmithPlugin:
"""Plugin for rendering LangSmith results."""
def __init__(self):
"""Initialize."""
from rich.console import Console # type: ignore[import-not-found]
from rich.live import Live # type: ignore[import-not-found]
self.test_suites = defaultdict(list)
self.test_suite_urls = {}
self.process_status = {} # Track process status
self.status_lock = Lock() # Thread-safe updates
self.console = Console()
self.live = Live(
self.generate_tables(), console=self.console, refresh_per_second=10
)
self.live.start()
self.live.console.print("Collecting tests...")
def pytest_collection_finish(self, session):
"""Call after collection phase is completed and session.items is populated."""
self.collected_nodeids = set()
for item in session.items:
self.collected_nodeids.add(item.nodeid)
def add_process_to_test_suite(self, test_suite, process_id):
"""Group a test case with its test suite."""
self.test_suites[test_suite].append(process_id)
def update_process_status(self, process_id, status):
"""Update test results."""
# First update
if not self.process_status:
self.live.console.print("Running tests...")
with self.status_lock:
current_status = self.process_status.get(process_id, {})
self.process_status[process_id] = _merge_statuses(
status,
current_status,
unpack=["feedback", "inputs", "reference_outputs", "outputs"],
)
self.live.update(self.generate_tables())
def pytest_runtest_logstart(self, nodeid):
"""Initialize live display when first test starts."""
self.update_process_status(nodeid, {"status": "running"})
def generate_tables(self):
"""Generate a collection of tables—one per suite.
Returns a 'Group' object so it can be rendered simultaneously by Rich Live.
"""
from rich.console import Group
tables = []
for suite_name in self.test_suites:
table = self._generate_table(suite_name)
tables.append(table)
group = Group(*tables)
return group
def _generate_table(self, suite_name: str):
"""Generate results table."""
from rich.table import Table # type: ignore[import-not-found]
process_ids = self.test_suites[suite_name]
title = f"""Test Suite: [bold]{suite_name}[/bold]
LangSmith URL: [bright_cyan]{self.test_suite_urls[suite_name]}[/bright_cyan]""" # noqa: E501
table = Table(title=title, title_justify="left")
table.add_column("Test")
table.add_column("Inputs")
table.add_column("Ref outputs")
table.add_column("Outputs")
table.add_column("Status")
table.add_column("Feedback")
table.add_column("Duration")
# Test, inputs, ref outputs, outputs col width
max_status = len("status")
max_duration = len("duration")
now = time.time()
durations = []
numeric_feedbacks = defaultdict(list)
# Gather data only for this suite
suite_statuses = {pid: self.process_status[pid] for pid in process_ids}
for pid, status in suite_statuses.items():
duration = status.get("end_time", now) - status.get("start_time", now)
durations.append(duration)
for k, v in status.get("feedback", {}).items():
if isinstance(v, (float, int, bool)):
numeric_feedbacks[k].append(v)
max_duration = max(len(f"{duration:.2f}s"), max_duration)
max_status = max(len(status.get("status", "queued")), max_status)
passed_count = sum(s.get("status") == "passed" for s in suite_statuses.values())
failed_count = sum(s.get("status") == "failed" for s in suite_statuses.values())
# You could arrange a row to show the aggregated data—here, in the last column:
if passed_count + failed_count:
rate = passed_count / (passed_count + failed_count)
color = "green" if rate == 1 else "red"
aggregate_status = f"[{color}]{rate:.0%}[/{color}]"
else:
aggregate_status = "Passed: --"
if durations:
aggregate_duration = f"{sum(durations) / len(durations):.2f}s"
else:
aggregate_duration = "--s"
if numeric_feedbacks:
aggregate_feedback = "\n".join(
f"{k}: {sum(v) / len(v)}" for k, v in numeric_feedbacks.items()
)
else:
aggregate_feedback = "--"
max_duration = max(max_duration, len(aggregate_duration))
max_dynamic_col_width = (self.console.width - (max_status + max_duration)) // 5
max_dynamic_col_width = max(max_dynamic_col_width, 8)
for pid, status in suite_statuses.items():
status_color = {
"running": "yellow",
"passed": "green",
"failed": "red",
"skipped": "cyan",
}.get(status.get("status", "queued"), "white")
duration = status.get("end_time", now) - status.get("start_time", now)
feedback = "\n".join(
f"{_abbreviate(k, max_len=max_dynamic_col_width)}: {int(v) if isinstance(v, bool) else v}" # noqa: E501
for k, v in status.get("feedback", {}).items()
)
inputs = _dumps_with_fallback(status.get("inputs", {}))
reference_outputs = _dumps_with_fallback(
status.get("reference_outputs", {})
)
outputs = _dumps_with_fallback(status.get("outputs", {}))
table.add_row(
_abbreviate_test_name(str(pid), max_len=max_dynamic_col_width),
_abbreviate(inputs, max_len=max_dynamic_col_width),
_abbreviate(reference_outputs, max_len=max_dynamic_col_width),
_abbreviate(outputs, max_len=max_dynamic_col_width)[
-max_dynamic_col_width:
],
f"[{status_color}]{status.get('status', 'queued')}[/{status_color}]",
feedback,
f"{duration:.2f}s",
)
# Add a blank row or a section separator if you like:
table.add_row("", "", "", "", "", "", "")
# Finally, our “footer” row:
table.add_row(
"[bold]Averages[/bold]",
"",
"",
"",
aggregate_status,
aggregate_feedback,
aggregate_duration,
)
return table
def pytest_configure(self, config):
"""Disable warning reporting and show no warnings in output."""
# Disable general warning reporting
config.option.showwarnings = False
# Disable warning summary
reporter = config.pluginmanager.get_plugin("warnings-plugin")
if reporter:
reporter.warning_summary = lambda *args, **kwargs: None
def pytest_sessionfinish(self, session):
"""Stop Rich Live rendering at the end of the session."""
self.live.stop()
self.live.console.print("\nFinishing up...")
def pytest_configure(config):
"""Register the 'langsmith' marker."""
config.addinivalue_line(
"markers", "langsmith: mark test to be tracked in LangSmith"
)
if config.getoption("--langsmith-output"):
if not importlib.util.find_spec("rich"):
msg = (
"Must have 'rich' installed to use --langsmith-output. "
"Please install with: `pip install -U 'langsmith[pytest]'`"
)
raise ValueError(msg)
if os.environ.get("PYTEST_XDIST_TESTRUNUID"):
msg = (
"--langsmith-output not supported with pytest-xdist. "
"Please remove the '--langsmith-output' option or '-n' option."
)
raise ValueError(msg)
if ls_utils.test_tracking_is_disabled():
msg = (
"--langsmith-output not supported when env var"
"LANGSMITH_TEST_TRACKING='false'. Please remove the"
"'--langsmith-output' option "
"or enable test tracking."
)
raise ValueError(msg)
config.pluginmanager.register(LangSmithPlugin(), "langsmith_output_plugin")
# Suppress warnings summary
config.option.showwarnings = False
def _abbreviate(x: str, max_len: int) -> str:
if len(x) > max_len:
return x[: max_len - 3] + "..."
else:
return x
def _abbreviate_test_name(test_name: str, max_len: int) -> str:
if len(test_name) > max_len:
file, test = test_name.split("::")
if len(".py::" + test) > max_len:
return "..." + test[-(max_len - 3) :]
file_len = max_len - len("...::" + test)
return "..." + file[-file_len:] + "::" + test
else:
return test_name
def _merge_statuses(update: dict, current: dict, *, unpack: list[str]) -> dict:
for path in unpack:
if path_update := update.pop(path, None):
path_current = current.get(path, {})
if isinstance(path_update, dict) and isinstance(path_current, dict):
current[path] = {**path_current, **path_update}
else:
current[path] = path_update
return {**current, **update}
def _dumps_with_fallback(obj: Any) -> str:
try:
return json.dumps(obj)
except Exception:
return "unserializable"