try:
import langchain # noqa: F401
except ImportError:
raise ModuleNotFoundError("Please install LangChain to use this feature: 'pip install langchain'")
import logging
import time
from dataclasses import dataclass
from typing import (
Any,
Dict,
List,
Optional,
Sequence,
Tuple,
Union,
cast,
)
from uuid import UUID
from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema.agent import AgentAction, AgentFinish
from langchain_core.documents import Document
from langchain_core.messages import AIMessage, BaseMessage, FunctionMessage, HumanMessage, SystemMessage, ToolMessage
from langchain_core.outputs import ChatGeneration, LLMResult
from pydantic import BaseModel
from posthog import default_client
from posthog.ai.utils import get_model_params, with_privacy_mode
from posthog.client import Client
log = logging.getLogger("posthog")
@dataclass
class SpanMetadata:
name: str
"""Name of the run: chain name, model name, etc."""
start_time: float
"""Start time of the run."""
end_time: Optional[float]
"""End time of the run."""
input: Optional[Any]
"""Input of the run: messages, prompt variables, etc."""
@property
def latency(self) -> float:
if not self.end_time:
return 0
return self.end_time - self.start_time
@dataclass
class GenerationMetadata(SpanMetadata):
provider: Optional[str] = None
"""Provider of the run: OpenAI, Anthropic"""
model: Optional[str] = None
"""Model used in the run"""
model_params: Optional[Dict[str, Any]] = None
"""Model parameters of the run: temperature, max_tokens, etc."""
base_url: Optional[str] = None
"""Base URL of the provider's API used in the run."""
tools: Optional[List[Dict[str, Any]]] = None
"""Tools provided to the model."""
RunMetadata = Union[SpanMetadata, GenerationMetadata]
RunMetadataStorage = Dict[UUID, RunMetadata]
class CallbackHandler(BaseCallbackHandler):
"""
The PostHog LLM observability callback handler for LangChain.
"""
_client: Client
"""PostHog client instance."""
_distinct_id: Optional[Union[str, int, float, UUID]]
"""Distinct ID of the user to associate the trace with."""
_trace_id: Optional[Union[str, int, float, UUID]]
"""Global trace ID to be sent with every event. Otherwise, the top-level run ID is used."""
_trace_input: Optional[Any]
"""The input at the start of the trace. Any JSON object."""
_trace_name: Optional[str]
"""Name of the trace, exposed in the UI."""
_properties: Optional[Dict[str, Any]]
"""Global properties to be sent with every event."""
_runs: RunMetadataStorage
"""Mapping of run IDs to run metadata as run metadata is only available on the start of generation."""
_parent_tree: Dict[UUID, UUID]
"""
A dictionary that maps chain run IDs to their parent chain run IDs (parent pointer tree),
so the top level can be found from a bottom-level run ID.
"""
def __init__(
self,
client: Optional[Client] = None,
*,
distinct_id: Optional[Union[str, int, float, UUID]] = None,
trace_id: Optional[Union[str, int, float, UUID]] = None,
properties: Optional[Dict[str, Any]] = None,
privacy_mode: bool = False,
groups: Optional[Dict[str, Any]] = None,
):
"""
Args:
client: PostHog client instance.
distinct_id: Optional distinct ID of the user to associate the trace with.
trace_id: Optional trace ID to use for the event.
properties: Optional additional metadata to use for the trace.
privacy_mode: Whether to redact the input and output of the trace.
groups: Optional additional PostHog groups to use for the trace.
"""
posthog_client = client or default_client
if posthog_client is None:
raise ValueError("PostHog client is required")
self._client = posthog_client
self._distinct_id = distinct_id
self._trace_id = trace_id
self._properties = properties or {}
self._privacy_mode = privacy_mode
self._groups = groups or {}
self._runs = {}
self._parent_tree = {}
def on_chain_start(
self,
serialized: Dict[str, Any],
inputs: Dict[str, Any],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs,
):
self._log_debug_event("on_chain_start", run_id, parent_run_id, inputs=inputs)
self._set_parent_of_run(run_id, parent_run_id)
self._set_trace_or_span_metadata(serialized, inputs, run_id, parent_run_id, **kwargs)
def on_chain_end(
self,
outputs: Dict[str, Any],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
):
self._log_debug_event("on_chain_end", run_id, parent_run_id, outputs=outputs)
self._pop_run_and_capture_trace_or_span(run_id, parent_run_id, outputs)
def on_chain_error(
self,
error: BaseException,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
):
self._log_debug_event("on_chain_error", run_id, parent_run_id, error=error)
self._pop_run_and_capture_trace_or_span(run_id, parent_run_id, error)
def on_chat_model_start(
self,
serialized: Dict[str, Any],
messages: List[List[BaseMessage]],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs,
):
self._log_debug_event("on_chat_model_start", run_id, parent_run_id, messages=messages)
self._set_parent_of_run(run_id, parent_run_id)
input = [_convert_message_to_dict(message) for row in messages for message in row]
self._set_llm_metadata(serialized, run_id, input, **kwargs)
def on_llm_start(
self,
serialized: Dict[str, Any],
prompts: List[str],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
):
self._log_debug_event("on_llm_start", run_id, parent_run_id, prompts=prompts)
self._set_parent_of_run(run_id, parent_run_id)
self._set_llm_metadata(serialized, run_id, prompts, **kwargs)
def on_llm_new_token(
self,
token: str,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
"""Run on new LLM token. Only available when streaming is enabled."""
self._log_debug_event("on_llm_new_token", run_id, parent_run_id, token=token)
def on_llm_end(
self,
response: LLMResult,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
):
"""
The callback works for both streaming and non-streaming runs. For streaming runs, the chain must set `stream_usage=True` in the LLM.
"""
self._log_debug_event("on_llm_end", run_id, parent_run_id, response=response, kwargs=kwargs)
self._pop_run_and_capture_generation(run_id, parent_run_id, response)
def on_llm_error(
self,
error: BaseException,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
):
self._log_debug_event("on_llm_error", run_id, parent_run_id, error=error)
self._pop_run_and_capture_generation(run_id, parent_run_id, error)
def on_tool_start(
self,
serialized: Optional[Dict[str, Any]],
input_str: str,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> Any:
self._log_debug_event("on_tool_start", run_id, parent_run_id, input_str=input_str)
self._set_parent_of_run(run_id, parent_run_id)
self._set_trace_or_span_metadata(serialized, input_str, run_id, parent_run_id, **kwargs)
def on_tool_end(
self,
output: str,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
self._log_debug_event("on_tool_end", run_id, parent_run_id, output=output)
self._pop_run_and_capture_trace_or_span(run_id, parent_run_id, output)
def on_tool_error(
self,
error: BaseException,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[list[str]] = None,
**kwargs: Any,
) -> Any:
self._log_debug_event("on_tool_error", run_id, parent_run_id, error=error)
self._pop_run_and_capture_trace_or_span(run_id, parent_run_id, error)
def on_retriever_start(
self,
serialized: Optional[Dict[str, Any]],
query: str,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> Any:
self._log_debug_event("on_retriever_start", run_id, parent_run_id, query=query)
self._set_parent_of_run(run_id, parent_run_id)
self._set_trace_or_span_metadata(serialized, query, run_id, parent_run_id, **kwargs)
def on_retriever_end(
self,
documents: Sequence[Document],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
):
self._log_debug_event("on_retriever_end", run_id, parent_run_id, documents=documents)
self._pop_run_and_capture_trace_or_span(run_id, parent_run_id, documents)
def on_retriever_error(
self,
error: BaseException,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[list[str]] = None,
**kwargs: Any,
) -> Any:
"""Run when Retriever errors."""
self._log_debug_event("on_retriever_error", run_id, parent_run_id, error=error)
self._pop_run_and_capture_trace_or_span(run_id, parent_run_id, error)
def on_agent_action(
self,
action: AgentAction,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
"""Run on agent action."""
self._log_debug_event("on_agent_action", run_id, parent_run_id, action=action)
self._set_parent_of_run(run_id, parent_run_id)
self._set_trace_or_span_metadata(None, action, run_id, parent_run_id, **kwargs)
def on_agent_finish(
self,
finish: AgentFinish,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
self._log_debug_event("on_agent_finish", run_id, parent_run_id, finish=finish)
self._pop_run_and_capture_trace_or_span(run_id, parent_run_id, finish)
def _set_parent_of_run(self, run_id: UUID, parent_run_id: Optional[UUID] = None):
"""
Set the parent run ID for a chain run. If there is no parent, the run is the root.
"""
if parent_run_id is not None:
self._parent_tree[run_id] = parent_run_id
def _pop_parent_of_run(self, run_id: UUID):
"""
Remove the parent run ID for a chain run.
"""
try:
self._parent_tree.pop(run_id)
except KeyError:
pass
def _find_root_run(self, run_id: UUID) -> UUID:
"""
Finds the root ID of a chain run.
"""
id: UUID = run_id
while id in self._parent_tree:
id = self._parent_tree[id]
return id
def _set_trace_or_span_metadata(
self,
serialized: Optional[Dict[str, Any]],
input: Any,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs,
):
default_name = "trace" if parent_run_id is None else "span"
run_name = _get_langchain_run_name(serialized, **kwargs) or default_name
self._runs[run_id] = SpanMetadata(name=run_name, input=input, start_time=time.time(), end_time=None)
def _set_llm_metadata(
self,
serialized: Dict[str, Any],
run_id: UUID,
messages: Union[List[Dict[str, Any]], List[str]],
metadata: Optional[Dict[str, Any]] = None,
invocation_params: Optional[Dict[str, Any]] = None,
**kwargs,
):
run_name = _get_langchain_run_name(serialized, **kwargs) or "generation"
generation = GenerationMetadata(name=run_name, input=messages, start_time=time.time(), end_time=None)
if isinstance(invocation_params, dict):
generation.model_params = get_model_params(invocation_params)
if tools := invocation_params.get("tools"):
generation.tools = tools
if isinstance(metadata, dict):
if model := metadata.get("ls_model_name"):
generation.model = model
if provider := metadata.get("ls_provider"):
generation.provider = provider
try:
base_url = serialized["kwargs"]["openai_api_base"]
if base_url is not None:
generation.base_url = base_url
except KeyError:
pass
self._runs[run_id] = generation
def _pop_run_metadata(self, run_id: UUID) -> Optional[RunMetadata]:
end_time = time.time()
try:
run = self._runs.pop(run_id)
except KeyError:
log.warning(f"No run metadata found for run {run_id}")
return None
run.end_time = end_time
return run
def _get_trace_id(self, run_id: UUID):
trace_id = self._trace_id or self._find_root_run(run_id)
if not trace_id:
return run_id
return trace_id
def _get_parent_run_id(self, trace_id: Any, run_id: UUID, parent_run_id: Optional[UUID]):
"""
Replace the parent run ID with the trace ID for second level runs when a custom trace ID is set.
"""
if parent_run_id is not None and parent_run_id not in self._parent_tree:
return trace_id
return parent_run_id
def _pop_run_and_capture_trace_or_span(self, run_id: UUID, parent_run_id: Optional[UUID], outputs: Any):
trace_id = self._get_trace_id(run_id)
self._pop_parent_of_run(run_id)
run = self._pop_run_metadata(run_id)
if not run:
return
if isinstance(run, GenerationMetadata):
log.warning(f"Run {run_id} is a generation, but attempted to be captured as a trace or span.")
return
self._capture_trace_or_span(
trace_id,
run_id,
run,
outputs,
self._get_parent_run_id(trace_id, run_id, parent_run_id),
)
def _capture_trace_or_span(
self,
trace_id: Any,
run_id: UUID,
run: SpanMetadata,
outputs: Any,
parent_run_id: Optional[UUID],
):
event_name = "$ai_trace" if parent_run_id is None else "$ai_span"
event_properties = {
"$ai_trace_id": trace_id,
"$ai_input_state": with_privacy_mode(self._client, self._privacy_mode, run.input),
"$ai_latency": run.latency,
"$ai_span_name": run.name,
"$ai_span_id": run_id,
}
if parent_run_id is not None:
event_properties["$ai_parent_id"] = parent_run_id
if self._properties:
event_properties.update(self._properties)
if isinstance(outputs, BaseException):
event_properties["$ai_error"] = _stringify_exception(outputs)
event_properties["$ai_is_error"] = True
elif outputs is not None:
event_properties["$ai_output_state"] = with_privacy_mode(self._client, self._privacy_mode, outputs)
if self._distinct_id is None:
event_properties["$process_person_profile"] = False
self._client.capture(
distinct_id=self._distinct_id or run_id,
event=event_name,
properties=event_properties,
groups=self._groups,
)
def _pop_run_and_capture_generation(
self,
run_id: UUID,
parent_run_id: Optional[UUID],
response: Union[LLMResult, BaseException],
):
trace_id = self._get_trace_id(run_id)
self._pop_parent_of_run(run_id)
run = self._pop_run_metadata(run_id)
if not run:
return
if not isinstance(run, GenerationMetadata):
log.warning(f"Run {run_id} is not a generation, but attempted to be captured as a generation.")
return
self._capture_generation(
trace_id,
run_id,
run,
response,
self._get_parent_run_id(trace_id, run_id, parent_run_id),
)
def _capture_generation(
self,
trace_id: Any,
run_id: UUID,
run: GenerationMetadata,
output: Union[LLMResult, BaseException],
parent_run_id: Optional[UUID] = None,
):
event_properties = {
"$ai_trace_id": trace_id,
"$ai_span_id": run_id,
"$ai_span_name": run.name,
"$ai_parent_id": parent_run_id,
"$ai_provider": run.provider,
"$ai_model": run.model,
"$ai_model_parameters": run.model_params,
"$ai_input": with_privacy_mode(self._client, self._privacy_mode, run.input),
"$ai_http_status": 200,
"$ai_latency": run.latency,
"$ai_base_url": run.base_url,
}
if run.tools:
event_properties["$ai_tools"] = with_privacy_mode(
self._client,
self._privacy_mode,
run.tools,
)
if isinstance(output, BaseException):
event_properties["$ai_http_status"] = _get_http_status(output)
event_properties["$ai_error"] = _stringify_exception(output)
event_properties["$ai_is_error"] = True
else:
# Add usage
input_tokens, output_tokens = _parse_usage(output)
event_properties["$ai_input_tokens"] = input_tokens
event_properties["$ai_output_tokens"] = output_tokens
# Generation results
generation_result = output.generations[-1]
if isinstance(generation_result[-1], ChatGeneration):
completions = [
_convert_message_to_dict(cast(ChatGeneration, generation).message)
for generation in generation_result
]
else:
completions = [_extract_raw_esponse(generation) for generation in generation_result]
event_properties["$ai_output_choices"] = with_privacy_mode(self._client, self._privacy_mode, completions)
if self._properties:
event_properties.update(self._properties)
if self._distinct_id is None:
event_properties["$process_person_profile"] = False
self._client.capture(
distinct_id=self._distinct_id or trace_id,
event="$ai_generation",
properties=event_properties,
groups=self._groups,
)
def _log_debug_event(
self,
event_name: str,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs,
):
log.debug(
f"Event: {event_name}, run_id: {str(run_id)[:5]}, parent_run_id: {str(parent_run_id)[:5]}, kwargs: {kwargs}"
)
def _extract_raw_esponse(last_response):
"""Extract the response from the last response of the LLM call."""
# We return the text of the response if not empty
if last_response.text is not None and last_response.text.strip() != "":
return last_response.text.strip()
elif hasattr(last_response, "message"):
# Additional kwargs contains the response in case of tool usage
return last_response.message.additional_kwargs
else:
# Not tool usage, some LLM responses can be simply empty
return ""
def _convert_message_to_dict(message: BaseMessage) -> Dict[str, Any]:
# assistant message
if isinstance(message, HumanMessage):
message_dict = {"role": "user", "content": message.content}
elif isinstance(message, AIMessage):
message_dict = {"role": "assistant", "content": message.content}
elif isinstance(message, SystemMessage):
message_dict = {"role": "system", "content": message.content}
elif isinstance(message, ToolMessage):
message_dict = {"role": "tool", "content": message.content}
elif isinstance(message, FunctionMessage):
message_dict = {"role": "function", "content": message.content}
else:
message_dict = {"role": message.type, "content": str(message.content)}
if message.additional_kwargs:
message_dict.update(message.additional_kwargs)
return message_dict
def _parse_usage_model(
usage: Union[BaseModel, Dict],
) -> Tuple[Union[int, None], Union[int, None]]:
if isinstance(usage, BaseModel):
usage = usage.__dict__
conversion_list = [
# https://pypi.org/project/langchain-anthropic/ (works also for Bedrock-Anthropic)
("input_tokens", "input"),
("output_tokens", "output"),
# https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/get-token-count
("prompt_token_count", "input"),
("candidates_token_count", "output"),
# Bedrock: https://docs.aws.amazon.com/bedrock/latest/userguide/monitoring-cw.html#runtime-cloudwatch-metrics
("inputTokenCount", "input"),
("outputTokenCount", "output"),
# Bedrock Anthropic
("prompt_tokens", "input"),
("completion_tokens", "output"),
# langchain-ibm https://pypi.org/project/langchain-ibm/
("input_token_count", "input"),
("generated_token_count", "output"),
]
parsed_usage = {}
for model_key, type_key in conversion_list:
if model_key in usage:
captured_count = usage[model_key]
final_count = (
sum(captured_count) if isinstance(captured_count, list) else captured_count
) # For Bedrock, the token count is a list when streamed
parsed_usage[type_key] = final_count
return parsed_usage.get("input"), parsed_usage.get("output")
def _parse_usage(response: LLMResult):
# langchain-anthropic uses the usage field
llm_usage_keys = ["token_usage", "usage"]
llm_usage: Tuple[Union[int, None], Union[int, None]] = (None, None)
if response.llm_output is not None:
for key in llm_usage_keys:
if response.llm_output.get(key):
llm_usage = _parse_usage_model(response.llm_output[key])
break
if hasattr(response, "generations"):
for generation in response.generations:
if "usage" in generation:
llm_usage = _parse_usage_model(generation["usage"])
break
for generation_chunk in generation:
if generation_chunk.generation_info and ("usage_metadata" in generation_chunk.generation_info):
llm_usage = _parse_usage_model(generation_chunk.generation_info["usage_metadata"])
break
message_chunk = getattr(generation_chunk, "message", {})
response_metadata = getattr(message_chunk, "response_metadata", {})
bedrock_anthropic_usage = (
response_metadata.get("usage", None) # for Bedrock-Anthropic
if isinstance(response_metadata, dict)
else None
)
bedrock_titan_usage = (
response_metadata.get("amazon-bedrock-invocationMetrics", None) # for Bedrock-Titan
if isinstance(response_metadata, dict)
else None
)
ollama_usage = getattr(message_chunk, "usage_metadata", None) # for Ollama
chunk_usage = bedrock_anthropic_usage or bedrock_titan_usage or ollama_usage
if chunk_usage:
llm_usage = _parse_usage_model(chunk_usage)
break
return llm_usage
def _get_http_status(error: BaseException) -> int:
# OpenAI: https://github.com/openai/openai-python/blob/main/src/openai/_exceptions.py
# Anthropic: https://github.com/anthropics/anthropic-sdk-python/blob/main/src/anthropic/_exceptions.py
# Google: https://github.com/googleapis/python-api-core/blob/main/google/api_core/exceptions.py
status_code = getattr(error, "status_code", getattr(error, "code", 0))
return status_code
def _get_langchain_run_name(serialized: Optional[Dict[str, Any]], **kwargs: Any) -> Optional[str]:
"""Retrieve the name of a serialized LangChain runnable.
The prioritization for the determination of the run name is as follows:
- The value assigned to the "name" key in `kwargs`.
- The value assigned to the "name" key in `serialized`.
- The last entry of the value assigned to the "id" key in `serialized`.
- "<unknown>".
Args:
serialized (Optional[Dict[str, Any]]): A dictionary containing the runnable's serialized data.
**kwargs (Any): Additional keyword arguments, potentially including the 'name' override.
Returns:
str: The determined name of the Langchain runnable.
"""
if "name" in kwargs and kwargs["name"] is not None:
return kwargs["name"]
if serialized is None:
return None
try:
return serialized["name"]
except (KeyError, TypeError):
pass
try:
return serialized["id"][-1]
except (KeyError, TypeError):
pass
return None
def _stringify_exception(exception: BaseException) -> str:
description = str(exception)
if description:
return f"{exception.__class__.__name__}: {description}"
return exception.__class__.__name__