from __future__ import annotations
from types import TracebackType
from typing import TYPE_CHECKING, Any, Type, Callable, cast
from typing_extensions import Self, Iterator, Awaitable, AsyncIterator, assert_never
import httpx
from pydantic import BaseModel
from ._types import (
TextEvent,
CitationEvent,
ThinkingEvent,
InputJsonEvent,
SignatureEvent,
MessageStopEvent,
MessageStreamEvent,
ContentBlockStopEvent,
)
from ...types import Message, ContentBlock, RawMessageStreamEvent
from ..._utils import consume_sync_iterator, consume_async_iterator
from ..._models import build, construct_type, construct_type_unchecked
from ..._streaming import Stream, AsyncStream
class MessageStream:
text_stream: Iterator[str]
"""Iterator over just the text deltas in the stream.
```py
for text in stream.text_stream:
print(text, end="", flush=True)
print()
```
"""
def __init__(self, raw_stream: Stream[RawMessageStreamEvent]) -> None:
self._raw_stream = raw_stream
self.text_stream = self.__stream_text__()
self._iterator = self.__stream__()
self.__final_message_snapshot: Message | None = None
@property
def response(self) -> httpx.Response:
return self._raw_stream.response
@property
def request_id(self) -> str | None:
return self.response.headers.get("request-id") # type: ignore[no-any-return]
def __next__(self) -> MessageStreamEvent:
return self._iterator.__next__()
def __iter__(self) -> Iterator[MessageStreamEvent]:
for item in self._iterator:
yield item
def __enter__(self) -> Self:
return self
def __exit__(
self,
exc_type: type[BaseException] | None,
exc: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
self.close()
def close(self) -> None:
"""
Close the response and release the connection.
Automatically called if the response body is read to completion.
"""
self._raw_stream.close()
def get_final_message(self) -> Message:
"""Waits until the stream has been read to completion and returns
the accumulated `Message` object.
"""
self.until_done()
assert self.__final_message_snapshot is not None
return self.__final_message_snapshot
def get_final_text(self) -> str:
"""Returns all `text` content blocks concatenated together.
> [!NOTE]
> Currently the API will only respond with a single content block.
Will raise an error if no `text` content blocks were returned.
"""
message = self.get_final_message()
text_blocks: list[str] = []
for block in message.content:
if block.type == "text":
text_blocks.append(block.text)
if not text_blocks:
raise RuntimeError("Expected to have received at least 1 text block")
return "".join(text_blocks)
def until_done(self) -> None:
"""Blocks until the stream has been consumed"""
consume_sync_iterator(self)
# properties
@property
def current_message_snapshot(self) -> Message:
assert self.__final_message_snapshot is not None
return self.__final_message_snapshot
def __stream__(self) -> Iterator[MessageStreamEvent]:
for sse_event in self._raw_stream:
self.__final_message_snapshot = accumulate_event(
event=sse_event,
current_snapshot=self.__final_message_snapshot,
)
events_to_fire = build_events(event=sse_event, message_snapshot=self.current_message_snapshot)
for event in events_to_fire:
yield event
def __stream_text__(self) -> Iterator[str]:
for chunk in self:
if chunk.type == "content_block_delta" and chunk.delta.type == "text_delta":
yield chunk.delta.text
class MessageStreamManager:
"""Wrapper over MessageStream that is returned by `.stream()`.
```py
with client.messages.stream(...) as stream:
for chunk in stream:
...
```
"""
def __init__(
self,
api_request: Callable[[], Stream[RawMessageStreamEvent]],
) -> None:
self.__stream: MessageStream | None = None
self.__api_request = api_request
def __enter__(self) -> MessageStream:
raw_stream = self.__api_request()
self.__stream = MessageStream(raw_stream)
return self.__stream
def __exit__(
self,
exc_type: type[BaseException] | None,
exc: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
if self.__stream is not None:
self.__stream.close()
class AsyncMessageStream:
text_stream: AsyncIterator[str]
"""Async iterator over just the text deltas in the stream.
```py
async for text in stream.text_stream:
print(text, end="", flush=True)
print()
```
"""
def __init__(self, raw_stream: AsyncStream[RawMessageStreamEvent]) -> None:
self._raw_stream = raw_stream
self.text_stream = self.__stream_text__()
self._iterator = self.__stream__()
self.__final_message_snapshot: Message | None = None
@property
def response(self) -> httpx.Response:
return self._raw_stream.response
@property
def request_id(self) -> str | None:
return self.response.headers.get("request-id") # type: ignore[no-any-return]
async def __anext__(self) -> MessageStreamEvent:
return await self._iterator.__anext__()
async def __aiter__(self) -> AsyncIterator[MessageStreamEvent]:
async for item in self._iterator:
yield item
async def __aenter__(self) -> Self:
return self
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
await self.close()
async def close(self) -> None:
"""
Close the response and release the connection.
Automatically called if the response body is read to completion.
"""
await self._raw_stream.close()
async def get_final_message(self) -> Message:
"""Waits until the stream has been read to completion and returns
the accumulated `Message` object.
"""
await self.until_done()
assert self.__final_message_snapshot is not None
return self.__final_message_snapshot
async def get_final_text(self) -> str:
"""Returns all `text` content blocks concatenated together.
> [!NOTE]
> Currently the API will only respond with a single content block.
Will raise an error if no `text` content blocks were returned.
"""
message = await self.get_final_message()
text_blocks: list[str] = []
for block in message.content:
if block.type == "text":
text_blocks.append(block.text)
if not text_blocks:
raise RuntimeError("Expected to have received at least 1 text block")
return "".join(text_blocks)
async def until_done(self) -> None:
"""Waits until the stream has been consumed"""
await consume_async_iterator(self)
# properties
@property
def current_message_snapshot(self) -> Message:
assert self.__final_message_snapshot is not None
return self.__final_message_snapshot
async def __stream__(self) -> AsyncIterator[MessageStreamEvent]:
async for sse_event in self._raw_stream:
self.__final_message_snapshot = accumulate_event(
event=sse_event,
current_snapshot=self.__final_message_snapshot,
)
events_to_fire = build_events(event=sse_event, message_snapshot=self.current_message_snapshot)
for event in events_to_fire:
yield event
async def __stream_text__(self) -> AsyncIterator[str]:
async for chunk in self:
if chunk.type == "content_block_delta" and chunk.delta.type == "text_delta":
yield chunk.delta.text
class AsyncMessageStreamManager:
"""Wrapper over AsyncMessageStream that is returned by `.stream()`
so that an async context manager can be used without `await`ing the
original client call.
```py
async with client.messages.stream(...) as stream:
async for chunk in stream:
...
```
"""
def __init__(
self,
api_request: Awaitable[AsyncStream[RawMessageStreamEvent]],
) -> None:
self.__stream: AsyncMessageStream | None = None
self.__api_request = api_request
async def __aenter__(self) -> AsyncMessageStream:
raw_stream = await self.__api_request
self.__stream = AsyncMessageStream(raw_stream)
return self.__stream
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
if self.__stream is not None:
await self.__stream.close()
def build_events(
*,
event: RawMessageStreamEvent,
message_snapshot: Message,
) -> list[MessageStreamEvent]:
events_to_fire: list[MessageStreamEvent] = []
if event.type == "message_start":
events_to_fire.append(event)
elif event.type == "message_delta":
events_to_fire.append(event)
elif event.type == "message_stop":
events_to_fire.append(build(MessageStopEvent, type="message_stop", message=message_snapshot))
elif event.type == "content_block_start":
events_to_fire.append(event)
elif event.type == "content_block_delta":
events_to_fire.append(event)
content_block = message_snapshot.content[event.index]
if event.delta.type == "text_delta":
if content_block.type == "text":
events_to_fire.append(
build(
TextEvent,
type="text",
text=event.delta.text,
snapshot=content_block.text,
)
)
elif event.delta.type == "input_json_delta":
if content_block.type == "tool_use":
events_to_fire.append(
build(
InputJsonEvent,
type="input_json",
partial_json=event.delta.partial_json,
snapshot=content_block.input,
)
)
elif event.delta.type == "citations_delta":
if content_block.type == "text":
events_to_fire.append(
build(
CitationEvent,
type="citation",
citation=event.delta.citation,
snapshot=content_block.citations or [],
)
)
elif event.delta.type == "thinking_delta":
if content_block.type == "thinking":
events_to_fire.append(
build(
ThinkingEvent,
type="thinking",
thinking=event.delta.thinking,
snapshot=content_block.thinking,
)
)
elif event.delta.type == "signature_delta":
if content_block.type == "thinking":
events_to_fire.append(
build(
SignatureEvent,
type="signature",
signature=content_block.signature,
)
)
pass
else:
# we only want exhaustive checking for linters, not at runtime
if TYPE_CHECKING: # type: ignore[unreachable]
assert_never(event.delta)
elif event.type == "content_block_stop":
content_block = message_snapshot.content[event.index]
events_to_fire.append(
build(ContentBlockStopEvent, type="content_block_stop", index=event.index, content_block=content_block),
)
else:
# we only want exhaustive checking for linters, not at runtime
if TYPE_CHECKING: # type: ignore[unreachable]
assert_never(event)
return events_to_fire
JSON_BUF_PROPERTY = "__json_buf"
def accumulate_event(
*,
event: RawMessageStreamEvent,
current_snapshot: Message | None,
) -> Message:
if not isinstance(cast(Any, event), BaseModel):
event = cast( # pyright: ignore[reportUnnecessaryCast]
RawMessageStreamEvent,
construct_type_unchecked(
type_=cast(Type[RawMessageStreamEvent], RawMessageStreamEvent),
value=event,
),
)
if not isinstance(cast(Any, event), BaseModel):
raise TypeError(f"Unexpected event runtime type, after deserialising twice - {event} - {type(event)}")
if current_snapshot is None:
if event.type == "message_start":
return Message.construct(**cast(Any, event.message.to_dict()))
raise RuntimeError(f'Unexpected event order, got {event.type} before "message_start"')
if event.type == "content_block_start":
# TODO: check index
current_snapshot.content.append(
cast(
ContentBlock,
construct_type(type_=ContentBlock, value=event.content_block.model_dump()),
),
)
elif event.type == "content_block_delta":
content = current_snapshot.content[event.index]
if event.delta.type == "text_delta":
if content.type == "text":
content.text += event.delta.text
elif event.delta.type == "input_json_delta":
if content.type == "tool_use":
from jiter import from_json
# we need to keep track of the raw JSON string as well so that we can
# re-parse it for each delta, for now we just store it as an untyped
# property on the snapshot
json_buf = cast(bytes, getattr(content, JSON_BUF_PROPERTY, b""))
json_buf += bytes(event.delta.partial_json, "utf-8")
if json_buf:
content.input = from_json(json_buf, partial_mode=True)
setattr(content, JSON_BUF_PROPERTY, json_buf)
elif event.delta.type == "citations_delta":
if content.type == "text":
if not content.citations:
content.citations = [event.delta.citation]
else:
content.citations.append(event.delta.citation)
elif event.delta.type == "thinking_delta":
if content.type == "thinking":
content.thinking += event.delta.thinking
elif event.delta.type == "signature_delta":
if content.type == "thinking":
content.signature = event.delta.signature
else:
# we only want exhaustive checking for linters, not at runtime
if TYPE_CHECKING: # type: ignore[unreachable]
assert_never(event.delta)
elif event.type == "message_delta":
current_snapshot.stop_reason = event.delta.stop_reason
current_snapshot.stop_sequence = event.delta.stop_sequence
current_snapshot.usage.output_tokens = event.usage.output_tokens
return current_snapshot