# -*- coding: utf-8 -*- # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations import dataclasses from collections.abc import Iterable, Sequence import itertools from typing import Any, Iterable, overload, TypeVar import google.ai.generativelanguage as glm from google.generativeai.client import get_default_text_client from google.generativeai import string_utils from google.generativeai.types import text_types from google.generativeai.types import model_types from google.generativeai import models from google.generativeai.types import safety_types DEFAULT_TEXT_MODEL = "models/text-bison-001" EMBEDDING_MAX_BATCH_SIZE = 100 try: # python 3.12+ _batched = itertools.batched # type: ignore except AttributeError: T = TypeVar("T") def _batched(iterable: Iterable[T], n: int) -> Iterable[list[T]]: if n < 1: raise ValueError(f"Batch size `n` must be >1, got: {n}") batch = [] for item in iterable: batch.append(item) if len(batch) == n: yield batch batch = [] if batch: yield batch def _make_text_prompt(prompt: str | dict[str, str]) -> glm.TextPrompt: """ Creates a `glm.TextPrompt` object based on the provided prompt input. Args: prompt: The prompt input, either a string or a dictionary. Returns: glm.TextPrompt: A TextPrompt object containing the prompt text. Raises: TypeError: If the provided prompt is neither a string nor a dictionary. """ if isinstance(prompt, str): return glm.TextPrompt(text=prompt) elif isinstance(prompt, dict): return glm.TextPrompt(prompt) else: TypeError("Expected string or dictionary for text prompt.") def _make_generate_text_request( *, model: model_types.AnyModelNameOptions = DEFAULT_TEXT_MODEL, prompt: str | None = None, temperature: float | None = None, candidate_count: int | None = None, max_output_tokens: int | None = None, top_p: int | None = None, top_k: int | None = None, safety_settings: safety_types.SafetySettingOptions | None = None, stop_sequences: str | Iterable[str] | None = None, ) -> glm.GenerateTextRequest: """ Creates a `glm.GenerateTextRequest` object based on the provided parameters. This function generates a `glm.GenerateTextRequest` object with the specified parameters. It prepares the input parameters and creates a request that can be used for generating text using the chosen model. Args: model: The model to use for text generation. prompt: The prompt for text generation. Defaults to None. temperature: The temperature for randomness in generation. Defaults to None. candidate_count: The number of candidates to consider. Defaults to None. max_output_tokens: The maximum number of output tokens. Defaults to None. top_p: The nucleus sampling probability threshold. Defaults to None. top_k: The top-k sampling parameter. Defaults to None. safety_settings: Safety settings for generated text. Defaults to None. stop_sequences: Stop sequences to halt text generation. Can be a string or iterable of strings. Defaults to None. Returns: `glm.GenerateTextRequest`: A `GenerateTextRequest` object configured with the specified parameters. """ model = model_types.make_model_name(model) prompt = _make_text_prompt(prompt=prompt) safety_settings = safety_types.normalize_safety_settings( safety_settings, harm_category_set="old" ) if isinstance(stop_sequences, str): stop_sequences = [stop_sequences] if stop_sequences: stop_sequences = list(stop_sequences) return glm.GenerateTextRequest( model=model, prompt=prompt, temperature=temperature, candidate_count=candidate_count, max_output_tokens=max_output_tokens, top_p=top_p, top_k=top_k, safety_settings=safety_settings, stop_sequences=stop_sequences, ) def generate_text( *, model: model_types.AnyModelNameOptions = DEFAULT_TEXT_MODEL, prompt: str, temperature: float | None = None, candidate_count: int | None = None, max_output_tokens: int | None = None, top_p: float | None = None, top_k: float | None = None, safety_settings: safety_types.SafetySettingOptions | None = None, stop_sequences: str | Iterable[str] | None = None, client: glm.TextServiceClient | None = None, request_options: dict[str, Any] | None = None, ) -> text_types.Completion: """Calls the API and returns a `types.Completion` containing the response. Args: model: Which model to call, as a string or a `types.Model`. prompt: Free-form input text given to the model. Given a prompt, the model will generate text that completes the input text. temperature: Controls the randomness of the output. Must be positive. Typical values are in the range: `[0.0,1.0]`. Higher values produce a more random and varied response. A temperature of zero will be deterministic. candidate_count: The **maximum** number of generated response messages to return. This value must be between `[1, 8]`, inclusive. If unset, this will default to `1`. Note: Only unique candidates are returned. Higher temperatures are more likely to produce unique candidates. Setting `temperature=0.0` will always return 1 candidate regardless of the `candidate_count`. max_output_tokens: Maximum number of tokens to include in a candidate. Must be greater than zero. If unset, will default to 64. top_k: The API uses combined [nucleus](https://arxiv.org/abs/1904.09751) and top-k sampling. `top_k` sets the maximum number of tokens to sample from on each step. top_p: The API uses combined [nucleus](https://arxiv.org/abs/1904.09751) and top-k sampling. `top_p` configures the nucleus sampling. It sets the maximum cumulative probability of tokens to sample from. For example, if the sorted probabilities are `[0.5, 0.2, 0.1, 0.1, 0.05, 0.05]` a `top_p` of `0.8` will sample as `[0.625, 0.25, 0.125, 0, 0, 0]`. safety_settings: A list of unique `types.SafetySetting` instances for blocking unsafe content. These will be enforced on the `prompt` and `candidates`. There should not be more than one setting for each `types.SafetyCategory` type. The API will block any prompts and responses that fail to meet the thresholds set by these settings. This list overrides the default settings for each `SafetyCategory` specified in the safety_settings. If there is no `types.SafetySetting` for a given `SafetyCategory` provided in the list, the API will use the default safety setting for that category. stop_sequences: A set of up to 5 character sequences that will stop output generation. If specified, the API will stop at the first appearance of a stop sequence. The stop sequence will not be included as part of the response. client: If you're not relying on a default client, you pass a `glm.TextServiceClient` instead. request_options: Options for the request. Returns: A `types.Completion` containing the model's text completion response. """ request = _make_generate_text_request( model=model, prompt=prompt, temperature=temperature, candidate_count=candidate_count, max_output_tokens=max_output_tokens, top_p=top_p, top_k=top_k, safety_settings=safety_settings, stop_sequences=stop_sequences, ) return _generate_response(client=client, request=request, request_options=request_options) @string_utils.prettyprint @dataclasses.dataclass(init=False) class Completion(text_types.Completion): def __init__(self, **kwargs): for key, value in kwargs.items(): setattr(self, key, value) self.result = None if self.candidates: self.result = self.candidates[0]["output"] def _generate_response( request: glm.GenerateTextRequest, client: glm.TextServiceClient = None, request_options: dict[str, Any] | None = None, ) -> Completion: """ Generates a response using the provided `glm.GenerateTextRequest` and client. Args: request: The text generation request. client: The client to use for text generation. Defaults to None, in which case the default text client is used. request_options: Options for the request. Returns: `Completion`: A `Completion` object with the generated text and response information. """ if request_options is None: request_options = {} if client is None: client = get_default_text_client() response = client.generate_text(request, **request_options) response = type(response).to_dict(response) response["filters"] = safety_types.convert_filters_to_enums(response["filters"]) response["safety_feedback"] = safety_types.convert_safety_feedback_to_enums( response["safety_feedback"] ) response["candidates"] = safety_types.convert_candidate_enums(response["candidates"]) return Completion(_client=client, **response) def count_text_tokens( model: model_types.AnyModelNameOptions, prompt: str, client: glm.TextServiceClient | None = None, request_options: dict[str, Any] | None = None, ) -> text_types.TokenCount: base_model = models.get_base_model_name(model) if request_options is None: request_options = {} if client is None: client = get_default_text_client() result = client.count_text_tokens( glm.CountTextTokensRequest(model=base_model, prompt={"text": prompt}), **request_options, ) return type(result).to_dict(result) @overload def generate_embeddings( model: model_types.BaseModelNameOptions, text: str, client: glm.TextServiceClient = None, request_options: dict[str, Any] | None = None, ) -> text_types.EmbeddingDict: ... @overload def generate_embeddings( model: model_types.BaseModelNameOptions, text: Sequence[str], client: glm.TextServiceClient = None, request_options: dict[str, Any] | None = None, ) -> text_types.BatchEmbeddingDict: ... def generate_embeddings( model: model_types.BaseModelNameOptions, text: str | Sequence[str], client: glm.TextServiceClient = None, request_options: dict[str, Any] | None = None, ) -> text_types.EmbeddingDict | text_types.BatchEmbeddingDict: """Calls the API to create an embedding for the text passed in. Args: model: Which model to call, as a string or a `types.Model`. text: Free-form input text given to the model. Given a string, the model will generate an embedding based on the input text. client: If you're not relying on a default client, you pass a `glm.TextServiceClient` instead. request_options: Options for the request. Returns: Dictionary containing the embedding (list of float values) for the input text. """ model = model_types.make_model_name(model) if request_options is None: request_options = {} if client is None: client = get_default_text_client() if isinstance(text, str): embedding_request = glm.EmbedTextRequest(model=model, text=text) embedding_response = client.embed_text( embedding_request, **request_options, ) embedding_dict = type(embedding_response).to_dict(embedding_response) embedding_dict["embedding"] = embedding_dict["embedding"]["value"] else: result = {"embedding": []} for batch in _batched(text, EMBEDDING_MAX_BATCH_SIZE): # TODO(markdaoust): This could use an option for returning an iterator or wait-bar. embedding_request = glm.BatchEmbedTextRequest(model=model, texts=batch) embedding_response = client.batch_embed_text( embedding_request, **request_options, ) embedding_dict = type(embedding_response).to_dict(embedding_response) result["embedding"].extend(e["value"] for e in embedding_dict["embeddings"]) return result return embedding_dict
Memory