# ------------------------------------ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # ------------------------------------ import functools import os from typing import Any, Optional, Dict, Mapping from azure.core.pipeline.transport import HttpRequest from .._constants import EnvironmentVariables from .._internal import within_dac from .._internal.managed_identity_client import ManagedIdentityClient from .._internal.managed_identity_base import ManagedIdentityBase def validate_client_id_and_config(client_id: Optional[str], identity_config: Optional[Mapping[str, str]]) -> None: if within_dac.get(): return if client_id: raise ValueError("client_id should not be set for cloud shell managed identity.") if identity_config: valid_keys = {"object_id", "resource_id", "client_id"} if len(identity_config.keys() & valid_keys) > 0: raise ValueError(f"identity_config must not contain the following keys: {', '.join(valid_keys)}") class CloudShellCredential(ManagedIdentityBase): def get_client(self, **kwargs: Any) -> Optional[ManagedIdentityClient]: client_id = kwargs.get("client_id") identity_config = kwargs.get("identity_config") validate_client_id_and_config(client_id, identity_config) url = os.environ.get(EnvironmentVariables.MSI_ENDPOINT) if url: return ManagedIdentityClient( request_factory=functools.partial(_get_request, url), base_headers={"Metadata": "true"}, **kwargs ) return None def get_unavailable_message(self, desc: str = "") -> str: return f"Cloud Shell managed identity configuration not found in environment. {desc}" def _get_request(url: str, scope: str, identity_config: Dict) -> HttpRequest: request = HttpRequest("POST", url, data=dict({"resource": scope}, **identity_config)) return request
Memory