import asyncio from typing import Awaitable, Callable, Optional, Union import httpx class BearerAuth(httpx.Auth): def __init__( self, auth_token_provider: Union[Callable[[], str], Callable[[], Awaitable[str]]], ): self.async_token: Optional[Callable[[], Awaitable[str]]] = None self.sync_token: Optional[Callable[[], str]] = None if asyncio.iscoroutinefunction(auth_token_provider): self.async_token = auth_token_provider else: if callable(auth_token_provider): self.sync_token = auth_token_provider # type: ignore else: raise ValueError("auth_token_provider must be a callable or awaitable") def _sync_get_token(self) -> str: if self.sync_token is None: raise ValueError("Synchronous token provider is not set.") return self.sync_token() def sync_auth_flow(self, request: httpx.Request) -> httpx.Request: token = self._sync_get_token() request.headers["Authorization"] = f"Bearer {token}" yield request async def _async_get_token(self) -> str: if self.async_token is not None: return await self.async_token() # type: ignore # Fallback to synchronous token if asynchronous token is not available return self._sync_get_token() async def async_auth_flow(self, request: httpx.Request) -> httpx.Request: token = await self._async_get_token() request.headers["Authorization"] = f"Bearer {token}" yield request
Memory