# -*- coding: utf-8 -*-
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import dataclasses
import sys
import textwrap
from typing import Any, Iterable, List, Optional, Union
import google.ai.generativelanguage as glm
from google.generativeai.client import get_default_discuss_client
from google.generativeai.client import get_default_discuss_async_client
from google.generativeai import string_utils
from google.generativeai.types import discuss_types
from google.generativeai.types import model_types
from google.generativeai.types import safety_types
def _make_message(content: discuss_types.MessageOptions) -> glm.Message:
"""Creates a `glm.Message` object from the provided content."""
if isinstance(content, glm.Message):
return content
if isinstance(content, str):
return glm.Message(content=content)
else:
return glm.Message(content)
def _make_messages(
messages: discuss_types.MessagesOptions,
) -> List[glm.Message]:
"""
Creates a list of `glm.Message` objects from the provided messages.
This function takes a variety of message content inputs, such as strings, dictionaries,
or `glm.Message` objects, and creates a list of `glm.Message` objects. It ensures that
the authors of the messages alternate appropriately. If authors are not provided,
default authors are assigned based on their position in the list.
Args:
messages: The messages to convert.
Returns:
A list of `glm.Message` objects with alternating authors.
"""
if isinstance(messages, (str, dict, glm.Message)):
messages = [_make_message(messages)]
else:
messages = [_make_message(message) for message in messages]
even_authors = set(msg.author for msg in messages[::2] if msg.author)
if not even_authors:
even_author = "0"
elif len(even_authors) == 1:
even_author = even_authors.pop()
else:
raise discuss_types.AuthorError("Authors are not strictly alternating")
odd_authors = set(msg.author for msg in messages[1::2] if msg.author)
if not odd_authors:
odd_author = "1"
elif len(odd_authors) == 1:
odd_author = odd_authors.pop()
else:
raise discuss_types.AuthorError("Authors are not strictly alternating")
if all(msg.author for msg in messages):
return messages
authors = [even_author, odd_author]
for i, msg in enumerate(messages):
msg.author = authors[i % 2]
return messages
def _make_example(item: discuss_types.ExampleOptions) -> glm.Example:
"""Creates a `glm.Example` object from the provided item."""
if isinstance(item, glm.Example):
return item
if isinstance(item, dict):
item = item.copy()
item["input"] = _make_message(item["input"])
item["output"] = _make_message(item["output"])
return glm.Example(item)
if isinstance(item, Iterable):
input, output = list(item)
return glm.Example(input=_make_message(input), output=_make_message(output))
# try anyway
return glm.Example(item)
def _make_examples_from_flat(
examples: List[discuss_types.MessageOptions],
) -> List[glm.Example]:
"""
Creates a list of `glm.Example` objects from a list of message options.
This function takes a list of `discuss_types.MessageOptions` and pairs them into
`glm.Example` objects. The input examples must be in pairs to create valid examples.
Args:
examples: The list of `discuss_types.MessageOptions`.
Returns:
A list of `glm.Example objects` created by pairing up the provided messages.
Raises:
ValueError: If the provided list of examples is not of even length.
"""
if len(examples) % 2 != 0:
raise ValueError(
textwrap.dedent(
f"""\
You must pass `Primer` objects, pairs of messages, or an *even* number of messages, got:
{len(examples)} messages"""
)
)
result = []
pair = []
for n, item in enumerate(examples):
msg = _make_message(item)
pair.append(msg)
if n % 2 == 0:
continue
primer = glm.Example(
input=pair[0],
output=pair[1],
)
result.append(primer)
pair = []
return result
def _make_examples(
examples: discuss_types.ExamplesOptions,
) -> List[glm.Example]:
"""
Creates a list of `glm.Example` objects from the provided examples.
This function takes various types of example content inputs and creates a list
of `glm.Example` objects. It handles the conversion of different input types and ensures
the appropriate structure for creating valid examples.
Args:
examples: The examples to convert.
Returns:
A list of `glm.Example` objects created from the provided examples.
"""
if isinstance(examples, glm.Example):
return [examples]
if isinstance(examples, dict):
return [_make_example(examples)]
examples = list(examples)
if not examples:
return examples
first = examples[0]
if isinstance(first, dict):
if "content" in first:
# These are `Messages`
return _make_examples_from_flat(examples)
else:
if not ("input" in first and "output" in first):
raise TypeError(
"To create an `Example` from a dict you must supply both `input` and an `output` keys"
)
else:
if isinstance(first, discuss_types.MESSAGE_OPTIONS):
return _make_examples_from_flat(examples)
result = []
for item in examples:
result.append(_make_example(item))
return result
def _make_message_prompt_dict(
prompt: discuss_types.MessagePromptOptions = None,
*,
context: str | None = None,
examples: discuss_types.ExamplesOptions | None = None,
messages: discuss_types.MessagesOptions | None = None,
) -> glm.MessagePrompt:
"""
Creates a `glm.MessagePrompt` object from the provided prompt components.
This function constructs a `glm.MessagePrompt` object using the provided `context`, `examples`,
or `messages`. It ensures the proper structure and handling of the input components.
Either pass a `prompt` or it's component `context`, `examples`, `messages`.
Args:
prompt: The complete prompt components.
context: The context for the prompt.
examples: The examples for the prompt.
messages: The messages for the prompt.
Returns:
A `glm.MessagePrompt` object created from the provided prompt components.
"""
if prompt is None:
prompt = dict(
context=context,
examples=examples,
messages=messages,
)
else:
flat_prompt = (context is not None) or (examples is not None) or (messages is not None)
if flat_prompt:
raise ValueError(
"You can't set `prompt`, and its fields `(context, examples, messages)`"
" at the same time"
)
if isinstance(prompt, glm.MessagePrompt):
return prompt
elif isinstance(prompt, dict): # Always check dict before Iterable.
pass
else:
prompt = {"messages": prompt}
keys = set(prompt.keys())
if not keys.issubset(discuss_types.MESSAGE_PROMPT_KEYS):
raise KeyError(
f"Found extra entries in the prompt dictionary: {keys - discuss_types.MESSAGE_PROMPT_KEYS}"
)
examples = prompt.get("examples", None)
if examples is not None:
prompt["examples"] = _make_examples(examples)
messages = prompt.get("messages", None)
if messages is not None:
prompt["messages"] = _make_messages(messages)
prompt = {k: v for k, v in prompt.items() if v is not None}
return prompt
def _make_message_prompt(
prompt: discuss_types.MessagePromptOptions = None,
*,
context: str | None = None,
examples: discuss_types.ExamplesOptions | None = None,
messages: discuss_types.MessagesOptions | None = None,
) -> glm.MessagePrompt:
"""Creates a `glm.MessagePrompt` object from the provided prompt components."""
prompt = _make_message_prompt_dict(
prompt=prompt, context=context, examples=examples, messages=messages
)
return glm.MessagePrompt(prompt)
def _make_generate_message_request(
*,
model: model_types.AnyModelNameOptions | None,
context: str | None = None,
examples: discuss_types.ExamplesOptions | None = None,
messages: discuss_types.MessagesOptions | None = None,
temperature: float | None = None,
candidate_count: int | None = None,
top_p: float | None = None,
top_k: float | None = None,
prompt: discuss_types.MessagePromptOptions | None = None,
) -> glm.GenerateMessageRequest:
"""Creates a `glm.GenerateMessageRequest` object for generating messages."""
model = model_types.make_model_name(model)
prompt = _make_message_prompt(
prompt=prompt, context=context, examples=examples, messages=messages
)
return glm.GenerateMessageRequest(
model=model,
prompt=prompt,
temperature=temperature,
top_p=top_p,
top_k=top_k,
candidate_count=candidate_count,
)
DEFAULT_DISCUSS_MODEL = "models/chat-bison-001"
def chat(
*,
model: model_types.AnyModelNameOptions | None = "models/chat-bison-001",
context: str | None = None,
examples: discuss_types.ExamplesOptions | None = None,
messages: discuss_types.MessagesOptions | None = None,
temperature: float | None = None,
candidate_count: int | None = None,
top_p: float | None = None,
top_k: float | None = None,
prompt: discuss_types.MessagePromptOptions | None = None,
client: glm.DiscussServiceClient | None = None,
request_options: dict[str, Any] | None = None,
) -> discuss_types.ChatResponse:
"""Calls the API and returns a `types.ChatResponse` containing the response.
Args:
model: Which model to call, as a string or a `types.Model`.
context: Text that should be provided to the model first, to ground the response.
If not empty, this `context` will be given to the model first before the
`examples` and `messages`.
This field can be a description of your prompt to the model to help provide
context and guide the responses.
Examples:
* "Translate the phrase from English to French."
* "Given a statement, classify the sentiment as happy, sad or neutral."
Anything included in this field will take precedence over history in `messages`
if the total input size exceeds the model's `Model.input_token_limit`.
examples: Examples of what the model should generate.
This includes both the user input and the response that the model should
emulate.
These `examples` are treated identically to conversation messages except
that they take precedence over the history in `messages`:
If the total input size exceeds the model's `input_token_limit` the input
will be truncated. Items will be dropped from `messages` before `examples`
messages: A snapshot of the conversation history sorted chronologically.
Turns alternate between two authors.
If the total input size exceeds the model's `input_token_limit` the input
will be truncated: The oldest items will be dropped from `messages`.
temperature: Controls the randomness of the output. Must be positive.
Typical values are in the range: `[0.0,1.0]`. Higher values produce a
more random and varied response. A temperature of zero will be deterministic.
candidate_count: The **maximum** number of generated response messages to return.
This value must be between `[1, 8]`, inclusive. If unset, this
will default to `1`.
Note: Only unique candidates are returned. Higher temperatures are more
likely to produce unique candidates. Setting `temperature=0.0` will always
return 1 candidate regardless of the `candidate_count`.
top_k: The API uses combined [nucleus](https://arxiv.org/abs/1904.09751) and
top-k sampling.
`top_k` sets the maximum number of tokens to sample from on each step.
top_p: The API uses combined [nucleus](https://arxiv.org/abs/1904.09751) and
top-k sampling.
`top_p` configures the nucleus sampling. It sets the maximum cumulative
probability of tokens to sample from.
For example, if the sorted probabilities are
`[0.5, 0.2, 0.1, 0.1, 0.05, 0.05]` a `top_p` of `0.8` will sample
as `[0.625, 0.25, 0.125, 0, 0, 0]`.
Typical values are in the `[0.9, 1.0]` range.
prompt: You may pass a `types.MessagePromptOptions` **instead** of a
setting `context`/`examples`/`messages`, but not both.
client: If you're not relying on the default client, you pass a
`glm.DiscussServiceClient` instead.
request_options: Options for the request.
Returns:
A `types.ChatResponse` containing the model's reply.
"""
request = _make_generate_message_request(
model=model,
context=context,
examples=examples,
messages=messages,
temperature=temperature,
candidate_count=candidate_count,
top_p=top_p,
top_k=top_k,
prompt=prompt,
)
return _generate_response(client=client, request=request, request_options=request_options)
@string_utils.set_doc(chat.__doc__)
async def chat_async(
*,
model: model_types.AnyModelNameOptions | None = "models/chat-bison-001",
context: str | None = None,
examples: discuss_types.ExamplesOptions | None = None,
messages: discuss_types.MessagesOptions | None = None,
temperature: float | None = None,
candidate_count: int | None = None,
top_p: float | None = None,
top_k: float | None = None,
prompt: discuss_types.MessagePromptOptions | None = None,
client: glm.DiscussServiceAsyncClient | None = None,
request_options: dict[str, Any] | None = None,
) -> discuss_types.ChatResponse:
request = _make_generate_message_request(
model=model,
context=context,
examples=examples,
messages=messages,
temperature=temperature,
candidate_count=candidate_count,
top_p=top_p,
top_k=top_k,
prompt=prompt,
)
return await _generate_response_async(
client=client, request=request, request_options=request_options
)
if (sys.version_info.major, sys.version_info.minor) >= (3, 10):
DATACLASS_KWARGS = {"kw_only": True}
else:
DATACLASS_KWARGS = {}
@string_utils.prettyprint
@string_utils.set_doc(discuss_types.ChatResponse.__doc__)
@dataclasses.dataclass(**DATACLASS_KWARGS, init=False)
class ChatResponse(discuss_types.ChatResponse):
_client: glm.DiscussServiceClient | None = dataclasses.field(default=lambda: None, repr=False)
def __init__(self, **kwargs):
for key, value in kwargs.items():
setattr(self, key, value)
@property
@string_utils.set_doc(discuss_types.ChatResponse.last.__doc__)
def last(self) -> str | None:
if self.messages[-1]:
return self.messages[-1]["content"]
else:
return None
@last.setter
def last(self, message: discuss_types.MessageOptions):
message = _make_message(message)
message = type(message).to_dict(message)
self.messages[-1] = message
@string_utils.set_doc(discuss_types.ChatResponse.reply.__doc__)
def reply(
self,
message: discuss_types.MessageOptions,
request_options: dict[str, Any] | None = None,
) -> discuss_types.ChatResponse:
if isinstance(self._client, glm.DiscussServiceAsyncClient):
raise TypeError(f"reply can't be called on an async client, use reply_async instead.")
if self.last is None:
raise ValueError(
"The last response from the model did not return any candidates.\n"
"Check the `.filters` attribute to see why the responses were filtered:\n"
f"{self.filters}"
)
request = self.to_dict()
request.pop("candidates")
request.pop("filters", None)
request["messages"] = list(request["messages"])
request["messages"].append(_make_message(message))
request = _make_generate_message_request(**request)
return _generate_response(
request=request, client=self._client, request_options=request_options
)
@string_utils.set_doc(discuss_types.ChatResponse.reply.__doc__)
async def reply_async(
self, message: discuss_types.MessageOptions
) -> discuss_types.ChatResponse:
if isinstance(self._client, glm.DiscussServiceClient):
raise TypeError(
f"reply_async can't be called on a non-async client, use reply instead."
)
request = self.to_dict()
request.pop("candidates")
request.pop("filters", None)
request["messages"] = list(request["messages"])
request["messages"].append(_make_message(message))
request = _make_generate_message_request(**request)
return await _generate_response_async(request=request, client=self._client)
def _build_chat_response(
request: glm.GenerateMessageRequest,
response: glm.GenerateMessageResponse,
client: glm.DiscussServiceClient | glm.DiscussServiceAsyncClient,
) -> ChatResponse:
request = type(request).to_dict(request)
prompt = request.pop("prompt")
request["examples"] = prompt["examples"]
request["context"] = prompt["context"]
request["messages"] = prompt["messages"]
response = type(response).to_dict(response)
response.pop("messages")
response["filters"] = safety_types.convert_filters_to_enums(response["filters"])
if response["candidates"]:
last = response["candidates"][0]
else:
last = None
request["messages"].append(last)
request.setdefault("temperature", None)
request.setdefault("candidate_count", None)
return ChatResponse(_client=client, **response, **request) # pytype: disable=missing-parameter
def _generate_response(
request: glm.GenerateMessageRequest,
client: glm.DiscussServiceClient | None = None,
request_options: dict[str, Any] | None = None,
) -> ChatResponse:
if request_options is None:
request_options = {}
if client is None:
client = get_default_discuss_client()
response = client.generate_message(request, **request_options)
return _build_chat_response(request, response, client)
async def _generate_response_async(
request: glm.GenerateMessageRequest,
client: glm.DiscussServiceAsyncClient | None = None,
request_options: dict[str, Any] | None = None,
) -> ChatResponse:
if request_options is None:
request_options = {}
if client is None:
client = get_default_discuss_async_client()
response = await client.generate_message(request, **request_options)
return _build_chat_response(request, response, client)
def count_message_tokens(
*,
prompt: discuss_types.MessagePromptOptions = None,
context: str | None = None,
examples: discuss_types.ExamplesOptions | None = None,
messages: discuss_types.MessagesOptions | None = None,
model: model_types.AnyModelNameOptions = DEFAULT_DISCUSS_MODEL,
client: glm.DiscussServiceAsyncClient | None = None,
request_options: dict[str, Any] | None = None,
) -> discuss_types.TokenCount:
model = model_types.make_model_name(model)
prompt = _make_message_prompt(prompt, context=context, examples=examples, messages=messages)
if request_options is None:
request_options = {}
if client is None:
client = get_default_discuss_client()
result = client.count_message_tokens(model=model, prompt=prompt, **request_options)
return type(result).to_dict(result)