import time import uuid from typing import Any, Dict, List, Optional try: import openai import openai.resources except ImportError: raise ModuleNotFoundError("Please install the OpenAI SDK to use this feature: 'pip install openai'") from posthog.ai.utils import call_llm_and_track_usage_async, get_model_params, with_privacy_mode from posthog.client import Client as PostHogClient class AsyncOpenAI(openai.AsyncOpenAI): """ An async wrapper around the OpenAI SDK that automatically sends LLM usage events to PostHog. """ _ph_client: PostHogClient def __init__(self, posthog_client: PostHogClient, **kwargs): """ Args: api_key: OpenAI API key. posthog_client: If provided, events will be captured via this client instance. **openai_config: Additional keyword args (e.g. organization="xxx"). """ super().__init__(**kwargs) self._ph_client = posthog_client self.chat = WrappedChat(self) self.embeddings = WrappedEmbeddings(self) self.beta = WrappedBeta(self) self.responses = WrappedResponses(self) class WrappedResponses(openai.resources.responses.Responses): _client: AsyncOpenAI async def create( self, posthog_distinct_id: Optional[str] = None, posthog_trace_id: Optional[str] = None, posthog_properties: Optional[Dict[str, Any]] = None, posthog_privacy_mode: bool = False, posthog_groups: Optional[Dict[str, Any]] = None, **kwargs: Any, ): if posthog_trace_id is None: posthog_trace_id = str(uuid.uuid4()) if kwargs.get("stream", False): return await self._create_streaming( posthog_distinct_id, posthog_trace_id, posthog_properties, posthog_privacy_mode, posthog_groups, **kwargs, ) return await call_llm_and_track_usage_async( posthog_distinct_id, self._client._ph_client, "openai", posthog_trace_id, posthog_properties, posthog_privacy_mode, posthog_groups, self._client.base_url, super().create, **kwargs, ) async def _create_streaming( self, posthog_distinct_id: Optional[str], posthog_trace_id: Optional[str], posthog_properties: Optional[Dict[str, Any]], posthog_privacy_mode: bool, posthog_groups: Optional[Dict[str, Any]], **kwargs: Any, ): start_time = time.time() usage_stats: Dict[str, int] = {} final_content = [] response = await super().create(**kwargs) async def async_generator(): nonlocal usage_stats nonlocal final_content # noqa: F824 try: async for chunk in response: if hasattr(chunk, "type") and chunk.type == "response.completed": res = chunk.response if res.output and len(res.output) > 0: final_content.append(res.output[0]) if hasattr(chunk, "usage") and chunk.usage: usage_stats = { k: getattr(chunk.usage, k, 0) for k in [ "input_tokens", "output_tokens", "total_tokens", ] } # Add support for cached tokens if hasattr(chunk.usage, "output_tokens_details") and hasattr( chunk.usage.output_tokens_details, "reasoning_tokens" ): usage_stats["reasoning_tokens"] = chunk.usage.output_tokens_details.reasoning_tokens if hasattr(chunk.usage, "input_tokens_details") and hasattr( chunk.usage.input_tokens_details, "cached_tokens" ): usage_stats["cache_read_input_tokens"] = chunk.usage.input_tokens_details.cached_tokens yield chunk finally: end_time = time.time() latency = end_time - start_time output = final_content await self._capture_streaming_event( posthog_distinct_id, posthog_trace_id, posthog_properties, posthog_privacy_mode, posthog_groups, kwargs, usage_stats, latency, output, ) return async_generator() async def _capture_streaming_event( self, posthog_distinct_id: Optional[str], posthog_trace_id: Optional[str], posthog_properties: Optional[Dict[str, Any]], posthog_privacy_mode: bool, posthog_groups: Optional[Dict[str, Any]], kwargs: Dict[str, Any], usage_stats: Dict[str, int], latency: float, output: Any, tool_calls: Optional[List[Dict[str, Any]]] = None, ): if posthog_trace_id is None: posthog_trace_id = str(uuid.uuid4()) event_properties = { "$ai_provider": "openai", "$ai_model": kwargs.get("model"), "$ai_model_parameters": get_model_params(kwargs), "$ai_input": with_privacy_mode(self._client._ph_client, posthog_privacy_mode, kwargs.get("input")), "$ai_output_choices": with_privacy_mode( self._client._ph_client, posthog_privacy_mode, output, ), "$ai_http_status": 200, "$ai_input_tokens": usage_stats.get("input_tokens", 0), "$ai_output_tokens": usage_stats.get("output_tokens", 0), "$ai_cache_read_input_tokens": usage_stats.get("cache_read_input_tokens", 0), "$ai_reasoning_tokens": usage_stats.get("reasoning_tokens", 0), "$ai_latency": latency, "$ai_trace_id": posthog_trace_id, "$ai_base_url": str(self._client.base_url), **(posthog_properties or {}), } if tool_calls: event_properties["$ai_tools"] = with_privacy_mode( self._client._ph_client, posthog_privacy_mode, tool_calls, ) if posthog_distinct_id is None: event_properties["$process_person_profile"] = False if hasattr(self._client._ph_client, "capture"): await self._client._ph_client.capture( distinct_id=posthog_distinct_id or posthog_trace_id, event="$ai_generation", properties=event_properties, groups=posthog_groups, ) class WrappedChat(openai.resources.chat.AsyncChat): _client: AsyncOpenAI @property def completions(self): return WrappedCompletions(self._client) class WrappedCompletions(openai.resources.chat.completions.AsyncCompletions): _client: AsyncOpenAI async def create( self, posthog_distinct_id: Optional[str] = None, posthog_trace_id: Optional[str] = None, posthog_properties: Optional[Dict[str, Any]] = None, posthog_privacy_mode: bool = False, posthog_groups: Optional[Dict[str, Any]] = None, **kwargs: Any, ): if posthog_trace_id is None: posthog_trace_id = str(uuid.uuid4()) # If streaming, handle streaming specifically if kwargs.get("stream", False): return await self._create_streaming( posthog_distinct_id, posthog_trace_id, posthog_properties, posthog_privacy_mode, posthog_groups, **kwargs, ) response = await call_llm_and_track_usage_async( posthog_distinct_id, self._client._ph_client, "openai", posthog_trace_id, posthog_properties, posthog_privacy_mode, posthog_groups, self._client.base_url, super().create, **kwargs, ) return response async def _create_streaming( self, posthog_distinct_id: Optional[str], posthog_trace_id: Optional[str], posthog_properties: Optional[Dict[str, Any]], posthog_privacy_mode: bool = False, posthog_groups: Optional[Dict[str, Any]] = None, **kwargs: Any, ): start_time = time.time() usage_stats: Dict[str, int] = {} accumulated_content = [] accumulated_tools = {} if "stream_options" not in kwargs: kwargs["stream_options"] = {} kwargs["stream_options"]["include_usage"] = True response = await super().create(**kwargs) async def async_generator(): nonlocal usage_stats, accumulated_content, accumulated_tools # noqa: F824 try: async for chunk in response: if hasattr(chunk, "usage") and chunk.usage: usage_stats = { k: getattr(chunk.usage, k, 0) for k in [ "prompt_tokens", "completion_tokens", "total_tokens", ] } # Add support for cached tokens if hasattr(chunk.usage, "prompt_tokens_details") and hasattr( chunk.usage.prompt_tokens_details, "cached_tokens" ): usage_stats["cache_read_input_tokens"] = chunk.usage.prompt_tokens_details.cached_tokens if hasattr(chunk, "choices") and chunk.choices and len(chunk.choices) > 0: if chunk.choices[0].delta and chunk.choices[0].delta.content: content = chunk.choices[0].delta.content if content: accumulated_content.append(content) # Process tool calls tool_calls = getattr(chunk.choices[0].delta, "tool_calls", None) if tool_calls: for tool_call in tool_calls: index = tool_call.index if index not in accumulated_tools: accumulated_tools[index] = tool_call else: # Append arguments for existing tool calls if hasattr(tool_call, "function") and hasattr(tool_call.function, "arguments"): accumulated_tools[index].function.arguments += tool_call.function.arguments yield chunk finally: end_time = time.time() latency = end_time - start_time output = "".join(accumulated_content) tools = list(accumulated_tools.values()) if accumulated_tools else None await self._capture_streaming_event( posthog_distinct_id, posthog_trace_id, posthog_properties, posthog_privacy_mode, posthog_groups, kwargs, usage_stats, latency, output, tools, ) return async_generator() async def _capture_streaming_event( self, posthog_distinct_id: Optional[str], posthog_trace_id: Optional[str], posthog_properties: Optional[Dict[str, Any]], posthog_privacy_mode: bool, posthog_groups: Optional[Dict[str, Any]], kwargs: Dict[str, Any], usage_stats: Dict[str, int], latency: float, output: Any, tool_calls: Optional[List[Dict[str, Any]]] = None, ): if posthog_trace_id is None: posthog_trace_id = str(uuid.uuid4()) event_properties = { "$ai_provider": "openai", "$ai_model": kwargs.get("model"), "$ai_model_parameters": get_model_params(kwargs), "$ai_input": with_privacy_mode(self._client._ph_client, posthog_privacy_mode, kwargs.get("messages")), "$ai_output_choices": with_privacy_mode( self._client._ph_client, posthog_privacy_mode, [{"content": output, "role": "assistant"}], ), "$ai_http_status": 200, "$ai_input_tokens": usage_stats.get("prompt_tokens", 0), "$ai_output_tokens": usage_stats.get("completion_tokens", 0), "$ai_cache_read_input_tokens": usage_stats.get("cache_read_input_tokens", 0), "$ai_latency": latency, "$ai_trace_id": posthog_trace_id, "$ai_base_url": str(self._client.base_url), **(posthog_properties or {}), } if tool_calls: event_properties["$ai_tools"] = with_privacy_mode( self._client._ph_client, posthog_privacy_mode, tool_calls, ) if posthog_distinct_id is None: event_properties["$process_person_profile"] = False if hasattr(self._client._ph_client, "capture"): await self._client._ph_client.capture( distinct_id=posthog_distinct_id or posthog_trace_id, event="$ai_generation", properties=event_properties, groups=posthog_groups, ) class WrappedEmbeddings(openai.resources.embeddings.AsyncEmbeddings): _client: AsyncOpenAI async def create( self, posthog_distinct_id: Optional[str] = None, posthog_trace_id: Optional[str] = None, posthog_properties: Optional[Dict[str, Any]] = None, posthog_privacy_mode: bool = False, posthog_groups: Optional[Dict[str, Any]] = None, **kwargs: Any, ): """ Create an embedding using OpenAI's 'embeddings.create' method, but also track usage in PostHog. Args: posthog_distinct_id: Optional ID to associate with the usage event. posthog_trace_id: Optional trace UUID for linking events. posthog_properties: Optional dictionary of extra properties to include in the event. posthog_privacy_mode: Whether to store input and output in PostHog. posthog_groups: Optional dictionary of groups to include in the event. **kwargs: Any additional parameters for the OpenAI Embeddings API. Returns: The response from OpenAI's embeddings.create call. """ if posthog_trace_id is None: posthog_trace_id = str(uuid.uuid4()) start_time = time.time() response = await super().create(**kwargs) end_time = time.time() # Extract usage statistics if available usage_stats = {} if hasattr(response, "usage") and response.usage: usage_stats = { "prompt_tokens": getattr(response.usage, "prompt_tokens", 0), "total_tokens": getattr(response.usage, "total_tokens", 0), } latency = end_time - start_time # Build the event properties event_properties = { "$ai_provider": "openai", "$ai_model": kwargs.get("model"), "$ai_input": with_privacy_mode(self._client._ph_client, posthog_privacy_mode, kwargs.get("input")), "$ai_http_status": 200, "$ai_input_tokens": usage_stats.get("prompt_tokens", 0), "$ai_latency": latency, "$ai_trace_id": posthog_trace_id, "$ai_base_url": str(self._client.base_url), **(posthog_properties or {}), } if posthog_distinct_id is None: event_properties["$process_person_profile"] = False # Send capture event for embeddings if hasattr(self._client._ph_client, "capture"): self._client._ph_client.capture( distinct_id=posthog_distinct_id or posthog_trace_id, event="$ai_embedding", properties=event_properties, groups=posthog_groups, ) return response class WrappedBeta(openai.resources.beta.AsyncBeta): _client: AsyncOpenAI @property def chat(self): return WrappedBetaChat(self._client) class WrappedBetaChat(openai.resources.beta.chat.AsyncChat): _client: AsyncOpenAI @property def completions(self): return WrappedBetaCompletions(self._client) class WrappedBetaCompletions(openai.resources.beta.chat.completions.AsyncCompletions): _client: AsyncOpenAI async def parse( self, posthog_distinct_id: Optional[str] = None, posthog_trace_id: Optional[str] = None, posthog_properties: Optional[Dict[str, Any]] = None, posthog_privacy_mode: bool = False, posthog_groups: Optional[Dict[str, Any]] = None, **kwargs: Any, ): return await call_llm_and_track_usage_async( posthog_distinct_id, self._client._ph_client, "openai", posthog_trace_id, posthog_properties, posthog_privacy_mode, posthog_groups, self._client.base_url, super().parse, **kwargs, )
Memory