import httpx import logging import typing import warnings import pydantic try: # Test that langchain is installed before proceeding import langchain # noqa except ImportError as e: log = logging.getLogger("langfuse") log.error( f"Could not import langchain. The langchain integration will not work. {e}" ) from typing import Any, Dict, List, Optional, Sequence, Union from uuid import UUID, uuid4 from langfuse.api.resources.ingestion.types.sdk_log_body import SdkLogBody from langfuse.client import ( StatefulSpanClient, StatefulTraceClient, ) from langfuse.extract_model import _extract_model_name from langfuse.utils import _get_timestamp from langfuse.utils.base_callback_handler import LangfuseBaseCallbackHandler try: from langchain.callbacks.base import ( BaseCallbackHandler as LangchainBaseCallbackHandler, ) from langchain.schema.agent import AgentAction, AgentFinish from langchain.schema.document import Document from langchain_core.outputs import ( ChatGeneration, LLMResult, ) from langchain_core.messages import ( AIMessage, BaseMessage, ChatMessage, HumanMessage, SystemMessage, ToolMessage, FunctionMessage, ) except ImportError: raise ModuleNotFoundError( "Please install langchain to use the Langfuse langchain integration: 'pip install langchain'" ) class LangchainCallbackHandler( LangchainBaseCallbackHandler, LangfuseBaseCallbackHandler ): log = logging.getLogger("langfuse") next_span_id: Optional[str] = None def __init__( self, public_key: Optional[str] = None, secret_key: Optional[str] = None, host: Optional[str] = None, debug: bool = False, stateful_client: Optional[ Union[StatefulTraceClient, StatefulSpanClient] ] = None, update_stateful_client: bool = False, session_id: Optional[str] = None, user_id: Optional[str] = None, trace_name: Optional[str] = None, release: Optional[str] = None, version: Optional[str] = None, metadata: Optional[Dict[str, Any]] = None, tags: Optional[List[str]] = None, threads: Optional[int] = None, flush_at: Optional[int] = None, flush_interval: Optional[int] = None, max_retries: Optional[int] = None, timeout: Optional[int] = None, enabled: Optional[bool] = None, httpx_client: Optional[httpx.Client] = None, sdk_integration: Optional[str] = None, sample_rate: Optional[float] = None, ) -> None: LangfuseBaseCallbackHandler.__init__( self, public_key=public_key, secret_key=secret_key, host=host, debug=debug, stateful_client=stateful_client, update_stateful_client=update_stateful_client, session_id=session_id, user_id=user_id, trace_name=trace_name, release=release, version=version, metadata=metadata, tags=tags, threads=threads, flush_at=flush_at, flush_interval=flush_interval, max_retries=max_retries, timeout=timeout, enabled=enabled, httpx_client=httpx_client, sdk_integration=sdk_integration or "langchain", sample_rate=sample_rate, ) self.runs = {} if stateful_client and isinstance(stateful_client, StatefulSpanClient): self.runs[stateful_client.id] = stateful_client def setNextSpan(self, id: str): warnings.warn( "setNextSpan is deprecated, use span.get_langchain_handler() instead", DeprecationWarning, ) self.next_span_id = id def on_llm_new_token( self, token: str, *, run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any, ) -> Any: """Run on new LLM token. Only available when streaming is enabled.""" # Nothing needs to happen here for langfuse. Once the streaming is done, self.log.debug( f"on llm new token: run_id: {run_id} parent_run_id: {parent_run_id}" ) def get_langchain_run_name(self, serialized: Dict[str, Any], **kwargs: Any) -> str: """Retrieves the 'run_name' for an entity based on Langchain convention, prioritizing the 'name' key in 'kwargs' or falling back to the 'name' or 'id' in 'serialized'. Defaults to "<unknown>" if none are available. Args: serialized (Dict[str, Any]): A dictionary containing the entity's serialized data. **kwargs (Any): Additional keyword arguments, potentially including the 'name' override. Returns: str: The determined Langchain run name for the entity. """ # Check if 'name' is in kwargs and not None, otherwise use default fallback logic if "name" in kwargs and kwargs["name"] is not None: return kwargs["name"] # Fallback to serialized 'name', 'id', or "<unknown>" return serialized.get("name", serialized.get("id", ["<unknown>"])[-1]) def on_retriever_error( self, error: Union[Exception, KeyboardInterrupt], *, run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any, ) -> Any: """Run when Retriever errors.""" try: self._log_debug_event( "on_retriever_error", run_id, parent_run_id, error=error ) if run_id is None or run_id not in self.runs: raise Exception("run not found") self.runs[run_id] = self.runs[run_id].end( level="ERROR", status_message=str(error), version=self.version, input=kwargs.get("inputs"), ) except Exception as e: self.log.exception(e) def on_chain_start( self, serialized: Dict[str, Any], inputs: Dict[str, Any], *, run_id: UUID, parent_run_id: Optional[UUID] = None, tags: Optional[List[str]] = None, metadata: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> Any: try: self._log_debug_event( "on_chain_start", run_id, parent_run_id, inputs=inputs ) self.__generate_trace_and_parent( serialized=serialized, inputs=inputs, run_id=run_id, parent_run_id=parent_run_id, tags=tags, metadata=metadata, version=self.version, **kwargs, ) content = { "id": self.next_span_id, "trace_id": self.trace.id, "name": self.get_langchain_run_name(serialized, **kwargs), "metadata": self.__join_tags_and_metadata(tags, metadata), "input": inputs, "version": self.version, } if parent_run_id is None: if self.root_span is None: self.runs[run_id] = self.trace.span(**content) else: self.runs[run_id] = self.root_span.span(**content) if parent_run_id is not None: self.runs[run_id] = self.runs[parent_run_id].span(**content) except Exception as e: self.log.exception(e) def __generate_trace_and_parent( self, serialized: Dict[str, Any], inputs: Union[Dict[str, Any], List[str], str, None], *, run_id: UUID, parent_run_id: Optional[UUID] = None, tags: Optional[List[str]] = None, metadata: Optional[Dict[str, Any]] = None, **kwargs: Any, ): try: class_name = self.get_langchain_run_name(serialized, **kwargs) # on a new invocation, and not user provided root, we want to initialise a new traceo # parent_run_id is None when we are at the root of a langchain execution if ( self.trace is not None and parent_run_id is None and self.langfuse is not None ): self.trace = None if ( self.trace is not None and parent_run_id is None # We are at the root of a langchain execution and self.langfuse is None # StatefulClient was provided by user and self.update_stateful_client ): params = { "name": self.trace_name if self.trace_name is not None else class_name, "metadata": self.__join_tags_and_metadata( tags, metadata, trace_metadata=self.metadata ), "version": self.version, "session_id": self.session_id, "user_id": self.user_id, "tags": self.tags, "input": inputs, } if self.root_span: self.root_span.update(**params) else: self.trace.update(**params) # if we are at a root, but langfuse exists, it means we do not have a # root provided by a user. Initialise it by creating a trace and root span. if self.trace is None and self.langfuse is not None: trace = self.langfuse.trace( id=str(run_id), name=self.trace_name if self.trace_name is not None else class_name, metadata=self.__join_tags_and_metadata( tags, metadata, trace_metadata=self.metadata ), version=self.version, session_id=self.session_id, user_id=self.user_id, tags=self.tags, input=inputs, ) self.trace = trace if parent_run_id is not None and parent_run_id in self.runs: self.runs[run_id] = self.trace.span( id=self.next_span_id, trace_id=self.trace.id, name=class_name, metadata=self.__join_tags_and_metadata(tags, metadata), input=inputs, version=self.version, ) return except Exception as e: self.log.exception(e) def on_agent_action( self, action: AgentAction, *, run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any, ) -> Any: """Run on agent action.""" try: self._log_debug_event( "on_agent_action", run_id, parent_run_id, action=action ) if run_id not in self.runs: raise Exception("run not found") self.runs[run_id] = self.runs[run_id].end( output=action, version=self.version, input=kwargs.get("inputs"), ) except Exception as e: self.log.exception(e) def on_agent_finish( self, finish: AgentFinish, *, run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any, ) -> Any: try: self._log_debug_event( "on_agent_finish", run_id, parent_run_id, finish=finish ) if run_id not in self.runs: raise Exception("run not found") self.runs[run_id] = self.runs[run_id].end( output=finish, version=self.version, input=kwargs.get("inputs"), ) # langchain sends same run_id for agent_finish and chain_end for the same agent interaction. # Hence, we only delete at chain_end and not here. self._update_trace_and_remove_state( run_id, parent_run_id, finish, keep_state=True ) except Exception as e: self.log.exception(e) def on_chain_end( self, outputs: Dict[str, Any], *, run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any, ) -> Any: try: self._log_debug_event( "on_chain_end", run_id, parent_run_id, outputs=outputs ) if run_id not in self.runs: raise Exception("run not found") self.runs[run_id] = self.runs[run_id].end( output=outputs, version=self.version, input=kwargs.get("inputs"), ) self._update_trace_and_remove_state( run_id, parent_run_id, outputs, input=kwargs.get("inputs") ) except Exception as e: self.log.exception(e) def on_chain_error( self, error: Union[Exception, KeyboardInterrupt], *, run_id: UUID, parent_run_id: Optional[UUID] = None, tags: Optional[List[str]] = None, **kwargs: Any, ) -> None: try: self._log_debug_event("on_chain_error", run_id, parent_run_id, error=error) self.runs[run_id] = self.runs[run_id].end( level="ERROR", status_message=str(error), version=self.version, input=kwargs.get("inputs"), ) self._update_trace_and_remove_state( run_id, parent_run_id, error, input=kwargs.get("inputs") ) except Exception as e: self.log.exception(e) def on_chat_model_start( self, serialized: Dict[str, Any], messages: List[List[BaseMessage]], *, run_id: UUID, parent_run_id: Optional[UUID] = None, tags: Optional[List[str]] = None, metadata: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> Any: try: self._log_debug_event( "on_chat_model_start", run_id, parent_run_id, messages=messages ) self.__on_llm_action( serialized, run_id, _flatten_comprehension( [self._create_message_dicts(m) for m in messages] ), parent_run_id, tags=tags, metadata=metadata, **kwargs, ) except Exception as e: self.log.exception(e) def on_llm_start( self, serialized: Dict[str, Any], prompts: List[str], *, run_id: UUID, parent_run_id: Optional[UUID] = None, tags: Optional[List[str]] = None, metadata: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> Any: try: self._log_debug_event( "on_llm_start", run_id, parent_run_id, prompts=prompts ) self.__on_llm_action( serialized, run_id, prompts[0] if len(prompts) == 1 else prompts, parent_run_id, tags=tags, metadata=metadata, **kwargs, ) except Exception as e: self.log.exception(e) def on_tool_start( self, serialized: Dict[str, Any], input_str: str, *, run_id: UUID, parent_run_id: Optional[UUID] = None, tags: Optional[List[str]] = None, metadata: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> Any: try: self._log_debug_event( "on_tool_start", run_id, parent_run_id, input_str=input_str ) if parent_run_id is None or parent_run_id not in self.runs: raise Exception("parent run not found") meta = self.__join_tags_and_metadata(tags, metadata) if not meta: meta = {} meta.update( {key: value for key, value in kwargs.items() if value is not None} ) self.runs[run_id] = self.runs[parent_run_id].span( id=self.next_span_id, name=self.get_langchain_run_name(serialized, **kwargs), input=input_str, metadata=meta, version=self.version, ) self.next_span_id = None except Exception as e: self.log.exception(e) def on_retriever_start( self, serialized: Dict[str, Any], query: str, *, run_id: UUID, parent_run_id: Optional[UUID] = None, tags: Optional[List[str]] = None, metadata: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> Any: try: self._log_debug_event( "on_retriever_start", run_id, parent_run_id, query=query ) if parent_run_id is None or parent_run_id not in self.runs: raise Exception("parent run not found") self.runs[run_id] = self.runs[parent_run_id].span( id=self.next_span_id, name=self.get_langchain_run_name(serialized, **kwargs), input=query, metadata=self.__join_tags_and_metadata(tags, metadata), version=self.version, ) self.next_span_id = None except Exception as e: self.log.exception(e) def on_retriever_end( self, documents: Sequence[Document], *, run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any, ) -> Any: try: self._log_debug_event( "on_retriever_end", run_id, parent_run_id, documents=documents ) if run_id is None or run_id not in self.runs: raise Exception("run not found") self.runs[run_id] = self.runs[run_id].end( output=documents, version=self.version, input=kwargs.get("inputs"), ) self._update_trace_and_remove_state(run_id, parent_run_id, documents) except Exception as e: self.log.exception(e) def on_tool_end( self, output: str, *, run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any, ) -> Any: try: self._log_debug_event("on_tool_end", run_id, parent_run_id, output=output) if run_id is None or run_id not in self.runs: raise Exception("run not found") self.runs[run_id] = self.runs[run_id].end( output=output, version=self.version, input=kwargs.get("inputs"), ) self._update_trace_and_remove_state(run_id, parent_run_id, output) except Exception as e: self.log.exception(e) def on_tool_error( self, error: Union[Exception, KeyboardInterrupt], *, run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any, ) -> Any: try: self._log_debug_event("on_tool_error", run_id, parent_run_id, error=error) if run_id is None or run_id not in self.runs: raise Exception("run not found") self.runs[run_id] = self.runs[run_id].end( status_message=str(error), level="ERROR", version=self.version, input=kwargs.get("inputs"), ) self._update_trace_and_remove_state(run_id, parent_run_id, error) except Exception as e: self.log.exception(e) def __on_llm_action( self, serialized: Dict[str, Any], run_id: UUID, prompts: List[str], parent_run_id: Optional[UUID] = None, tags: Optional[List[str]] = None, metadata: Optional[Dict[str, Any]] = None, **kwargs: Any, ): try: self.__generate_trace_and_parent( serialized, inputs=prompts[0] if len(prompts) == 1 else prompts, run_id=run_id, parent_run_id=parent_run_id, tags=tags, metadata=metadata, version=self.version, kwargs=kwargs, ) model_name = None model_name = self._parse_model_and_log_errors(serialized, kwargs) content = { "name": self.get_langchain_run_name(serialized, **kwargs), "input": prompts, "metadata": self.__join_tags_and_metadata(tags, metadata), "model": model_name, "model_parameters": { key: value for key, value in { "temperature": kwargs["invocation_params"].get("temperature"), "max_tokens": kwargs["invocation_params"].get("max_tokens"), "top_p": kwargs["invocation_params"].get("top_p"), "frequency_penalty": kwargs["invocation_params"].get( "frequency_penalty" ), "presence_penalty": kwargs["invocation_params"].get( "presence_penalty" ), "request_timeout": kwargs["invocation_params"].get( "request_timeout" ), }.items() if value is not None }, "version": self.version, } if parent_run_id in self.runs: self.runs[run_id] = self.runs[parent_run_id].generation(**content) elif self.root_span is not None and parent_run_id is None: self.runs[run_id] = self.root_span.generation(**content) else: self.runs[run_id] = self.trace.generation(**content) except Exception as e: self.log.exception(e) def _parse_model_and_log_errors(self, serialized, kwargs): """Parse the model name from the serialized object or kwargs. If it fails, send the error log to the server and return None.""" try: model_name = _extract_model_name(serialized, **kwargs) if model_name: return model_name if model_name is None: self.log.warning( "Langfuse was not able to parse the LLM model. The LLM call will be recorded without model name. Please create an issue so we can fix your integration: https://github.com/langfuse/langfuse/issues/new/choose" ) self._report_error( { "log": "unable to parse model name", "kwargs": str(kwargs), "serialized": str(serialized), } ) except Exception as e: self.log.exception(e) self.log.warning( "Langfuse was not able to parse the LLM model. The LLM call will be recorded without model name. Please create an issue so we can fix your integration: https://github.com/langfuse/langfuse/issues/new/choose" ) self._report_error( { "log": "unable to parse model name", "kwargs": str(kwargs), "serialized": str(serialized), "exception": str(e), } ) return None def on_llm_end( self, response: LLMResult, *, run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any, ) -> Any: try: self._log_debug_event( "on_llm_end", run_id, parent_run_id, response=response, kwargs=kwargs ) if run_id not in self.runs: raise Exception("Run not found, see docs what to do in this case.") else: generation = response.generations[-1][-1] extracted_response = ( self._convert_message_to_dict(generation.message) if isinstance(generation, ChatGeneration) else _extract_raw_esponse(generation) ) llm_usage = _parse_usage(response) # e.g. azure returns the model name in the response model = _parse_model(response) self.runs[run_id] = self.runs[run_id].end( output=extracted_response, usage=llm_usage, version=self.version, input=kwargs.get("inputs"), model=model, ) self._update_trace_and_remove_state( run_id, parent_run_id, extracted_response ) except Exception as e: self.log.exception(e) def on_llm_error( self, error: Union[Exception, KeyboardInterrupt], *, run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any, ) -> Any: try: self._log_debug_event("on_llm_error", run_id, parent_run_id, error=error) self.runs[run_id] = self.runs[run_id].end( status_message=str(error), level="ERROR", version=self.version, input=kwargs.get("inputs"), ) self._update_trace_and_remove_state(run_id, parent_run_id, error) except Exception as e: self.log.exception(e) def __join_tags_and_metadata( self, tags: Optional[List[str]] = None, metadata: Optional[Dict[str, Any]] = None, trace_metadata: Optional[Dict[str, Any]] = None, ) -> Optional[Dict[str, Any]]: final_dict = {} if tags is not None and len(tags) > 0: final_dict["tags"] = tags if metadata is not None: final_dict.update(metadata) if trace_metadata is not None: final_dict.update(trace_metadata) return final_dict if final_dict != {} else None def _report_error(self, error: dict): event = SdkLogBody(log=error) self._task_manager.add_task( { "id": str(uuid4()), "type": "sdk-log", "timestamp": _get_timestamp(), "body": event.dict(), } ) def _update_trace_and_remove_state( self, run_id: str, parent_run_id: Optional[str], output: any, *, keep_state: bool = False, **kwargs: Any, ): """Update the trace with the output of the current run. Called at every finish callback event.""" if ( parent_run_id is None # If we are at the root of the langchain execution -> reached the end of the root and self.trace is not None # We do have a trace available and self.trace.id == str(run_id) # The trace was generated by langchain and not by the user ): self.trace = self.trace.update(output=output, **kwargs) elif ( parent_run_id is None and self.trace is not None # We have a user-provided parent and self.update_stateful_client ): if self.root_span is not None: self.root_span = self.root_span.update(output=output, **kwargs) else: self.trace = self.trace.update(output=output, **kwargs) if not keep_state: del self.runs[run_id] def _convert_message_to_dict(self, message: BaseMessage) -> Dict[str, Any]: # assistant message if isinstance(message, HumanMessage): message_dict = {"role": "user", "content": message.content} elif isinstance(message, AIMessage): message_dict = {"role": "assistant", "content": message.content} elif isinstance(message, SystemMessage): message_dict = {"role": "system", "content": message.content} elif isinstance(message, ToolMessage): message_dict = {"role": "tool", "content": message.content} elif isinstance(message, FunctionMessage): message_dict = {"role": "function", "content": message.content} elif isinstance(message, ChatMessage): message_dict = {"role": message.role, "content": message.content} else: raise ValueError(f"Got unknown type {message}") if "name" in message.additional_kwargs: message_dict["name"] = message.additional_kwargs["name"] if message.additional_kwargs: message_dict["additional_kwargs"] = message.additional_kwargs return message_dict def _create_message_dicts( self, messages: List[BaseMessage] ) -> List[Dict[str, Any]]: return [self._convert_message_to_dict(m) for m in messages] def _log_debug_event( self, event_name: str, run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs, ): kwargs_log = ( ", " + ", ".join([f"{key}: {value}" for key, value in kwargs.items()]) if len(kwargs) > 0 else "" ) self.log.debug( f"Event: {event_name}, run_id: {str(run_id)[:5]}, parent_run_id: {str(parent_run_id)[:5]}" + kwargs_log ) def _extract_raw_esponse(last_response): """Extract the response from the last response of the LLM call.""" # We return the text of the response if not empty if last_response.text is not None and last_response.text.strip() != "": return last_response.text.strip() elif hasattr(last_response, "message"): # Additional kwargs contains the response in case of tool usage return last_response.message.additional_kwargs else: # Not tool usage, some LLM responses can be simply empty return "" def _flatten_comprehension(matrix): return [item for row in matrix for item in row] def _parse_usage_model(usage: typing.Union[pydantic.BaseModel, dict]): # maintains a list of key translations. For each key, the usage model is checked # and a new object will be created with the new key if the key exists in the usage model # All non matched keys will remain on the object. if hasattr(usage, "__dict__"): usage = usage.__dict__ conversion_list = [ # https://pypi.org/project/langchain-anthropic/ (works also for Bedrock-Anthropic) ("input_tokens", "input"), ("output_tokens", "output"), # https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/get-token-count ("prompt_token_count", "input"), ("candidates_token_count", "output"), # Bedrock: https://docs.aws.amazon.com/bedrock/latest/userguide/monitoring-cw.html#runtime-cloudwatch-metrics ("inputTokenCount", "input"), ("outputTokenCount", "output"), ] usage_model = usage.copy() # Copy all existing key-value pairs for model_key, langfuse_key in conversion_list: if model_key in usage_model: captured_count = usage_model.pop(model_key) final_count = ( sum(captured_count) if isinstance(captured_count, list) else captured_count ) # For Bedrock, the token count is a list when streamed usage_model[langfuse_key] = final_count # Translate key and keep the value return usage_model if usage_model else None def _parse_usage(response: LLMResult): # langchain-anthropic uses the usage field llm_usage_keys = ["token_usage", "usage"] llm_usage = None if response.llm_output is not None: for key in llm_usage_keys: if key in response.llm_output and response.llm_output[key]: llm_usage = _parse_usage_model(response.llm_output[key]) break if hasattr(response, "generations"): for generation in response.generations: for generation_chunk in generation: if generation_chunk.generation_info and ( "usage_metadata" in generation_chunk.generation_info ): llm_usage = _parse_usage_model( generation_chunk.generation_info["usage_metadata"] ) break message_chunk = getattr(generation_chunk, "message", {}) response_metadata = getattr(message_chunk, "response_metadata", {}) chunk_usage = ( response_metadata.get("usage", None) # for Bedrock-Anthropic if isinstance(response_metadata, dict) else None ) or ( response_metadata.get( "amazon-bedrock-invocationMetrics", None ) # for Bedrock-Titan if isinstance(response_metadata, dict) else None ) if chunk_usage: llm_usage = _parse_usage_model(chunk_usage) break return llm_usage def _parse_model(response: LLMResult): # langchain-anthropic uses the usage field llm_model_keys = ["model_name"] llm_model = None if response.llm_output is not None: for key in llm_model_keys: if key in response.llm_output and response.llm_output[key]: llm_model = response.llm_output[key] break return llm_model
Memory