from __future__ import annotations
import os
from typing import TYPE_CHECKING, Any, Union, Mapping, TypeVar
from typing_extensions import Self, override
import httpx
from ... import _exceptions
from ._auth import load_auth, refresh_auth
from ._beta import Beta, AsyncBeta
from ..._types import NOT_GIVEN, NotGiven
from ..._utils import is_dict, asyncify, is_given
from ..._compat import model_copy, typed_cached_property
from ..._models import FinalRequestOptions
from ..._version import __version__
from ..._streaming import Stream, AsyncStream
from ..._exceptions import AnthropicError, APIStatusError
from ..._base_client import (
DEFAULT_MAX_RETRIES,
BaseClient,
SyncAPIClient,
AsyncAPIClient,
)
from ...resources.messages import Messages, AsyncMessages
if TYPE_CHECKING:
from google.auth.credentials import Credentials as GoogleCredentials # type: ignore
DEFAULT_VERSION = "vertex-2023-10-16"
_HttpxClientT = TypeVar("_HttpxClientT", bound=Union[httpx.Client, httpx.AsyncClient])
_DefaultStreamT = TypeVar("_DefaultStreamT", bound=Union[Stream[Any], AsyncStream[Any]])
class BaseVertexClient(BaseClient[_HttpxClientT, _DefaultStreamT]):
@typed_cached_property
def region(self) -> str:
raise RuntimeError("region not set")
@typed_cached_property
def project_id(self) -> str | None:
project_id = os.environ.get("ANTHROPIC_VERTEX_PROJECT_ID")
if project_id:
return project_id
return None
@override
def _make_status_error(
self,
err_msg: str,
*,
body: object,
response: httpx.Response,
) -> APIStatusError:
if response.status_code == 400:
return _exceptions.BadRequestError(err_msg, response=response, body=body)
if response.status_code == 401:
return _exceptions.AuthenticationError(err_msg, response=response, body=body)
if response.status_code == 403:
return _exceptions.PermissionDeniedError(err_msg, response=response, body=body)
if response.status_code == 404:
return _exceptions.NotFoundError(err_msg, response=response, body=body)
if response.status_code == 409:
return _exceptions.ConflictError(err_msg, response=response, body=body)
if response.status_code == 422:
return _exceptions.UnprocessableEntityError(err_msg, response=response, body=body)
if response.status_code == 429:
return _exceptions.RateLimitError(err_msg, response=response, body=body)
if response.status_code == 503:
return _exceptions.ServiceUnavailableError(err_msg, response=response, body=body)
if response.status_code == 504:
return _exceptions.DeadlineExceededError(err_msg, response=response, body=body)
if response.status_code >= 500:
return _exceptions.InternalServerError(err_msg, response=response, body=body)
return APIStatusError(err_msg, response=response, body=body)
class AnthropicVertex(BaseVertexClient[httpx.Client, Stream[Any]], SyncAPIClient):
messages: Messages
beta: Beta
def __init__(
self,
*,
region: str | NotGiven = NOT_GIVEN,
project_id: str | NotGiven = NOT_GIVEN,
access_token: str | None = None,
credentials: GoogleCredentials | None = None,
base_url: str | httpx.URL | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
max_retries: int = DEFAULT_MAX_RETRIES,
default_headers: Mapping[str, str] | None = None,
default_query: Mapping[str, object] | None = None,
# Configure a custom httpx client. See the [httpx documentation](https://www.python-httpx.org/api/#client) for more details.
http_client: httpx.Client | None = None,
_strict_response_validation: bool = False,
) -> None:
if not is_given(region):
region = os.environ.get("CLOUD_ML_REGION", NOT_GIVEN)
if not is_given(region):
raise ValueError(
"No region was given. The client should be instantiated with the `region` argument or the `CLOUD_ML_REGION` environment variable should be set."
)
if base_url is None:
base_url = os.environ.get("ANTHROPIC_VERTEX_BASE_URL")
if base_url is None:
base_url = f"https://{region}-aiplatform.googleapis.com/v1"
super().__init__(
version=__version__,
base_url=base_url,
timeout=timeout,
max_retries=max_retries,
custom_headers=default_headers,
custom_query=default_query,
http_client=http_client,
_strict_response_validation=_strict_response_validation,
)
if is_given(project_id):
self.project_id = project_id
self.region = region
self.access_token = access_token
self.credentials = credentials
self.messages = Messages(self)
self.beta = Beta(self)
@override
def _prepare_options(self, options: FinalRequestOptions) -> FinalRequestOptions:
return _prepare_options(options, project_id=self.project_id, region=self.region)
@override
def _prepare_request(self, request: httpx.Request) -> None:
if request.headers.get("Authorization"):
# already authenticated, nothing for us to do
return
request.headers["Authorization"] = f"Bearer {self._ensure_access_token()}"
def _ensure_access_token(self) -> str:
if self.access_token is not None:
return self.access_token
if not self.credentials:
self.credentials, project_id = load_auth(project_id=self.project_id)
if not self.project_id:
self.project_id = project_id
if self.credentials.expired or not self.credentials.token:
refresh_auth(self.credentials)
if not self.credentials.token:
raise RuntimeError("Could not resolve API token from the environment")
assert isinstance(self.credentials.token, str)
return self.credentials.token
def copy(
self,
*,
region: str | NotGiven = NOT_GIVEN,
project_id: str | NotGiven = NOT_GIVEN,
access_token: str | None = None,
credentials: GoogleCredentials | None = None,
base_url: str | httpx.URL | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
http_client: httpx.Client | None = None,
max_retries: int | NotGiven = NOT_GIVEN,
default_headers: Mapping[str, str] | None = None,
set_default_headers: Mapping[str, str] | None = None,
default_query: Mapping[str, object] | None = None,
set_default_query: Mapping[str, object] | None = None,
_extra_kwargs: Mapping[str, Any] = {},
) -> Self:
"""
Create a new client instance re-using the same options given to the current client with optional overriding.
"""
if default_headers is not None and set_default_headers is not None:
raise ValueError("The `default_headers` and `set_default_headers` arguments are mutually exclusive")
if default_query is not None and set_default_query is not None:
raise ValueError("The `default_query` and `set_default_query` arguments are mutually exclusive")
headers = self._custom_headers
if default_headers is not None:
headers = {**headers, **default_headers}
elif set_default_headers is not None:
headers = set_default_headers
params = self._custom_query
if default_query is not None:
params = {**params, **default_query}
elif set_default_query is not None:
params = set_default_query
http_client = http_client or self._client
return self.__class__(
region=region if is_given(region) else self.region,
project_id=project_id if is_given(project_id) else self.project_id or NOT_GIVEN,
access_token=access_token or self.access_token,
credentials=credentials or self.credentials,
base_url=base_url or self.base_url,
timeout=self.timeout if isinstance(timeout, NotGiven) else timeout,
http_client=http_client,
max_retries=max_retries if is_given(max_retries) else self.max_retries,
default_headers=headers,
default_query=params,
**_extra_kwargs,
)
# Alias for `copy` for nicer inline usage, e.g.
# client.with_options(timeout=10).foo.create(...)
with_options = copy
class AsyncAnthropicVertex(BaseVertexClient[httpx.AsyncClient, AsyncStream[Any]], AsyncAPIClient):
messages: AsyncMessages
beta: AsyncBeta
def __init__(
self,
*,
region: str | NotGiven = NOT_GIVEN,
project_id: str | NotGiven = NOT_GIVEN,
access_token: str | None = None,
credentials: GoogleCredentials | None = None,
base_url: str | httpx.URL | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
max_retries: int = DEFAULT_MAX_RETRIES,
default_headers: Mapping[str, str] | None = None,
default_query: Mapping[str, object] | None = None,
# Configure a custom httpx client. See the [httpx documentation](https://www.python-httpx.org/api/#client) for more details.
http_client: httpx.AsyncClient | None = None,
_strict_response_validation: bool = False,
) -> None:
if not is_given(region):
region = os.environ.get("CLOUD_ML_REGION", NOT_GIVEN)
if not is_given(region):
raise ValueError(
"No region was given. The client should be instantiated with the `region` argument or the `CLOUD_ML_REGION` environment variable should be set."
)
if base_url is None:
base_url = os.environ.get("ANTHROPIC_VERTEX_BASE_URL")
if base_url is None:
base_url = f"https://{region}-aiplatform.googleapis.com/v1"
super().__init__(
version=__version__,
base_url=base_url,
timeout=timeout,
max_retries=max_retries,
custom_headers=default_headers,
custom_query=default_query,
http_client=http_client,
_strict_response_validation=_strict_response_validation,
)
if is_given(project_id):
self.project_id = project_id
self.region = region
self.access_token = access_token
self.credentials = credentials
self.messages = AsyncMessages(self)
self.beta = AsyncBeta(self)
@override
async def _prepare_options(self, options: FinalRequestOptions) -> FinalRequestOptions:
return _prepare_options(options, project_id=self.project_id, region=self.region)
@override
async def _prepare_request(self, request: httpx.Request) -> None:
if request.headers.get("Authorization"):
# already authenticated, nothing for us to do
return
request.headers["Authorization"] = f"Bearer {await self._ensure_access_token()}"
async def _ensure_access_token(self) -> str:
if self.access_token is not None:
return self.access_token
if not self.credentials:
self.credentials, project_id = await asyncify(load_auth)(project_id=self.project_id)
if not self.project_id:
self.project_id = project_id
if self.credentials.expired or not self.credentials.token:
await asyncify(refresh_auth)(self.credentials)
if not self.credentials.token:
raise RuntimeError("Could not resolve API token from the environment")
assert isinstance(self.credentials.token, str)
return self.credentials.token
def copy(
self,
*,
region: str | NotGiven = NOT_GIVEN,
project_id: str | NotGiven = NOT_GIVEN,
access_token: str | None = None,
credentials: GoogleCredentials | None = None,
base_url: str | httpx.URL | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
http_client: httpx.AsyncClient | None = None,
max_retries: int | NotGiven = NOT_GIVEN,
default_headers: Mapping[str, str] | None = None,
set_default_headers: Mapping[str, str] | None = None,
default_query: Mapping[str, object] | None = None,
set_default_query: Mapping[str, object] | None = None,
_extra_kwargs: Mapping[str, Any] = {},
) -> Self:
"""
Create a new client instance re-using the same options given to the current client with optional overriding.
"""
if default_headers is not None and set_default_headers is not None:
raise ValueError("The `default_headers` and `set_default_headers` arguments are mutually exclusive")
if default_query is not None and set_default_query is not None:
raise ValueError("The `default_query` and `set_default_query` arguments are mutually exclusive")
headers = self._custom_headers
if default_headers is not None:
headers = {**headers, **default_headers}
elif set_default_headers is not None:
headers = set_default_headers
params = self._custom_query
if default_query is not None:
params = {**params, **default_query}
elif set_default_query is not None:
params = set_default_query
http_client = http_client or self._client
return self.__class__(
region=region if is_given(region) else self.region,
project_id=project_id if is_given(project_id) else self.project_id or NOT_GIVEN,
access_token=access_token or self.access_token,
credentials=credentials or self.credentials,
base_url=base_url or self.base_url,
timeout=self.timeout if isinstance(timeout, NotGiven) else timeout,
http_client=http_client,
max_retries=max_retries if is_given(max_retries) else self.max_retries,
default_headers=headers,
default_query=params,
**_extra_kwargs,
)
# Alias for `copy` for nicer inline usage, e.g.
# client.with_options(timeout=10).foo.create(...)
with_options = copy
def _prepare_options(input_options: FinalRequestOptions, *, project_id: str | None, region: str) -> FinalRequestOptions:
options = model_copy(input_options, deep=True)
if is_dict(options.json_data):
options.json_data.setdefault("anthropic_version", DEFAULT_VERSION)
if options.url in {"/v1/messages", "/v1/messages?beta=true"} and options.method == "post":
if project_id is None:
raise RuntimeError(
"No project_id was given and it could not be resolved from credentials. The client should be instantiated with the `project_id` argument or the `ANTHROPIC_VERTEX_PROJECT_ID` environment variable should be set."
)
if not is_dict(options.json_data):
raise RuntimeError("Expected json data to be a dictionary for post /v1/messages")
model = options.json_data.pop("model")
stream = options.json_data.get("stream", False)
specifier = "streamRawPredict" if stream else "rawPredict"
options.url = f"/projects/{project_id}/locations/{region}/publishers/anthropic/models/{model}:{specifier}"
if options.url in {"/v1/messages/count_tokens", "/v1/messages/count_tokens?beta=true"} and options.method == "post":
if project_id is None:
raise RuntimeError(
"No project_id was given and it could not be resolved from credentials. The client should be instantiated with the `project_id` argument or the `ANTHROPIC_VERTEX_PROJECT_ID` environment variable should be set."
)
options.url = f"/projects/{project_id}/locations/{region}/publishers/anthropic/models/count-tokens:rawPredict"
if options.url.startswith("/v1/messages/batches"):
raise AnthropicError("The Batch API is not supported in the Vertex client yet")
return options