# -*- coding: utf-8 -*- # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations import dataclasses import sys import textwrap from typing import Any, Iterable, List, Optional, Union import google.ai.generativelanguage as glm from google.generativeai.client import get_default_discuss_client from google.generativeai.client import get_default_discuss_async_client from google.generativeai import string_utils from google.generativeai.types import discuss_types from google.generativeai.types import model_types from google.generativeai.types import safety_types def _make_message(content: discuss_types.MessageOptions) -> glm.Message: """Creates a `glm.Message` object from the provided content.""" if isinstance(content, glm.Message): return content if isinstance(content, str): return glm.Message(content=content) else: return glm.Message(content) def _make_messages( messages: discuss_types.MessagesOptions, ) -> List[glm.Message]: """ Creates a list of `glm.Message` objects from the provided messages. This function takes a variety of message content inputs, such as strings, dictionaries, or `glm.Message` objects, and creates a list of `glm.Message` objects. It ensures that the authors of the messages alternate appropriately. If authors are not provided, default authors are assigned based on their position in the list. Args: messages: The messages to convert. Returns: A list of `glm.Message` objects with alternating authors. """ if isinstance(messages, (str, dict, glm.Message)): messages = [_make_message(messages)] else: messages = [_make_message(message) for message in messages] even_authors = set(msg.author for msg in messages[::2] if msg.author) if not even_authors: even_author = "0" elif len(even_authors) == 1: even_author = even_authors.pop() else: raise discuss_types.AuthorError("Authors are not strictly alternating") odd_authors = set(msg.author for msg in messages[1::2] if msg.author) if not odd_authors: odd_author = "1" elif len(odd_authors) == 1: odd_author = odd_authors.pop() else: raise discuss_types.AuthorError("Authors are not strictly alternating") if all(msg.author for msg in messages): return messages authors = [even_author, odd_author] for i, msg in enumerate(messages): msg.author = authors[i % 2] return messages def _make_example(item: discuss_types.ExampleOptions) -> glm.Example: """Creates a `glm.Example` object from the provided item.""" if isinstance(item, glm.Example): return item if isinstance(item, dict): item = item.copy() item["input"] = _make_message(item["input"]) item["output"] = _make_message(item["output"]) return glm.Example(item) if isinstance(item, Iterable): input, output = list(item) return glm.Example(input=_make_message(input), output=_make_message(output)) # try anyway return glm.Example(item) def _make_examples_from_flat( examples: List[discuss_types.MessageOptions], ) -> List[glm.Example]: """ Creates a list of `glm.Example` objects from a list of message options. This function takes a list of `discuss_types.MessageOptions` and pairs them into `glm.Example` objects. The input examples must be in pairs to create valid examples. Args: examples: The list of `discuss_types.MessageOptions`. Returns: A list of `glm.Example objects` created by pairing up the provided messages. Raises: ValueError: If the provided list of examples is not of even length. """ if len(examples) % 2 != 0: raise ValueError( textwrap.dedent( f"""\ You must pass `Primer` objects, pairs of messages, or an *even* number of messages, got: {len(examples)} messages""" ) ) result = [] pair = [] for n, item in enumerate(examples): msg = _make_message(item) pair.append(msg) if n % 2 == 0: continue primer = glm.Example( input=pair[0], output=pair[1], ) result.append(primer) pair = [] return result def _make_examples( examples: discuss_types.ExamplesOptions, ) -> List[glm.Example]: """ Creates a list of `glm.Example` objects from the provided examples. This function takes various types of example content inputs and creates a list of `glm.Example` objects. It handles the conversion of different input types and ensures the appropriate structure for creating valid examples. Args: examples: The examples to convert. Returns: A list of `glm.Example` objects created from the provided examples. """ if isinstance(examples, glm.Example): return [examples] if isinstance(examples, dict): return [_make_example(examples)] examples = list(examples) if not examples: return examples first = examples[0] if isinstance(first, dict): if "content" in first: # These are `Messages` return _make_examples_from_flat(examples) else: if not ("input" in first and "output" in first): raise TypeError( "To create an `Example` from a dict you must supply both `input` and an `output` keys" ) else: if isinstance(first, discuss_types.MESSAGE_OPTIONS): return _make_examples_from_flat(examples) result = [] for item in examples: result.append(_make_example(item)) return result def _make_message_prompt_dict( prompt: discuss_types.MessagePromptOptions = None, *, context: str | None = None, examples: discuss_types.ExamplesOptions | None = None, messages: discuss_types.MessagesOptions | None = None, ) -> glm.MessagePrompt: """ Creates a `glm.MessagePrompt` object from the provided prompt components. This function constructs a `glm.MessagePrompt` object using the provided `context`, `examples`, or `messages`. It ensures the proper structure and handling of the input components. Either pass a `prompt` or it's component `context`, `examples`, `messages`. Args: prompt: The complete prompt components. context: The context for the prompt. examples: The examples for the prompt. messages: The messages for the prompt. Returns: A `glm.MessagePrompt` object created from the provided prompt components. """ if prompt is None: prompt = dict( context=context, examples=examples, messages=messages, ) else: flat_prompt = (context is not None) or (examples is not None) or (messages is not None) if flat_prompt: raise ValueError( "You can't set `prompt`, and its fields `(context, examples, messages)`" " at the same time" ) if isinstance(prompt, glm.MessagePrompt): return prompt elif isinstance(prompt, dict): # Always check dict before Iterable. pass else: prompt = {"messages": prompt} keys = set(prompt.keys()) if not keys.issubset(discuss_types.MESSAGE_PROMPT_KEYS): raise KeyError( f"Found extra entries in the prompt dictionary: {keys - discuss_types.MESSAGE_PROMPT_KEYS}" ) examples = prompt.get("examples", None) if examples is not None: prompt["examples"] = _make_examples(examples) messages = prompt.get("messages", None) if messages is not None: prompt["messages"] = _make_messages(messages) prompt = {k: v for k, v in prompt.items() if v is not None} return prompt def _make_message_prompt( prompt: discuss_types.MessagePromptOptions = None, *, context: str | None = None, examples: discuss_types.ExamplesOptions | None = None, messages: discuss_types.MessagesOptions | None = None, ) -> glm.MessagePrompt: """Creates a `glm.MessagePrompt` object from the provided prompt components.""" prompt = _make_message_prompt_dict( prompt=prompt, context=context, examples=examples, messages=messages ) return glm.MessagePrompt(prompt) def _make_generate_message_request( *, model: model_types.AnyModelNameOptions | None, context: str | None = None, examples: discuss_types.ExamplesOptions | None = None, messages: discuss_types.MessagesOptions | None = None, temperature: float | None = None, candidate_count: int | None = None, top_p: float | None = None, top_k: float | None = None, prompt: discuss_types.MessagePromptOptions | None = None, ) -> glm.GenerateMessageRequest: """Creates a `glm.GenerateMessageRequest` object for generating messages.""" model = model_types.make_model_name(model) prompt = _make_message_prompt( prompt=prompt, context=context, examples=examples, messages=messages ) return glm.GenerateMessageRequest( model=model, prompt=prompt, temperature=temperature, top_p=top_p, top_k=top_k, candidate_count=candidate_count, ) DEFAULT_DISCUSS_MODEL = "models/chat-bison-001" def chat( *, model: model_types.AnyModelNameOptions | None = "models/chat-bison-001", context: str | None = None, examples: discuss_types.ExamplesOptions | None = None, messages: discuss_types.MessagesOptions | None = None, temperature: float | None = None, candidate_count: int | None = None, top_p: float | None = None, top_k: float | None = None, prompt: discuss_types.MessagePromptOptions | None = None, client: glm.DiscussServiceClient | None = None, request_options: dict[str, Any] | None = None, ) -> discuss_types.ChatResponse: """Calls the API and returns a `types.ChatResponse` containing the response. Args: model: Which model to call, as a string or a `types.Model`. context: Text that should be provided to the model first, to ground the response. If not empty, this `context` will be given to the model first before the `examples` and `messages`. This field can be a description of your prompt to the model to help provide context and guide the responses. Examples: * "Translate the phrase from English to French." * "Given a statement, classify the sentiment as happy, sad or neutral." Anything included in this field will take precedence over history in `messages` if the total input size exceeds the model's `Model.input_token_limit`. examples: Examples of what the model should generate. This includes both the user input and the response that the model should emulate. These `examples` are treated identically to conversation messages except that they take precedence over the history in `messages`: If the total input size exceeds the model's `input_token_limit` the input will be truncated. Items will be dropped from `messages` before `examples` messages: A snapshot of the conversation history sorted chronologically. Turns alternate between two authors. If the total input size exceeds the model's `input_token_limit` the input will be truncated: The oldest items will be dropped from `messages`. temperature: Controls the randomness of the output. Must be positive. Typical values are in the range: `[0.0,1.0]`. Higher values produce a more random and varied response. A temperature of zero will be deterministic. candidate_count: The **maximum** number of generated response messages to return. This value must be between `[1, 8]`, inclusive. If unset, this will default to `1`. Note: Only unique candidates are returned. Higher temperatures are more likely to produce unique candidates. Setting `temperature=0.0` will always return 1 candidate regardless of the `candidate_count`. top_k: The API uses combined [nucleus](https://arxiv.org/abs/1904.09751) and top-k sampling. `top_k` sets the maximum number of tokens to sample from on each step. top_p: The API uses combined [nucleus](https://arxiv.org/abs/1904.09751) and top-k sampling. `top_p` configures the nucleus sampling. It sets the maximum cumulative probability of tokens to sample from. For example, if the sorted probabilities are `[0.5, 0.2, 0.1, 0.1, 0.05, 0.05]` a `top_p` of `0.8` will sample as `[0.625, 0.25, 0.125, 0, 0, 0]`. Typical values are in the `[0.9, 1.0]` range. prompt: You may pass a `types.MessagePromptOptions` **instead** of a setting `context`/`examples`/`messages`, but not both. client: If you're not relying on the default client, you pass a `glm.DiscussServiceClient` instead. request_options: Options for the request. Returns: A `types.ChatResponse` containing the model's reply. """ request = _make_generate_message_request( model=model, context=context, examples=examples, messages=messages, temperature=temperature, candidate_count=candidate_count, top_p=top_p, top_k=top_k, prompt=prompt, ) return _generate_response(client=client, request=request, request_options=request_options) @string_utils.set_doc(chat.__doc__) async def chat_async( *, model: model_types.AnyModelNameOptions | None = "models/chat-bison-001", context: str | None = None, examples: discuss_types.ExamplesOptions | None = None, messages: discuss_types.MessagesOptions | None = None, temperature: float | None = None, candidate_count: int | None = None, top_p: float | None = None, top_k: float | None = None, prompt: discuss_types.MessagePromptOptions | None = None, client: glm.DiscussServiceAsyncClient | None = None, request_options: dict[str, Any] | None = None, ) -> discuss_types.ChatResponse: request = _make_generate_message_request( model=model, context=context, examples=examples, messages=messages, temperature=temperature, candidate_count=candidate_count, top_p=top_p, top_k=top_k, prompt=prompt, ) return await _generate_response_async( client=client, request=request, request_options=request_options ) if (sys.version_info.major, sys.version_info.minor) >= (3, 10): DATACLASS_KWARGS = {"kw_only": True} else: DATACLASS_KWARGS = {} @string_utils.prettyprint @string_utils.set_doc(discuss_types.ChatResponse.__doc__) @dataclasses.dataclass(**DATACLASS_KWARGS, init=False) class ChatResponse(discuss_types.ChatResponse): _client: glm.DiscussServiceClient | None = dataclasses.field(default=lambda: None, repr=False) def __init__(self, **kwargs): for key, value in kwargs.items(): setattr(self, key, value) @property @string_utils.set_doc(discuss_types.ChatResponse.last.__doc__) def last(self) -> str | None: if self.messages[-1]: return self.messages[-1]["content"] else: return None @last.setter def last(self, message: discuss_types.MessageOptions): message = _make_message(message) message = type(message).to_dict(message) self.messages[-1] = message @string_utils.set_doc(discuss_types.ChatResponse.reply.__doc__) def reply( self, message: discuss_types.MessageOptions, request_options: dict[str, Any] | None = None, ) -> discuss_types.ChatResponse: if isinstance(self._client, glm.DiscussServiceAsyncClient): raise TypeError(f"reply can't be called on an async client, use reply_async instead.") if self.last is None: raise ValueError( "The last response from the model did not return any candidates.\n" "Check the `.filters` attribute to see why the responses were filtered:\n" f"{self.filters}" ) request = self.to_dict() request.pop("candidates") request.pop("filters", None) request["messages"] = list(request["messages"]) request["messages"].append(_make_message(message)) request = _make_generate_message_request(**request) return _generate_response( request=request, client=self._client, request_options=request_options ) @string_utils.set_doc(discuss_types.ChatResponse.reply.__doc__) async def reply_async( self, message: discuss_types.MessageOptions ) -> discuss_types.ChatResponse: if isinstance(self._client, glm.DiscussServiceClient): raise TypeError( f"reply_async can't be called on a non-async client, use reply instead." ) request = self.to_dict() request.pop("candidates") request.pop("filters", None) request["messages"] = list(request["messages"]) request["messages"].append(_make_message(message)) request = _make_generate_message_request(**request) return await _generate_response_async(request=request, client=self._client) def _build_chat_response( request: glm.GenerateMessageRequest, response: glm.GenerateMessageResponse, client: glm.DiscussServiceClient | glm.DiscussServiceAsyncClient, ) -> ChatResponse: request = type(request).to_dict(request) prompt = request.pop("prompt") request["examples"] = prompt["examples"] request["context"] = prompt["context"] request["messages"] = prompt["messages"] response = type(response).to_dict(response) response.pop("messages") response["filters"] = safety_types.convert_filters_to_enums(response["filters"]) if response["candidates"]: last = response["candidates"][0] else: last = None request["messages"].append(last) request.setdefault("temperature", None) request.setdefault("candidate_count", None) return ChatResponse(_client=client, **response, **request) # pytype: disable=missing-parameter def _generate_response( request: glm.GenerateMessageRequest, client: glm.DiscussServiceClient | None = None, request_options: dict[str, Any] | None = None, ) -> ChatResponse: if request_options is None: request_options = {} if client is None: client = get_default_discuss_client() response = client.generate_message(request, **request_options) return _build_chat_response(request, response, client) async def _generate_response_async( request: glm.GenerateMessageRequest, client: glm.DiscussServiceAsyncClient | None = None, request_options: dict[str, Any] | None = None, ) -> ChatResponse: if request_options is None: request_options = {} if client is None: client = get_default_discuss_async_client() response = await client.generate_message(request, **request_options) return _build_chat_response(request, response, client) def count_message_tokens( *, prompt: discuss_types.MessagePromptOptions = None, context: str | None = None, examples: discuss_types.ExamplesOptions | None = None, messages: discuss_types.MessagesOptions | None = None, model: model_types.AnyModelNameOptions = DEFAULT_DISCUSS_MODEL, client: glm.DiscussServiceAsyncClient | None = None, request_options: dict[str, Any] | None = None, ) -> discuss_types.TokenCount: model = model_types.make_model_name(model) prompt = _make_message_prompt(prompt, context=context, examples=examples, messages=messages) if request_options is None: request_options = {} if client is None: client = get_default_discuss_client() result = client.count_message_tokens(model=model, prompt=prompt, **request_options) return type(result).to_dict(result)
Memory