# -*- coding: utf-8 -*-
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import dataclasses
from collections.abc import Iterable, Sequence
import itertools
from typing import Any, Iterable, overload, TypeVar
import google.ai.generativelanguage as glm
from google.generativeai.client import get_default_text_client
from google.generativeai import string_utils
from google.generativeai.types import text_types
from google.generativeai.types import model_types
from google.generativeai import models
from google.generativeai.types import safety_types
DEFAULT_TEXT_MODEL = "models/text-bison-001"
EMBEDDING_MAX_BATCH_SIZE = 100
try:
# python 3.12+
_batched = itertools.batched # type: ignore
except AttributeError:
T = TypeVar("T")
def _batched(iterable: Iterable[T], n: int) -> Iterable[list[T]]:
if n < 1:
raise ValueError(f"Batch size `n` must be >1, got: {n}")
batch = []
for item in iterable:
batch.append(item)
if len(batch) == n:
yield batch
batch = []
if batch:
yield batch
def _make_text_prompt(prompt: str | dict[str, str]) -> glm.TextPrompt:
"""
Creates a `glm.TextPrompt` object based on the provided prompt input.
Args:
prompt: The prompt input, either a string or a dictionary.
Returns:
glm.TextPrompt: A TextPrompt object containing the prompt text.
Raises:
TypeError: If the provided prompt is neither a string nor a dictionary.
"""
if isinstance(prompt, str):
return glm.TextPrompt(text=prompt)
elif isinstance(prompt, dict):
return glm.TextPrompt(prompt)
else:
TypeError("Expected string or dictionary for text prompt.")
def _make_generate_text_request(
*,
model: model_types.AnyModelNameOptions = DEFAULT_TEXT_MODEL,
prompt: str | None = None,
temperature: float | None = None,
candidate_count: int | None = None,
max_output_tokens: int | None = None,
top_p: int | None = None,
top_k: int | None = None,
safety_settings: safety_types.SafetySettingOptions | None = None,
stop_sequences: str | Iterable[str] | None = None,
) -> glm.GenerateTextRequest:
"""
Creates a `glm.GenerateTextRequest` object based on the provided parameters.
This function generates a `glm.GenerateTextRequest` object with the specified
parameters. It prepares the input parameters and creates a request that can be
used for generating text using the chosen model.
Args:
model: The model to use for text generation.
prompt: The prompt for text generation. Defaults to None.
temperature: The temperature for randomness in generation. Defaults to None.
candidate_count: The number of candidates to consider. Defaults to None.
max_output_tokens: The maximum number of output tokens. Defaults to None.
top_p: The nucleus sampling probability threshold. Defaults to None.
top_k: The top-k sampling parameter. Defaults to None.
safety_settings: Safety settings for generated text. Defaults to None.
stop_sequences: Stop sequences to halt text generation. Can be a string
or iterable of strings. Defaults to None.
Returns:
`glm.GenerateTextRequest`: A `GenerateTextRequest` object configured with the specified parameters.
"""
model = model_types.make_model_name(model)
prompt = _make_text_prompt(prompt=prompt)
safety_settings = safety_types.normalize_safety_settings(
safety_settings, harm_category_set="old"
)
if isinstance(stop_sequences, str):
stop_sequences = [stop_sequences]
if stop_sequences:
stop_sequences = list(stop_sequences)
return glm.GenerateTextRequest(
model=model,
prompt=prompt,
temperature=temperature,
candidate_count=candidate_count,
max_output_tokens=max_output_tokens,
top_p=top_p,
top_k=top_k,
safety_settings=safety_settings,
stop_sequences=stop_sequences,
)
def generate_text(
*,
model: model_types.AnyModelNameOptions = DEFAULT_TEXT_MODEL,
prompt: str,
temperature: float | None = None,
candidate_count: int | None = None,
max_output_tokens: int | None = None,
top_p: float | None = None,
top_k: float | None = None,
safety_settings: safety_types.SafetySettingOptions | None = None,
stop_sequences: str | Iterable[str] | None = None,
client: glm.TextServiceClient | None = None,
request_options: dict[str, Any] | None = None,
) -> text_types.Completion:
"""Calls the API and returns a `types.Completion` containing the response.
Args:
model: Which model to call, as a string or a `types.Model`.
prompt: Free-form input text given to the model. Given a prompt, the model will
generate text that completes the input text.
temperature: Controls the randomness of the output. Must be positive.
Typical values are in the range: `[0.0,1.0]`. Higher values produce a
more random and varied response. A temperature of zero will be deterministic.
candidate_count: The **maximum** number of generated response messages to return.
This value must be between `[1, 8]`, inclusive. If unset, this
will default to `1`.
Note: Only unique candidates are returned. Higher temperatures are more
likely to produce unique candidates. Setting `temperature=0.0` will always
return 1 candidate regardless of the `candidate_count`.
max_output_tokens: Maximum number of tokens to include in a candidate. Must be greater
than zero. If unset, will default to 64.
top_k: The API uses combined [nucleus](https://arxiv.org/abs/1904.09751) and top-k sampling.
`top_k` sets the maximum number of tokens to sample from on each step.
top_p: The API uses combined [nucleus](https://arxiv.org/abs/1904.09751) and top-k sampling.
`top_p` configures the nucleus sampling. It sets the maximum cumulative
probability of tokens to sample from.
For example, if the sorted probabilities are
`[0.5, 0.2, 0.1, 0.1, 0.05, 0.05]` a `top_p` of `0.8` will sample
as `[0.625, 0.25, 0.125, 0, 0, 0]`.
safety_settings: A list of unique `types.SafetySetting` instances for blocking unsafe content.
These will be enforced on the `prompt` and
`candidates`. There should not be more than one
setting for each `types.SafetyCategory` type. The API will block any prompts and
responses that fail to meet the thresholds set by these settings. This list
overrides the default settings for each `SafetyCategory` specified in the
safety_settings. If there is no `types.SafetySetting` for a given
`SafetyCategory` provided in the list, the API will use the default safety
setting for that category.
stop_sequences: A set of up to 5 character sequences that will stop output generation.
If specified, the API will stop at the first appearance of a stop
sequence. The stop sequence will not be included as part of the response.
client: If you're not relying on a default client, you pass a `glm.TextServiceClient` instead.
request_options: Options for the request.
Returns:
A `types.Completion` containing the model's text completion response.
"""
request = _make_generate_text_request(
model=model,
prompt=prompt,
temperature=temperature,
candidate_count=candidate_count,
max_output_tokens=max_output_tokens,
top_p=top_p,
top_k=top_k,
safety_settings=safety_settings,
stop_sequences=stop_sequences,
)
return _generate_response(client=client, request=request, request_options=request_options)
@string_utils.prettyprint
@dataclasses.dataclass(init=False)
class Completion(text_types.Completion):
def __init__(self, **kwargs):
for key, value in kwargs.items():
setattr(self, key, value)
self.result = None
if self.candidates:
self.result = self.candidates[0]["output"]
def _generate_response(
request: glm.GenerateTextRequest,
client: glm.TextServiceClient = None,
request_options: dict[str, Any] | None = None,
) -> Completion:
"""
Generates a response using the provided `glm.GenerateTextRequest` and client.
Args:
request: The text generation request.
client: The client to use for text generation. Defaults to None, in which
case the default text client is used.
request_options: Options for the request.
Returns:
`Completion`: A `Completion` object with the generated text and response information.
"""
if request_options is None:
request_options = {}
if client is None:
client = get_default_text_client()
response = client.generate_text(request, **request_options)
response = type(response).to_dict(response)
response["filters"] = safety_types.convert_filters_to_enums(response["filters"])
response["safety_feedback"] = safety_types.convert_safety_feedback_to_enums(
response["safety_feedback"]
)
response["candidates"] = safety_types.convert_candidate_enums(response["candidates"])
return Completion(_client=client, **response)
def count_text_tokens(
model: model_types.AnyModelNameOptions,
prompt: str,
client: glm.TextServiceClient | None = None,
request_options: dict[str, Any] | None = None,
) -> text_types.TokenCount:
base_model = models.get_base_model_name(model)
if request_options is None:
request_options = {}
if client is None:
client = get_default_text_client()
result = client.count_text_tokens(
glm.CountTextTokensRequest(model=base_model, prompt={"text": prompt}),
**request_options,
)
return type(result).to_dict(result)
@overload
def generate_embeddings(
model: model_types.BaseModelNameOptions,
text: str,
client: glm.TextServiceClient = None,
request_options: dict[str, Any] | None = None,
) -> text_types.EmbeddingDict: ...
@overload
def generate_embeddings(
model: model_types.BaseModelNameOptions,
text: Sequence[str],
client: glm.TextServiceClient = None,
request_options: dict[str, Any] | None = None,
) -> text_types.BatchEmbeddingDict: ...
def generate_embeddings(
model: model_types.BaseModelNameOptions,
text: str | Sequence[str],
client: glm.TextServiceClient = None,
request_options: dict[str, Any] | None = None,
) -> text_types.EmbeddingDict | text_types.BatchEmbeddingDict:
"""Calls the API to create an embedding for the text passed in.
Args:
model: Which model to call, as a string or a `types.Model`.
text: Free-form input text given to the model. Given a string, the model will
generate an embedding based on the input text.
client: If you're not relying on a default client, you pass a `glm.TextServiceClient` instead.
request_options: Options for the request.
Returns:
Dictionary containing the embedding (list of float values) for the input text.
"""
model = model_types.make_model_name(model)
if request_options is None:
request_options = {}
if client is None:
client = get_default_text_client()
if isinstance(text, str):
embedding_request = glm.EmbedTextRequest(model=model, text=text)
embedding_response = client.embed_text(
embedding_request,
**request_options,
)
embedding_dict = type(embedding_response).to_dict(embedding_response)
embedding_dict["embedding"] = embedding_dict["embedding"]["value"]
else:
result = {"embedding": []}
for batch in _batched(text, EMBEDDING_MAX_BATCH_SIZE):
# TODO(markdaoust): This could use an option for returning an iterator or wait-bar.
embedding_request = glm.BatchEmbedTextRequest(model=model, texts=batch)
embedding_response = client.batch_embed_text(
embedding_request,
**request_options,
)
embedding_dict = type(embedding_response).to_dict(embedding_response)
result["embedding"].extend(e["value"] for e in embedding_dict["embeddings"])
return result
return embedding_dict