from __future__ import annotations import inspect from types import TracebackType from typing import Any, List, Generic, Iterable, Awaitable, cast from typing_extensions import Self, Callable, Iterator, AsyncIterator from ._types import ParsedResponseSnapshot from ._events import ( ResponseStreamEvent, ResponseTextDoneEvent, ResponseCompletedEvent, ResponseTextDeltaEvent, ResponseFunctionCallArgumentsDeltaEvent, ) from ...._types import NOT_GIVEN, NotGiven from ...._utils import is_given, consume_sync_iterator, consume_async_iterator from ...._models import build, construct_type_unchecked from ...._streaming import Stream, AsyncStream from ....types.responses import ParsedResponse, ResponseStreamEvent as RawResponseStreamEvent from ..._parsing._responses import TextFormatT, parse_text, parse_response from ....types.responses.tool_param import ToolParam from ....types.responses.parsed_response import ( ParsedContent, ParsedResponseOutputMessage, ParsedResponseFunctionToolCall, ) class ResponseStream(Generic[TextFormatT]): def __init__( self, *, raw_stream: Stream[RawResponseStreamEvent], text_format: type[TextFormatT] | NotGiven, input_tools: Iterable[ToolParam] | NotGiven, ) -> None: self._raw_stream = raw_stream self._response = raw_stream.response self._iterator = self.__stream__() self._state = ResponseStreamState(text_format=text_format, input_tools=input_tools) def __next__(self) -> ResponseStreamEvent[TextFormatT]: return self._iterator.__next__() def __iter__(self) -> Iterator[ResponseStreamEvent[TextFormatT]]: for item in self._iterator: yield item def __enter__(self) -> Self: return self def __stream__(self) -> Iterator[ResponseStreamEvent[TextFormatT]]: for sse_event in self._raw_stream: events_to_fire = self._state.handle_event(sse_event) for event in events_to_fire: yield event def __exit__( self, exc_type: type[BaseException] | None, exc: BaseException | None, exc_tb: TracebackType | None, ) -> None: self.close() def close(self) -> None: """ Close the response and release the connection. Automatically called if the response body is read to completion. """ self._response.close() def get_final_response(self) -> ParsedResponse[TextFormatT]: """Waits until the stream has been read to completion and returns the accumulated `ParsedResponse` object. """ self.until_done() response = self._state._completed_response if not response: raise RuntimeError("Didn't receive a `response.completed` event.") return response def until_done(self) -> Self: """Blocks until the stream has been consumed.""" consume_sync_iterator(self) return self class ResponseStreamManager(Generic[TextFormatT]): def __init__( self, api_request: Callable[[], Stream[RawResponseStreamEvent]], *, text_format: type[TextFormatT] | NotGiven, input_tools: Iterable[ToolParam] | NotGiven, ) -> None: self.__stream: ResponseStream[TextFormatT] | None = None self.__api_request = api_request self.__text_format = text_format self.__input_tools = input_tools def __enter__(self) -> ResponseStream[TextFormatT]: raw_stream = self.__api_request() self.__stream = ResponseStream( raw_stream=raw_stream, text_format=self.__text_format, input_tools=self.__input_tools, ) return self.__stream def __exit__( self, exc_type: type[BaseException] | None, exc: BaseException | None, exc_tb: TracebackType | None, ) -> None: if self.__stream is not None: self.__stream.close() class AsyncResponseStream(Generic[TextFormatT]): def __init__( self, *, raw_stream: AsyncStream[RawResponseStreamEvent], text_format: type[TextFormatT] | NotGiven, input_tools: Iterable[ToolParam] | NotGiven, ) -> None: self._raw_stream = raw_stream self._response = raw_stream.response self._iterator = self.__stream__() self._state = ResponseStreamState(text_format=text_format, input_tools=input_tools) async def __anext__(self) -> ResponseStreamEvent[TextFormatT]: return await self._iterator.__anext__() async def __aiter__(self) -> AsyncIterator[ResponseStreamEvent[TextFormatT]]: async for item in self._iterator: yield item async def __stream__(self) -> AsyncIterator[ResponseStreamEvent[TextFormatT]]: async for sse_event in self._raw_stream: events_to_fire = self._state.handle_event(sse_event) for event in events_to_fire: yield event async def __aenter__(self) -> Self: return self async def __aexit__( self, exc_type: type[BaseException] | None, exc: BaseException | None, exc_tb: TracebackType | None, ) -> None: await self.close() async def close(self) -> None: """ Close the response and release the connection. Automatically called if the response body is read to completion. """ await self._response.aclose() async def get_final_response(self) -> ParsedResponse[TextFormatT]: """Waits until the stream has been read to completion and returns the accumulated `ParsedResponse` object. """ await self.until_done() response = self._state._completed_response if not response: raise RuntimeError("Didn't receive a `response.completed` event.") return response async def until_done(self) -> Self: """Blocks until the stream has been consumed.""" await consume_async_iterator(self) return self class AsyncResponseStreamManager(Generic[TextFormatT]): def __init__( self, api_request: Awaitable[AsyncStream[RawResponseStreamEvent]], *, text_format: type[TextFormatT] | NotGiven, input_tools: Iterable[ToolParam] | NotGiven, ) -> None: self.__stream: AsyncResponseStream[TextFormatT] | None = None self.__api_request = api_request self.__text_format = text_format self.__input_tools = input_tools async def __aenter__(self) -> AsyncResponseStream[TextFormatT]: raw_stream = await self.__api_request self.__stream = AsyncResponseStream( raw_stream=raw_stream, text_format=self.__text_format, input_tools=self.__input_tools, ) return self.__stream async def __aexit__( self, exc_type: type[BaseException] | None, exc: BaseException | None, exc_tb: TracebackType | None, ) -> None: if self.__stream is not None: await self.__stream.close() class ResponseStreamState(Generic[TextFormatT]): def __init__( self, *, input_tools: Iterable[ToolParam] | NotGiven, text_format: type[TextFormatT] | NotGiven, ) -> None: self.__current_snapshot: ParsedResponseSnapshot | None = None self._completed_response: ParsedResponse[TextFormatT] | None = None self._input_tools = [tool for tool in input_tools] if is_given(input_tools) else [] self._text_format = text_format self._rich_text_format: type | NotGiven = text_format if inspect.isclass(text_format) else NOT_GIVEN def handle_event(self, event: RawResponseStreamEvent) -> List[ResponseStreamEvent[TextFormatT]]: self.__current_snapshot = snapshot = self.accumulate_event(event) events: List[ResponseStreamEvent[TextFormatT]] = [] if event.type == "response.output_text.delta": output = snapshot.output[event.output_index] assert output.type == "message" content = output.content[event.content_index] assert content.type == "output_text" events.append( build( ResponseTextDeltaEvent, content_index=event.content_index, delta=event.delta, item_id=event.item_id, output_index=event.output_index, type="response.output_text.delta", snapshot=content.text, ) ) elif event.type == "response.output_text.done": output = snapshot.output[event.output_index] assert output.type == "message" content = output.content[event.content_index] assert content.type == "output_text" events.append( build( ResponseTextDoneEvent[TextFormatT], content_index=event.content_index, item_id=event.item_id, output_index=event.output_index, type="response.output_text.done", text=event.text, parsed=parse_text(event.text, text_format=self._text_format), ) ) elif event.type == "response.function_call_arguments.delta": output = snapshot.output[event.output_index] assert output.type == "function_call" events.append( build( ResponseFunctionCallArgumentsDeltaEvent, delta=event.delta, item_id=event.item_id, output_index=event.output_index, type="response.function_call_arguments.delta", snapshot=output.arguments, ) ) elif event.type == "response.completed": response = self._completed_response assert response is not None events.append( build( ResponseCompletedEvent, type="response.completed", response=response, ) ) else: events.append(event) return events def accumulate_event(self, event: RawResponseStreamEvent) -> ParsedResponseSnapshot: snapshot = self.__current_snapshot if snapshot is None: return self._create_initial_response(event) if event.type == "response.output_item.added": if event.item.type == "function_call": snapshot.output.append( construct_type_unchecked( type_=cast(Any, ParsedResponseFunctionToolCall), value=event.item.to_dict() ) ) elif event.item.type == "message": snapshot.output.append( construct_type_unchecked(type_=cast(Any, ParsedResponseOutputMessage), value=event.item.to_dict()) ) else: snapshot.output.append(event.item) elif event.type == "response.content_part.added": output = snapshot.output[event.output_index] if output.type == "message": output.content.append( construct_type_unchecked(type_=cast(Any, ParsedContent), value=event.part.to_dict()) ) elif event.type == "response.output_text.delta": output = snapshot.output[event.output_index] if output.type == "message": content = output.content[event.content_index] assert content.type == "output_text" content.text += event.delta elif event.type == "response.function_call_arguments.delta": output = snapshot.output[event.output_index] if output.type == "function_call": output.arguments += event.delta elif event.type == "response.completed": self._completed_response = parse_response( text_format=self._text_format, response=event.response, input_tools=self._input_tools, ) return snapshot def _create_initial_response(self, event: RawResponseStreamEvent) -> ParsedResponseSnapshot: if event.type != "response.created": raise RuntimeError(f"Expected to have received `response.created` before `{event.type}`") return construct_type_unchecked(type_=ParsedResponseSnapshot, value=event.response.to_dict())
Memory