from collections import defaultdict
from contextvars import ContextVar
from typing import Any, Dict, List, Optional, Union, Tuple, Callable, Generator
from uuid import uuid4
import logging
import httpx
from langfuse.client import (
StatefulSpanClient,
StatefulTraceClient,
StatefulGenerationClient,
StateType,
)
from langfuse.utils.error_logging import (
auto_decorate_methods_with,
catch_and_log_errors,
)
from langfuse.types import TraceMetadata
from langfuse.utils.base_callback_handler import LangfuseBaseCallbackHandler
from .utils import CallbackEvent, ParsedLLMEndPayload
from pydantic import BaseModel
try:
from llama_index.core.callbacks.base_handler import (
BaseCallbackHandler as LlamaIndexBaseCallbackHandler,
)
from llama_index.core.callbacks.schema import (
CBEventType,
BASE_TRACE_EVENT,
EventPayload,
)
from llama_index.core.utilities.token_counting import TokenCounter
except ImportError:
raise ModuleNotFoundError(
"Please install llama-index to use the Langfuse llama-index integration: 'pip install llama-index'"
)
context_root: ContextVar[Optional[Union[StatefulTraceClient, StatefulSpanClient]]] = (
ContextVar("root", default=None)
)
context_trace_metadata: ContextVar[TraceMetadata] = ContextVar(
"trace_metadata",
default={
"name": None,
"user_id": None,
"session_id": None,
"version": None,
"release": None,
"metadata": None,
"tags": None,
"public": None,
},
)
@auto_decorate_methods_with(catch_and_log_errors, exclude=["__init__"])
class LlamaIndexCallbackHandler(
LlamaIndexBaseCallbackHandler, LangfuseBaseCallbackHandler
):
"""LlamaIndex callback handler for Langfuse. This version is in alpha and may change in the future."""
log = logging.getLogger("langfuse")
def __init__(
self,
*,
public_key: Optional[str] = None,
secret_key: Optional[str] = None,
host: Optional[str] = None,
debug: bool = False,
session_id: Optional[str] = None,
user_id: Optional[str] = None,
trace_name: Optional[str] = None,
release: Optional[str] = None,
version: Optional[str] = None,
tags: Optional[List[str]] = None,
metadata: Optional[Any] = None,
threads: Optional[int] = None,
flush_at: Optional[int] = None,
flush_interval: Optional[int] = None,
max_retries: Optional[int] = None,
timeout: Optional[int] = None,
event_starts_to_ignore: Optional[List[CBEventType]] = None,
event_ends_to_ignore: Optional[List[CBEventType]] = None,
tokenizer: Optional[Callable[[str], list]] = None,
enabled: Optional[bool] = None,
httpx_client: Optional[httpx.Client] = None,
sdk_integration: Optional[str] = None,
sample_rate: Optional[float] = None,
) -> None:
LlamaIndexBaseCallbackHandler.__init__(
self,
event_starts_to_ignore=event_starts_to_ignore or [],
event_ends_to_ignore=event_ends_to_ignore or [],
)
LangfuseBaseCallbackHandler.__init__(
self,
public_key=public_key,
secret_key=secret_key,
host=host,
debug=debug,
session_id=session_id,
user_id=user_id,
trace_name=trace_name,
release=release,
version=version,
tags=tags,
metadata=metadata,
threads=threads,
flush_at=flush_at,
flush_interval=flush_interval,
max_retries=max_retries,
timeout=timeout,
enabled=enabled,
httpx_client=httpx_client,
sdk_integration=sdk_integration or "llama-index_callback",
sample_rate=sample_rate,
)
self.event_map: Dict[str, List[CallbackEvent]] = defaultdict(list)
self._llama_index_trace_name: Optional[str] = None
self._token_counter = TokenCounter(tokenizer)
# For stream-chat, the last LLM end_event arrives after the trace has ended
# Keep track of these orphans to upsert them with the correct trace_id after the trace has ended
self._orphaned_LLM_generations: Dict[
str, Tuple[StatefulGenerationClient, StatefulTraceClient]
] = {}
def set_root(
self,
root: Optional[Union[StatefulTraceClient, StatefulSpanClient]],
*,
update_root: bool = False,
) -> None:
"""Set the root trace or span for the callback handler.
Args:
root (Optional[Union[StatefulTraceClient, StatefulSpanClient]]): The root trace or observation to
be used for all following operations.
Keyword Args:
update_root (bool): If True, the root trace or observation will be updated with the outcome of the LlamaIndex run.
Returns:
None
"""
context_root.set(root)
if root is None:
self.trace = None
self.root_span = None
self._task_manager = self.langfuse.task_manager if self.langfuse else None
return
if isinstance(root, StatefulTraceClient):
self.trace = root
elif isinstance(root, StatefulSpanClient):
self.root_span = root
self.trace = StatefulTraceClient(
root.client,
root.trace_id,
StateType.TRACE,
root.trace_id,
root.task_manager,
)
self._task_manager = root.task_manager
self.update_stateful_client = update_root
def set_trace_params(
self,
name: Optional[str] = None,
user_id: Optional[str] = None,
session_id: Optional[str] = None,
version: Optional[str] = None,
release: Optional[str] = None,
metadata: Optional[Any] = None,
tags: Optional[List[str]] = None,
public: Optional[bool] = None,
):
"""Set the trace params that will be used for all following operations.
Allows setting params of subsequent traces at any point in the code.
Overwrites the default params set in the callback constructor.
Attention: If a root trace or span is set on the callback handler, those trace params will be used and NOT those set through this method.
Attributes:
name (Optional[str]): Identifier of the trace. Useful for sorting/filtering in the UI.
user_id (Optional[str]): The id of the user that triggered the execution. Used to provide user-level analytics.
session_id (Optional[str]): Used to group multiple traces into a session in Langfuse. Use your own session/thread identifier.
version (Optional[str]): The version of the trace type. Used to understand how changes to the trace type affect metrics. Useful in debugging.
metadata (Optional[Any]): Additional metadata of the trace. Can be any JSON object. Metadata is merged when being updated via the API.
tags (Optional[List[str]]): Tags are used to categorize or label traces. Traces can be filtered by tags in the Langfuse UI and GET API.
public (Optional[bool]): You can make a trace public to share it via a public link. This allows others to view the trace without needing to log in or be members of your Langfuse project.
Returns:
None
"""
context_trace_metadata.set(
{
"name": name,
"user_id": user_id,
"session_id": session_id,
"version": version,
"release": release,
"metadata": metadata,
"tags": tags,
"public": public,
}
)
def start_trace(self, trace_id: Optional[str] = None) -> None:
"""Run when an overall trace is launched."""
self._llama_index_trace_name = trace_id
def end_trace(
self,
trace_id: Optional[str] = None,
trace_map: Optional[Dict[str, List[str]]] = None,
) -> None:
"""Run when an overall trace is exited."""
if not trace_map:
self.log.debug("No events in trace map to create the observation tree.")
return
# Generate Langfuse observations after trace has ended and full trace_map is available.
# For long-running traces this leads to events only being sent to Langfuse after the trace has ended.
# Timestamps remain accurate as they are set at the time of the event.
self._create_observations_from_trace_map(
event_id=BASE_TRACE_EVENT, trace_map=trace_map
)
self._update_trace_data(trace_map=trace_map)
def on_event_start(
self,
event_type: CBEventType,
payload: Optional[Dict[str, Any]] = None,
event_id: str = "",
parent_id: str = "",
**kwargs: Any,
) -> str:
"""Run when an event starts and return id of event."""
start_event = CallbackEvent(
event_id=event_id, event_type=event_type, payload=payload
)
self.event_map[event_id].append(start_event)
return event_id
def on_event_end(
self,
event_type: CBEventType,
payload: Optional[Dict[str, Any]] = None,
event_id: str = "",
**kwargs: Any,
) -> None:
"""Run when an event ends."""
end_event = CallbackEvent(
event_id=event_id, event_type=event_type, payload=payload
)
self.event_map[event_id].append(end_event)
if event_type == CBEventType.LLM and event_id in self._orphaned_LLM_generations:
generation, trace = self._orphaned_LLM_generations[event_id]
self._handle_orphaned_LLM_end_event(
end_event, generation=generation, trace=trace
)
del self._orphaned_LLM_generations[event_id]
def _create_observations_from_trace_map(
self,
event_id: str,
trace_map: Dict[str, List[str]],
parent: Optional[
Union[StatefulTraceClient, StatefulSpanClient, StatefulGenerationClient]
] = None,
) -> None:
"""Recursively create langfuse observations based on the trace_map."""
if event_id != BASE_TRACE_EVENT and not self.event_map.get(event_id):
return
if event_id == BASE_TRACE_EVENT:
observation = self._get_root_observation()
else:
observation = self._create_observation(
event_id=event_id, parent=parent, trace_id=self.trace.id
)
for child_event_id in trace_map.get(event_id, []):
self._create_observations_from_trace_map(
event_id=child_event_id, parent=observation, trace_map=trace_map
)
def _get_root_observation(self) -> Union[StatefulTraceClient, StatefulSpanClient]:
user_provided_root = context_root.get()
# Get trace metadata from contextvars or use default values
trace_metadata = context_trace_metadata.get()
name = (
trace_metadata["name"]
or self.trace_name
or f"LlamaIndex_{self._llama_index_trace_name}"
)
version = trace_metadata["version"] or self.version
release = trace_metadata["release"] or self.release
session_id = trace_metadata["session_id"] or self.session_id
user_id = trace_metadata["user_id"] or self.user_id
metadata = trace_metadata["metadata"] or self.metadata
tags = trace_metadata["tags"] or self.tags
public = trace_metadata["public"] or None
# Make sure that if a user-provided root is set, it has been set in the same trace
# and it's not a root from a different trace
if (
user_provided_root is not None
and self.trace
and self.trace.id == user_provided_root.trace_id
):
if self.update_stateful_client:
user_provided_root.update(
name=name,
version=version,
session_id=session_id,
user_id=user_id,
metadata=metadata,
tags=tags,
release=release,
public=public,
)
return user_provided_root
else:
self.trace = self.langfuse.trace(
id=str(uuid4()),
name=name,
version=version,
session_id=session_id,
user_id=user_id,
metadata=metadata,
tags=tags,
release=release,
public=public,
)
return self.trace
def _create_observation(
self,
event_id: str,
parent: Union[
StatefulTraceClient, StatefulSpanClient, StatefulGenerationClient
],
trace_id: str,
) -> Union[StatefulSpanClient, StatefulGenerationClient]:
event_type = self.event_map[event_id][0].event_type
if event_type == CBEventType.LLM:
return self._handle_LLM_events(event_id, parent, trace_id)
elif event_type == CBEventType.EMBEDDING:
return self._handle_embedding_events(event_id, parent, trace_id)
else:
return self._handle_span_events(event_id, parent, trace_id)
def _handle_LLM_events(
self,
event_id: str,
parent: Union[
StatefulTraceClient, StatefulSpanClient, StatefulGenerationClient
],
trace_id: str,
) -> StatefulGenerationClient:
events = self.event_map[event_id]
start_event, end_event = events[0], events[-1]
if start_event.payload and EventPayload.SERIALIZED in start_event.payload:
serialized = start_event.payload.get(EventPayload.SERIALIZED, {})
name = serialized.get("class_name", "LLM")
temperature = serialized.get("temperature", None)
max_tokens = serialized.get("max_tokens", None)
timeout = serialized.get("timeout", None)
parsed_end_payload = self._parse_LLM_end_event_payload(end_event)
parsed_metadata = self._parse_metadata_from_event(end_event)
generation = parent.generation(
id=event_id,
trace_id=trace_id,
version=self.version,
name=name,
start_time=start_event.time,
metadata=parsed_metadata,
model_parameters={
"temperature": temperature,
"max_tokens": max_tokens,
"request_timeout": timeout,
},
**parsed_end_payload,
)
# Register orphaned LLM event (only start event, no end event) to be later upserted with the correct trace_id
if len(events) == 1:
self._orphaned_LLM_generations[event_id] = (generation, self.trace)
return generation
def _handle_orphaned_LLM_end_event(
self,
end_event: CallbackEvent,
generation: StatefulGenerationClient,
trace: StatefulTraceClient,
) -> None:
parsed_end_payload = self._parse_LLM_end_event_payload(end_event)
generation.update(
**parsed_end_payload,
)
if generation.trace_id != trace.id:
raise ValueError(
f"Generation trace_id {generation.trace_id} does not match trace.id {trace.id}"
)
trace.update(output=parsed_end_payload["output"])
def _parse_LLM_end_event_payload(
self, end_event: CallbackEvent
) -> ParsedLLMEndPayload:
result: ParsedLLMEndPayload = {
"input": None,
"output": None,
"usage": None,
"model": None,
"end_time": end_event.time,
}
if not end_event.payload:
return result
result["input"] = self._parse_input_from_event(end_event)
result["output"] = self._parse_output_from_event(end_event)
result["model"], result["usage"] = self._parse_usage_from_event_payload(
end_event.payload
)
return result
def _parse_usage_from_event_payload(self, event_payload: Dict):
model = usage = None
if not (
EventPayload.MESSAGES in event_payload
and EventPayload.RESPONSE in event_payload
):
return model, usage
response = event_payload.get(EventPayload.RESPONSE)
if response and hasattr(response, "raw") and response.raw is not None:
if isinstance(response.raw, dict):
raw_dict = response.raw
elif isinstance(response.raw, BaseModel):
raw_dict = response.raw.model_dump()
else:
raw_dict = {}
model = raw_dict.get("model", None)
raw_token_usage = raw_dict.get("usage", {})
if isinstance(raw_token_usage, dict):
token_usage = raw_token_usage
elif isinstance(raw_token_usage, BaseModel):
token_usage = raw_token_usage.model_dump()
else:
token_usage = {}
if token_usage:
usage = {
"input": token_usage.get("prompt_tokens", None),
"output": token_usage.get("completion_tokens", None),
"total": token_usage.get("total_tokens", None),
}
return model, usage
def _handle_embedding_events(
self,
event_id: str,
parent: Union[
StatefulTraceClient, StatefulSpanClient, StatefulGenerationClient
],
trace_id: str,
) -> StatefulGenerationClient:
events = self.event_map[event_id]
start_event, end_event = events[0], events[-1]
if start_event.payload and EventPayload.SERIALIZED in start_event.payload:
serialized = start_event.payload.get(EventPayload.SERIALIZED, {})
name = serialized.get("class_name", "Embedding")
model = serialized.get("model_name", None)
timeout = serialized.get("timeout", None)
if end_event.payload:
chunks = end_event.payload.get(EventPayload.CHUNKS, [])
token_count = sum(
self._token_counter.get_string_tokens(chunk) for chunk in chunks
)
usage = {
"input": 0,
"output": 0,
"total": token_count or None,
}
input = self._parse_input_from_event(end_event)
output = self._parse_output_from_event(end_event)
generation = parent.generation(
id=event_id,
trace_id=trace_id,
name=name,
start_time=start_event.time,
end_time=end_event.time,
version=self.version,
model=model,
input=input,
output=output,
usage=usage or None,
model_parameters={
"request_timeout": timeout,
},
)
return generation
def _handle_span_events(
self,
event_id: str,
parent: Union[
StatefulTraceClient, StatefulSpanClient, StatefulGenerationClient
],
trace_id: str,
) -> StatefulSpanClient:
start_event, end_event = self.event_map[event_id]
extracted_input = self._parse_input_from_event(start_event)
extracted_output = self._parse_output_from_event(end_event)
extracted_metadata = self._parse_metadata_from_event(end_event)
metadata = (
extracted_metadata if extracted_output != extracted_metadata else None
)
name = start_event.event_type.value
# Update name to the actual tool's name used by openai agent if available
if (
name == "function_call"
and start_event.payload
and start_event.payload.get("tool", None)
):
name = start_event.payload.get("tool", name)
span = parent.span(
id=event_id,
trace_id=trace_id,
start_time=start_event.time,
name=name,
version=self.version,
session_id=self.session_id,
input=extracted_input,
output=extracted_output,
metadata=metadata,
)
if end_event:
span.end(end_time=end_event.time)
return span
def _update_trace_data(self, trace_map):
context_root_value = context_root.get()
if context_root_value and not self.update_stateful_client:
return
child_event_ids = trace_map.get(BASE_TRACE_EVENT, [])
if not child_event_ids:
return
event_pair = self.event_map.get(child_event_ids[0])
if not event_pair or len(event_pair) < 2:
return
start_event, end_event = event_pair
input = self._parse_input_from_event(start_event)
output = self._parse_output_from_event(end_event)
if input or output:
if context_root_value and self.update_stateful_client:
context_root_value.update(input=input, output=output)
else:
self.trace.update(input=input, output=output)
def _parse_input_from_event(self, event: CallbackEvent):
if event.payload is None:
return
payload = event.payload.copy()
if EventPayload.SERIALIZED in payload:
# Always pop Serialized from payload as it may contain LLM api keys
payload.pop(EventPayload.SERIALIZED)
if event.event_type == CBEventType.EMBEDDING and EventPayload.CHUNKS in payload:
chunks = payload.get(EventPayload.CHUNKS)
return {"num_chunks": len(chunks)}
if (
event.event_type == CBEventType.NODE_PARSING
and EventPayload.DOCUMENTS in payload
):
documents = payload.pop(EventPayload.DOCUMENTS)
payload["documents"] = [doc.metadata for doc in documents]
return payload
for key in [EventPayload.MESSAGES, EventPayload.QUERY_STR, EventPayload.PROMPT]:
if key in payload:
return payload.get(key)
return payload or None
def _parse_output_from_event(self, event: CallbackEvent):
if event.payload is None:
return
payload = event.payload.copy()
if EventPayload.SERIALIZED in payload:
# Always pop Serialized from payload as it may contain LLM api keys
payload.pop(EventPayload.SERIALIZED)
if (
event.event_type == CBEventType.EMBEDDING
and EventPayload.EMBEDDINGS in payload
):
embeddings = payload.get(EventPayload.EMBEDDINGS)
return {"num_embeddings": len(embeddings)}
if (
event.event_type == CBEventType.NODE_PARSING
and EventPayload.NODES in payload
):
nodes = payload.pop(EventPayload.NODES)
payload["num_nodes"] = len(nodes)
return payload
if event.event_type == CBEventType.CHUNKING and EventPayload.CHUNKS in payload:
chunks = payload.pop(EventPayload.CHUNKS)
payload["num_chunks"] = len(chunks)
if EventPayload.COMPLETION in payload:
return payload.get(EventPayload.COMPLETION)
if EventPayload.RESPONSE in payload:
response = payload.get(EventPayload.RESPONSE)
# Skip streaming responses as consuming them would block the user's execution path
if "Streaming" in type(response).__name__:
return None
if hasattr(response, "response"):
return response.response
if hasattr(response, "message"):
output = dict(response.message)
if "additional_kwargs" in output:
if "tool_calls" in output["additional_kwargs"]:
output["tool_calls"] = output["additional_kwargs"]["tool_calls"]
del output["additional_kwargs"]
return output
return payload or None
def _parse_metadata_from_event(self, event: CallbackEvent):
if event.payload is None:
return
metadata = {}
for key in event.payload.keys():
if key not in [
EventPayload.MESSAGES,
EventPayload.QUERY_STR,
EventPayload.PROMPT,
EventPayload.COMPLETION,
EventPayload.SERIALIZED,
"additional_kwargs",
]:
if key != EventPayload.RESPONSE:
metadata[key] = event.payload[key]
else:
response = event.payload.get(EventPayload.RESPONSE)
if "Streaming" in type(response).__name__:
continue
for res_key, value in vars(response).items():
if (
not res_key.startswith("_")
and res_key
not in [
"response",
"response_txt",
"message",
"additional_kwargs",
"delta",
"raw",
]
and not isinstance(value, Generator)
):
metadata[res_key] = value
return metadata or None