# coding=utf-8 # Copyright 2023 The HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations import asyncio from queue import Queue from typing import TYPE_CHECKING, Optional if TYPE_CHECKING: from ..models.auto import AutoTokenizer class BaseStreamer: """ Base class from which `.generate()` streamers should inherit. """ def put(self, value): """Function that is called by `.generate()` to push new tokens""" raise NotImplementedError() def end(self): """Function that is called by `.generate()` to signal the end of generation""" raise NotImplementedError() class TextStreamer(BaseStreamer): """ Simple text streamer that prints the token(s) to stdout as soon as entire words are formed. <Tip warning={true}> The API for the streamer classes is still under development and may change in the future. </Tip> Parameters: tokenizer (`AutoTokenizer`): The tokenized used to decode the tokens. skip_prompt (`bool`, *optional*, defaults to `False`): Whether to skip the prompt to `.generate()` or not. Useful e.g. for chatbots. decode_kwargs (`dict`, *optional*): Additional keyword arguments to pass to the tokenizer's `decode` method. Examples: ```python >>> from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer >>> tok = AutoTokenizer.from_pretrained("openai-community/gpt2") >>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2") >>> inputs = tok(["An increasing sequence: one,"], return_tensors="pt") >>> streamer = TextStreamer(tok) >>> # Despite returning the usual output, the streamer will also print the generated text to stdout. >>> _ = model.generate(**inputs, streamer=streamer, max_new_tokens=20) An increasing sequence: one, two, three, four, five, six, seven, eight, nine, ten, eleven, ``` """ def __init__(self, tokenizer: "AutoTokenizer", skip_prompt: bool = False, **decode_kwargs): self.tokenizer = tokenizer self.skip_prompt = skip_prompt self.decode_kwargs = decode_kwargs # variables used in the streaming process self.token_cache = [] self.print_len = 0 self.next_tokens_are_prompt = True def put(self, value): """ Receives tokens, decodes them, and prints them to stdout as soon as they form entire words. """ if len(value.shape) > 1 and value.shape[0] > 1: raise ValueError("TextStreamer only supports batch size 1") elif len(value.shape) > 1: value = value[0] if self.skip_prompt and self.next_tokens_are_prompt: self.next_tokens_are_prompt = False return # Add the new token to the cache and decodes the entire thing. self.token_cache.extend(value.tolist()) text = self.tokenizer.decode(self.token_cache, **self.decode_kwargs) # After the symbol for a new line, we flush the cache. if text.endswith("\n"): printable_text = text[self.print_len :] self.token_cache = [] self.print_len = 0 # If the last token is a CJK character, we print the characters. elif len(text) > 0 and self._is_chinese_char(ord(text[-1])): printable_text = text[self.print_len :] self.print_len += len(printable_text) # Otherwise, prints until the last space char (simple heuristic to avoid printing incomplete words, # which may change with the subsequent token -- there are probably smarter ways to do this!) else: printable_text = text[self.print_len : text.rfind(" ") + 1] self.print_len += len(printable_text) self.on_finalized_text(printable_text) def end(self): """Flushes any remaining cache and prints a newline to stdout.""" # Flush the cache, if it exists if len(self.token_cache) > 0: text = self.tokenizer.decode(self.token_cache, **self.decode_kwargs) printable_text = text[self.print_len :] self.token_cache = [] self.print_len = 0 else: printable_text = "" self.next_tokens_are_prompt = True self.on_finalized_text(printable_text, stream_end=True) def on_finalized_text(self, text: str, stream_end: bool = False): """Prints the new text to stdout. If the stream is ending, also prints a newline.""" print(text, flush=True, end="" if not stream_end else None) def _is_chinese_char(self, cp): """Checks whether CP is the codepoint of a CJK character.""" # This defines a "chinese character" as anything in the CJK Unicode block: # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) # # Note that the CJK Unicode block is NOT all Japanese and Korean characters, # despite its name. The modern Korean Hangul alphabet is a different block, # as is Japanese Hiragana and Katakana. Those alphabets are used to write # space-separated words, so they are not treated specially and handled # like the all of the other languages. if ( (cp >= 0x4E00 and cp <= 0x9FFF) or (cp >= 0x3400 and cp <= 0x4DBF) # or (cp >= 0x20000 and cp <= 0x2A6DF) # or (cp >= 0x2A700 and cp <= 0x2B73F) # or (cp >= 0x2B740 and cp <= 0x2B81F) # or (cp >= 0x2B820 and cp <= 0x2CEAF) # or (cp >= 0xF900 and cp <= 0xFAFF) or (cp >= 0x2F800 and cp <= 0x2FA1F) # ): # return True return False class TextIteratorStreamer(TextStreamer): """ Streamer that stores print-ready text in a queue, to be used by a downstream application as an iterator. This is useful for applications that benefit from acessing the generated text in a non-blocking way (e.g. in an interactive Gradio demo). <Tip warning={true}> The API for the streamer classes is still under development and may change in the future. </Tip> Parameters: tokenizer (`AutoTokenizer`): The tokenized used to decode the tokens. skip_prompt (`bool`, *optional*, defaults to `False`): Whether to skip the prompt to `.generate()` or not. Useful e.g. for chatbots. timeout (`float`, *optional*): The timeout for the text queue. If `None`, the queue will block indefinitely. Useful to handle exceptions in `.generate()`, when it is called in a separate thread. decode_kwargs (`dict`, *optional*): Additional keyword arguments to pass to the tokenizer's `decode` method. Examples: ```python >>> from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer >>> from threading import Thread >>> tok = AutoTokenizer.from_pretrained("openai-community/gpt2") >>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2") >>> inputs = tok(["An increasing sequence: one,"], return_tensors="pt") >>> streamer = TextIteratorStreamer(tok) >>> # Run the generation in a separate thread, so that we can fetch the generated text in a non-blocking way. >>> generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=20) >>> thread = Thread(target=model.generate, kwargs=generation_kwargs) >>> thread.start() >>> generated_text = "" >>> for new_text in streamer: ... generated_text += new_text >>> generated_text 'An increasing sequence: one, two, three, four, five, six, seven, eight, nine, ten, eleven,' ``` """ def __init__( self, tokenizer: "AutoTokenizer", skip_prompt: bool = False, timeout: Optional[float] = None, **decode_kwargs ): super().__init__(tokenizer, skip_prompt, **decode_kwargs) self.text_queue = Queue() self.stop_signal = None self.timeout = timeout def on_finalized_text(self, text: str, stream_end: bool = False): """Put the new text in the queue. If the stream is ending, also put a stop signal in the queue.""" self.text_queue.put(text, timeout=self.timeout) if stream_end: self.text_queue.put(self.stop_signal, timeout=self.timeout) def __iter__(self): return self def __next__(self): value = self.text_queue.get(timeout=self.timeout) if value == self.stop_signal: raise StopIteration() else: return value class AsyncTextIteratorStreamer(TextStreamer): """ Streamer that stores print-ready text in a queue, to be used by a downstream application as an async iterator. This is useful for applications that benefit from acessing the generated text asynchronously (e.g. in an interactive Gradio demo). <Tip warning={true}> The API for the streamer classes is still under development and may change in the future. </Tip> Parameters: tokenizer (`AutoTokenizer`): The tokenized used to decode the tokens. skip_prompt (`bool`, *optional*, defaults to `False`): Whether to skip the prompt to `.generate()` or not. Useful e.g. for chatbots. timeout (`float`, *optional*): The timeout for the text queue. If `None`, the queue will block indefinitely. Useful to handle exceptions in `.generate()`, when it is called in a separate thread. decode_kwargs (`dict`, *optional*): Additional keyword arguments to pass to the tokenizer's `decode` method. Raises: TimeoutError: If token generation time exceeds timeout value. Examples: ```python >>> from transformers import AutoModelForCausalLM, AutoTokenizer, AsyncTextIteratorStreamer >>> from threading import Thread >>> import asyncio >>> tok = AutoTokenizer.from_pretrained("openai-community/gpt2") >>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2") >>> inputs = tok(["An increasing sequence: one,"], return_tensors="pt") >>> # Run the generation in a separate thread, so that we can fetch the generated text in a non-blocking way. >>> async def main(): ... # Important: AsyncTextIteratorStreamer must be initialized inside a coroutine! ... streamer = AsyncTextIteratorStreamer(tok) ... generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=20) ... thread = Thread(target=model.generate, kwargs=generation_kwargs) ... thread.start() ... generated_text = "" ... async for new_text in streamer: ... generated_text += new_text >>> print(generated_text) >>> asyncio.run(main()) An increasing sequence: one, two, three, four, five, six, seven, eight, nine, ten, eleven, ``` """ def __init__( self, tokenizer: "AutoTokenizer", skip_prompt: bool = False, timeout: Optional[float] = None, **decode_kwargs ): super().__init__(tokenizer, skip_prompt, **decode_kwargs) self.text_queue = asyncio.Queue() self.stop_signal = None self.timeout = timeout self.loop = asyncio.get_running_loop() self.has_asyncio_timeout = hasattr(asyncio, "timeout") def on_finalized_text(self, text: str, stream_end: bool = False): """Put the new text in the queue. If the stream is ending, also put a stop signal in the queue.""" self.loop.call_soon_threadsafe(self.text_queue.put_nowait, text) if stream_end: self.loop.call_soon_threadsafe(self.text_queue.put_nowait, self.stop_signal) def __aiter__(self): return self async def __anext__(self): try: if self.has_asyncio_timeout: async with asyncio.timeout(self.timeout): value = await self.text_queue.get() else: value = await asyncio.wait_for(self.text_queue.get(), timeout=self.timeout) except asyncio.TimeoutError: raise TimeoutError() else: if value == self.stop_signal: raise StopAsyncIteration() else: return value
Memory