# Copyright (c) Microsoft Corporation.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import json
from pathlib import Path
from types import SimpleNamespace
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
List,
Literal,
Optional,
Pattern,
Sequence,
Set,
Union,
cast,
)
from playwright._impl._api_structures import (
Cookie,
Geolocation,
SetCookieParam,
StorageState,
)
from playwright._impl._artifact import Artifact
from playwright._impl._cdp_session import CDPSession
from playwright._impl._clock import Clock
from playwright._impl._connection import (
ChannelOwner,
from_channel,
from_nullable_channel,
)
from playwright._impl._console_message import ConsoleMessage
from playwright._impl._dialog import Dialog
from playwright._impl._errors import Error, TargetClosedError
from playwright._impl._event_context_manager import EventContextManagerImpl
from playwright._impl._fetch import APIRequestContext
from playwright._impl._frame import Frame
from playwright._impl._har_router import HarRouter
from playwright._impl._helper import (
HarContentPolicy,
HarMode,
HarRecordingMetadata,
RouteFromHarNotFoundPolicy,
RouteHandler,
RouteHandlerCallback,
TimeoutSettings,
URLMatch,
WebSocketRouteHandlerCallback,
async_readfile,
async_writefile,
locals_to_params,
parse_error,
prepare_record_har_options,
to_impl,
)
from playwright._impl._network import (
Request,
Response,
Route,
WebSocketRoute,
WebSocketRouteHandler,
serialize_headers,
)
from playwright._impl._page import BindingCall, Page, Worker
from playwright._impl._str_utils import escape_regex_flags
from playwright._impl._tracing import Tracing
from playwright._impl._waiter import Waiter
from playwright._impl._web_error import WebError
if TYPE_CHECKING: # pragma: no cover
from playwright._impl._browser import Browser
class BrowserContext(ChannelOwner):
Events = SimpleNamespace(
BackgroundPage="backgroundpage",
Close="close",
Console="console",
Dialog="dialog",
Page="page",
WebError="weberror",
ServiceWorker="serviceworker",
Request="request",
Response="response",
RequestFailed="requestfailed",
RequestFinished="requestfinished",
)
def __init__(
self, parent: ChannelOwner, type: str, guid: str, initializer: Dict
) -> None:
super().__init__(parent, type, guid, initializer)
# circular import workaround:
self._browser: Optional["Browser"] = None
if parent.__class__.__name__ == "Browser":
self._browser = cast("Browser", parent)
self._browser._contexts.append(self)
self._pages: List[Page] = []
self._routes: List[RouteHandler] = []
self._web_socket_routes: List[WebSocketRouteHandler] = []
self._bindings: Dict[str, Any] = {}
self._timeout_settings = TimeoutSettings(None)
self._owner_page: Optional[Page] = None
self._options: Dict[str, Any] = {}
self._background_pages: Set[Page] = set()
self._service_workers: Set[Worker] = set()
self._tracing = cast(Tracing, from_channel(initializer["tracing"]))
self._har_recorders: Dict[str, HarRecordingMetadata] = {}
self._request: APIRequestContext = from_channel(initializer["requestContext"])
self._clock = Clock(self)
self._channel.on(
"bindingCall",
lambda params: self._on_binding(from_channel(params["binding"])),
)
self._channel.on("close", lambda _: self._on_close())
self._channel.on(
"page", lambda params: self._on_page(from_channel(params["page"]))
)
self._channel.on(
"route",
lambda params: self._loop.create_task(
self._on_route(
from_channel(params.get("route")),
)
),
)
self._channel.on(
"webSocketRoute",
lambda params: self._loop.create_task(
self._on_web_socket_route(
from_channel(params["webSocketRoute"]),
)
),
)
self._channel.on(
"backgroundPage",
lambda params: self._on_background_page(from_channel(params["page"])),
)
self._channel.on(
"serviceWorker",
lambda params: self._on_service_worker(from_channel(params["worker"])),
)
self._channel.on(
"console",
lambda event: self._on_console_message(event),
)
self._channel.on(
"dialog", lambda params: self._on_dialog(from_channel(params["dialog"]))
)
self._channel.on(
"pageError",
lambda params: self._on_page_error(
parse_error(params["error"]["error"]),
from_nullable_channel(params["page"]),
),
)
self._channel.on(
"request",
lambda params: self._on_request(
from_channel(params["request"]),
from_nullable_channel(params.get("page")),
),
)
self._channel.on(
"response",
lambda params: self._on_response(
from_channel(params["response"]),
from_nullable_channel(params.get("page")),
),
)
self._channel.on(
"requestFailed",
lambda params: self._on_request_failed(
from_channel(params["request"]),
params["responseEndTiming"],
params.get("failureText"),
from_nullable_channel(params.get("page")),
),
)
self._channel.on(
"requestFinished",
lambda params: self._on_request_finished(
from_channel(params["request"]),
from_nullable_channel(params.get("response")),
params["responseEndTiming"],
from_nullable_channel(params.get("page")),
),
)
self._closed_future: asyncio.Future = asyncio.Future()
self.once(
self.Events.Close, lambda context: self._closed_future.set_result(True)
)
self._close_reason: Optional[str] = None
self._har_routers: List[HarRouter] = []
self._set_event_to_subscription_mapping(
{
BrowserContext.Events.Console: "console",
BrowserContext.Events.Dialog: "dialog",
BrowserContext.Events.Request: "request",
BrowserContext.Events.Response: "response",
BrowserContext.Events.RequestFinished: "requestFinished",
BrowserContext.Events.RequestFailed: "requestFailed",
}
)
self._close_was_called = False
def __repr__(self) -> str:
return f"<BrowserContext browser={self.browser}>"
def _on_page(self, page: Page) -> None:
self._pages.append(page)
self.emit(BrowserContext.Events.Page, page)
if page._opener and not page._opener.is_closed():
page._opener.emit(Page.Events.Popup, page)
async def _on_route(self, route: Route) -> None:
route._context = self
page = route.request._safe_page()
route_handlers = self._routes.copy()
for route_handler in route_handlers:
# If the page or the context was closed we stall all requests right away.
if (page and page._close_was_called) or self._close_was_called:
return
if not route_handler.matches(route.request.url):
continue
if route_handler not in self._routes:
continue
if route_handler.will_expire:
self._routes.remove(route_handler)
try:
handled = await route_handler.handle(route)
finally:
if len(self._routes) == 0:
asyncio.create_task(
self._connection.wrap_api_call(
lambda: self._update_interception_patterns(), True
)
)
if handled:
return
try:
# If the page is closed or unrouteAll() was called without waiting and interception disabled,
# the method will throw an error - silence it.
await route._inner_continue(True)
except Exception:
pass
async def _on_web_socket_route(self, web_socket_route: WebSocketRoute) -> None:
route_handler = next(
(
route_handler
for route_handler in self._web_socket_routes
if route_handler.matches(web_socket_route.url)
),
None,
)
if route_handler:
await route_handler.handle(web_socket_route)
else:
web_socket_route.connect_to_server()
def _on_binding(self, binding_call: BindingCall) -> None:
func = self._bindings.get(binding_call._initializer["name"])
if func is None:
return
asyncio.create_task(binding_call.call(func))
def set_default_navigation_timeout(self, timeout: float) -> None:
return self._set_default_navigation_timeout_impl(timeout)
def _set_default_navigation_timeout_impl(self, timeout: Optional[float]) -> None:
self._timeout_settings.set_default_navigation_timeout(timeout)
self._channel.send_no_reply(
"setDefaultNavigationTimeoutNoReply",
{} if timeout is None else {"timeout": timeout},
)
def set_default_timeout(self, timeout: float) -> None:
return self._set_default_timeout_impl(timeout)
def _set_default_timeout_impl(self, timeout: Optional[float]) -> None:
self._timeout_settings.set_default_timeout(timeout)
self._channel.send_no_reply(
"setDefaultTimeoutNoReply", {} if timeout is None else {"timeout": timeout}
)
@property
def pages(self) -> List[Page]:
return self._pages.copy()
@property
def browser(self) -> Optional["Browser"]:
return self._browser
def _set_options(self, context_options: Dict, browser_options: Dict) -> None:
self._options = context_options
if self._options.get("recordHar"):
self._har_recorders[""] = {
"path": self._options["recordHar"]["path"],
"content": self._options["recordHar"].get("content"),
}
self._tracing._traces_dir = browser_options.get("tracesDir")
async def new_page(self) -> Page:
if self._owner_page:
raise Error("Please use browser.new_context()")
return from_channel(await self._channel.send("newPage"))
async def cookies(self, urls: Union[str, Sequence[str]] = None) -> List[Cookie]:
if urls is None:
urls = []
if isinstance(urls, str):
urls = [urls]
return await self._channel.send("cookies", dict(urls=urls))
async def add_cookies(self, cookies: Sequence[SetCookieParam]) -> None:
await self._channel.send("addCookies", dict(cookies=cookies))
async def clear_cookies(
self,
name: Union[str, Pattern[str]] = None,
domain: Union[str, Pattern[str]] = None,
path: Union[str, Pattern[str]] = None,
) -> None:
await self._channel.send(
"clearCookies",
{
"name": name if isinstance(name, str) else None,
"nameRegexSource": name.pattern if isinstance(name, Pattern) else None,
"nameRegexFlags": (
escape_regex_flags(name) if isinstance(name, Pattern) else None
),
"domain": domain if isinstance(domain, str) else None,
"domainRegexSource": (
domain.pattern if isinstance(domain, Pattern) else None
),
"domainRegexFlags": (
escape_regex_flags(domain) if isinstance(domain, Pattern) else None
),
"path": path if isinstance(path, str) else None,
"pathRegexSource": path.pattern if isinstance(path, Pattern) else None,
"pathRegexFlags": (
escape_regex_flags(path) if isinstance(path, Pattern) else None
),
},
)
async def grant_permissions(
self, permissions: Sequence[str], origin: str = None
) -> None:
await self._channel.send("grantPermissions", locals_to_params(locals()))
async def clear_permissions(self) -> None:
await self._channel.send("clearPermissions")
async def set_geolocation(self, geolocation: Geolocation = None) -> None:
await self._channel.send("setGeolocation", locals_to_params(locals()))
async def set_extra_http_headers(self, headers: Dict[str, str]) -> None:
await self._channel.send(
"setExtraHTTPHeaders", dict(headers=serialize_headers(headers))
)
async def set_offline(self, offline: bool) -> None:
await self._channel.send("setOffline", dict(offline=offline))
async def add_init_script(
self, script: str = None, path: Union[str, Path] = None
) -> None:
if path:
script = (await async_readfile(path)).decode()
if not isinstance(script, str):
raise Error("Either path or script parameter must be specified")
await self._channel.send("addInitScript", dict(source=script))
async def expose_binding(
self, name: str, callback: Callable, handle: bool = None
) -> None:
for page in self._pages:
if name in page._bindings:
raise Error(
f'Function "{name}" has been already registered in one of the pages'
)
if name in self._bindings:
raise Error(f'Function "{name}" has been already registered')
self._bindings[name] = callback
await self._channel.send(
"exposeBinding", dict(name=name, needsHandle=handle or False)
)
async def expose_function(self, name: str, callback: Callable) -> None:
await self.expose_binding(name, lambda source, *args: callback(*args))
async def route(
self, url: URLMatch, handler: RouteHandlerCallback, times: int = None
) -> None:
self._routes.insert(
0,
RouteHandler(
self._options.get("baseURL"),
url,
handler,
True if self._dispatcher_fiber else False,
times,
),
)
await self._update_interception_patterns()
async def unroute(
self, url: URLMatch, handler: Optional[RouteHandlerCallback] = None
) -> None:
removed = []
remaining = []
for route in self._routes:
if route.url != url or (handler and route.handler != handler):
remaining.append(route)
else:
removed.append(route)
await self._unroute_internal(removed, remaining, "default")
async def _unroute_internal(
self,
removed: List[RouteHandler],
remaining: List[RouteHandler],
behavior: Literal["default", "ignoreErrors", "wait"] = None,
) -> None:
self._routes = remaining
await self._update_interception_patterns()
if behavior is None or behavior == "default":
return
await asyncio.gather(*map(lambda router: router.stop(behavior), removed)) # type: ignore
async def route_web_socket(
self, url: URLMatch, handler: WebSocketRouteHandlerCallback
) -> None:
self._web_socket_routes.insert(
0,
WebSocketRouteHandler(self._options.get("baseURL"), url, handler),
)
await self._update_web_socket_interception_patterns()
def _dispose_har_routers(self) -> None:
for router in self._har_routers:
router.dispose()
self._har_routers = []
async def unroute_all(
self, behavior: Literal["default", "ignoreErrors", "wait"] = None
) -> None:
await self._unroute_internal(self._routes, [], behavior)
self._dispose_har_routers()
async def _record_into_har(
self,
har: Union[Path, str],
page: Optional[Page] = None,
url: Union[Pattern[str], str] = None,
update_content: HarContentPolicy = None,
update_mode: HarMode = None,
) -> None:
params: Dict[str, Any] = {
"options": prepare_record_har_options(
{
"recordHarPath": har,
"recordHarContent": update_content or "attach",
"recordHarMode": update_mode or "minimal",
"recordHarUrlFilter": url,
}
)
}
if page:
params["page"] = page._channel
har_id = await self._channel.send("harStart", params)
self._har_recorders[har_id] = {
"path": str(har),
"content": update_content or "attach",
}
async def route_from_har(
self,
har: Union[Path, str],
url: Union[Pattern[str], str] = None,
notFound: RouteFromHarNotFoundPolicy = None,
update: bool = None,
updateContent: Literal["attach", "embed"] = None,
updateMode: HarMode = None,
) -> None:
if update:
await self._record_into_har(
har=har,
page=None,
url=url,
update_content=updateContent,
update_mode=updateMode,
)
return
router = await HarRouter.create(
local_utils=self._connection.local_utils,
file=str(har),
not_found_action=notFound or "abort",
url_matcher=url,
)
self._har_routers.append(router)
await router.add_context_route(self)
async def _update_interception_patterns(self) -> None:
patterns = RouteHandler.prepare_interception_patterns(self._routes)
await self._channel.send(
"setNetworkInterceptionPatterns", {"patterns": patterns}
)
async def _update_web_socket_interception_patterns(self) -> None:
patterns = WebSocketRouteHandler.prepare_interception_patterns(
self._web_socket_routes
)
await self._channel.send(
"setWebSocketInterceptionPatterns", {"patterns": patterns}
)
def expect_event(
self,
event: str,
predicate: Callable = None,
timeout: float = None,
) -> EventContextManagerImpl:
if timeout is None:
timeout = self._timeout_settings.timeout()
waiter = Waiter(self, f"browser_context.expect_event({event})")
waiter.reject_on_timeout(
timeout, f'Timeout {timeout}ms exceeded while waiting for event "{event}"'
)
if event != BrowserContext.Events.Close:
waiter.reject_on_event(
self, BrowserContext.Events.Close, lambda: TargetClosedError()
)
waiter.wait_for_event(self, event, predicate)
return EventContextManagerImpl(waiter.result())
def _on_close(self) -> None:
if self._browser:
self._browser._contexts.remove(self)
self._dispose_har_routers()
self._tracing._reset_stack_counter()
self.emit(BrowserContext.Events.Close, self)
async def close(self, reason: str = None) -> None:
if self._close_was_called:
return
self._close_reason = reason
self._close_was_called = True
await self._channel._connection.wrap_api_call(
lambda: self.request.dispose(reason=reason), True
)
async def _inner_close() -> None:
for har_id, params in self._har_recorders.items():
har = cast(
Artifact,
from_channel(
await self._channel.send("harExport", {"harId": har_id})
),
)
# Server side will compress artifact if content is attach or if file is .zip.
is_compressed = params.get("content") == "attach" or params[
"path"
].endswith(".zip")
need_compressed = params["path"].endswith(".zip")
if is_compressed and not need_compressed:
tmp_path = params["path"] + ".tmp"
await har.save_as(tmp_path)
await self._connection.local_utils.har_unzip(
zipFile=tmp_path, harFile=params["path"]
)
else:
await har.save_as(params["path"])
await har.delete()
await self._channel._connection.wrap_api_call(_inner_close, True)
await self._channel.send("close", {"reason": reason})
await self._closed_future
async def storage_state(self, path: Union[str, Path] = None) -> StorageState:
result = await self._channel.send_return_as_dict("storageState")
if path:
await async_writefile(path, json.dumps(result))
return result
def _effective_close_reason(self) -> Optional[str]:
if self._close_reason:
return self._close_reason
if self._browser:
return self._browser._close_reason
return None
async def wait_for_event(
self, event: str, predicate: Callable = None, timeout: float = None
) -> Any:
async with self.expect_event(event, predicate, timeout) as event_info:
pass
return await event_info
def expect_console_message(
self,
predicate: Callable[[ConsoleMessage], bool] = None,
timeout: float = None,
) -> EventContextManagerImpl[ConsoleMessage]:
return self.expect_event(Page.Events.Console, predicate, timeout)
def expect_page(
self,
predicate: Callable[[Page], bool] = None,
timeout: float = None,
) -> EventContextManagerImpl[Page]:
return self.expect_event(BrowserContext.Events.Page, predicate, timeout)
def _on_background_page(self, page: Page) -> None:
self._background_pages.add(page)
self.emit(BrowserContext.Events.BackgroundPage, page)
def _on_service_worker(self, worker: Worker) -> None:
worker._context = self
self._service_workers.add(worker)
self.emit(BrowserContext.Events.ServiceWorker, worker)
def _on_request_failed(
self,
request: Request,
response_end_timing: float,
failure_text: Optional[str],
page: Optional[Page],
) -> None:
request._failure_text = failure_text
request._set_response_end_timing(response_end_timing)
self.emit(BrowserContext.Events.RequestFailed, request)
if page:
page.emit(Page.Events.RequestFailed, request)
def _on_request_finished(
self,
request: Request,
response: Optional[Response],
response_end_timing: float,
page: Optional[Page],
) -> None:
request._set_response_end_timing(response_end_timing)
self.emit(BrowserContext.Events.RequestFinished, request)
if page:
page.emit(Page.Events.RequestFinished, request)
if response:
response._finished_future.set_result(True)
def _on_console_message(self, event: Dict) -> None:
message = ConsoleMessage(event, self._loop, self._dispatcher_fiber)
self.emit(BrowserContext.Events.Console, message)
page = message.page
if page:
page.emit(Page.Events.Console, message)
def _on_dialog(self, dialog: Dialog) -> None:
has_listeners = self.emit(BrowserContext.Events.Dialog, dialog)
page = dialog.page
if page:
has_listeners = page.emit(Page.Events.Dialog, dialog) or has_listeners
if not has_listeners:
# Although we do similar handling on the server side, we still need this logic
# on the client side due to a possible race condition between two async calls:
# a) removing "dialog" listener subscription (client->server)
# b) actual "dialog" event (server->client)
if dialog.type == "beforeunload":
asyncio.create_task(dialog.accept())
else:
asyncio.create_task(dialog.dismiss())
def _on_page_error(self, error: Error, page: Optional[Page]) -> None:
self.emit(BrowserContext.Events.WebError, WebError(self._loop, page, error))
if page:
page.emit(Page.Events.PageError, error)
def _on_request(self, request: Request, page: Optional[Page]) -> None:
self.emit(BrowserContext.Events.Request, request)
if page:
page.emit(Page.Events.Request, request)
def _on_response(self, response: Response, page: Optional[Page]) -> None:
self.emit(BrowserContext.Events.Response, response)
if page:
page.emit(Page.Events.Response, response)
@property
def background_pages(self) -> List[Page]:
return list(self._background_pages)
@property
def service_workers(self) -> List[Worker]:
return list(self._service_workers)
async def new_cdp_session(self, page: Union[Page, Frame]) -> CDPSession:
page = to_impl(page)
params = {}
if isinstance(page, Page):
params["page"] = page._channel
elif isinstance(page, Frame):
params["frame"] = page._channel
else:
raise Error("page: expected Page or Frame")
return from_channel(await self._channel.send("newCDPSession", params))
@property
def tracing(self) -> Tracing:
return self._tracing
@property
def request(self) -> "APIRequestContext":
return self._request
@property
def clock(self) -> Clock:
return self._clock