# Copyright (c) Microsoft Corporation.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import collections.abc
import contextvars
import datetime
import inspect
import sys
import traceback
from pathlib import Path
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
List,
Mapping,
Optional,
TypedDict,
Union,
cast,
)
from pyee import EventEmitter
from pyee.asyncio import AsyncIOEventEmitter
import playwright
import playwright._impl._impl_to_api_mapping
from playwright._impl._errors import TargetClosedError, rewrite_error
from playwright._impl._greenlets import EventGreenlet
from playwright._impl._helper import Error, ParsedMessagePayload, parse_error
from playwright._impl._transport import Transport
if TYPE_CHECKING:
from playwright._impl._local_utils import LocalUtils
from playwright._impl._playwright import Playwright
class Channel(AsyncIOEventEmitter):
def __init__(self, connection: "Connection", object: "ChannelOwner") -> None:
super().__init__()
self._connection = connection
self._guid = object._guid
self._object = object
self.on("error", lambda exc: self._connection._on_event_listener_error(exc))
self._is_internal_type = False
async def send(self, method: str, params: Dict = None) -> Any:
return await self._connection.wrap_api_call(
lambda: self._inner_send(method, params, False),
self._is_internal_type,
)
async def send_return_as_dict(self, method: str, params: Dict = None) -> Any:
return await self._connection.wrap_api_call(
lambda: self._inner_send(method, params, True),
self._is_internal_type,
)
def send_no_reply(self, method: str, params: Dict = None) -> None:
# No reply messages are used to e.g. waitForEventInfo(after).
self._connection.wrap_api_call_sync(
lambda: self._connection._send_message_to_server(
self._object, method, {} if params is None else params, True
)
)
async def _inner_send(
self, method: str, params: Optional[Dict], return_as_dict: bool
) -> Any:
if params is None:
params = {}
if self._connection._error:
error = self._connection._error
self._connection._error = None
raise error
callback = self._connection._send_message_to_server(
self._object, method, _filter_none(params)
)
done, _ = await asyncio.wait(
{
self._connection._transport.on_error_future,
callback.future,
},
return_when=asyncio.FIRST_COMPLETED,
)
if not callback.future.done():
callback.future.cancel()
result = next(iter(done)).result()
# Protocol now has named return values, assume result is one level deeper unless
# there is explicit ambiguity.
if not result:
return None
assert isinstance(result, dict)
if return_as_dict:
return result
if len(result) == 0:
return None
assert len(result) == 1
key = next(iter(result))
return result[key]
def mark_as_internal_type(self) -> None:
self._is_internal_type = True
class ChannelOwner(AsyncIOEventEmitter):
def __init__(
self,
parent: Union["ChannelOwner", "Connection"],
type: str,
guid: str,
initializer: Dict,
) -> None:
super().__init__(loop=parent._loop)
self._loop: asyncio.AbstractEventLoop = parent._loop
self._dispatcher_fiber: Any = parent._dispatcher_fiber
self._type = type
self._guid: str = guid
self._connection: Connection = (
parent._connection if isinstance(parent, ChannelOwner) else parent
)
self._parent: Optional[ChannelOwner] = (
parent if isinstance(parent, ChannelOwner) else None
)
self._objects: Dict[str, "ChannelOwner"] = {}
self._channel: Channel = Channel(self._connection, self)
self._initializer = initializer
self._was_collected = False
self._connection._objects[guid] = self
if self._parent:
self._parent._objects[guid] = self
self._event_to_subscription_mapping: Dict[str, str] = {}
def _dispose(self, reason: Optional[str]) -> None:
# Clean up from parent and connection.
if self._parent:
del self._parent._objects[self._guid]
del self._connection._objects[self._guid]
self._was_collected = reason == "gc"
# Dispose all children.
for object in list(self._objects.values()):
object._dispose(reason)
self._objects.clear()
def _adopt(self, child: "ChannelOwner") -> None:
del cast("ChannelOwner", child._parent)._objects[child._guid]
self._objects[child._guid] = child
child._parent = self
def _set_event_to_subscription_mapping(self, mapping: Dict[str, str]) -> None:
self._event_to_subscription_mapping = mapping
def _update_subscription(self, event: str, enabled: bool) -> None:
protocol_event = self._event_to_subscription_mapping.get(event)
if protocol_event:
self._connection.wrap_api_call_sync(
lambda: self._channel.send_no_reply(
"updateSubscription", {"event": protocol_event, "enabled": enabled}
),
True,
)
def _add_event_handler(self, event: str, k: Any, v: Any) -> None:
if not self.listeners(event):
self._update_subscription(event, True)
super()._add_event_handler(event, k, v)
def remove_listener(self, event: str, f: Any) -> None:
super().remove_listener(event, f)
if not self.listeners(event):
self._update_subscription(event, False)
class ProtocolCallback:
def __init__(self, loop: asyncio.AbstractEventLoop) -> None:
self.stack_trace: traceback.StackSummary
self.no_reply: bool
self.future = loop.create_future()
# The outer task can get cancelled by the user, this forwards the cancellation to the inner task.
current_task = asyncio.current_task()
def cb(task: asyncio.Task) -> None:
if current_task:
current_task.remove_done_callback(cb)
if task.cancelled():
self.future.cancel()
if current_task:
current_task.add_done_callback(cb)
self.future.add_done_callback(
lambda _: (
current_task.remove_done_callback(cb) if current_task else None
)
)
class RootChannelOwner(ChannelOwner):
def __init__(self, connection: "Connection") -> None:
super().__init__(connection, "Root", "", {})
async def initialize(self) -> "Playwright":
return from_channel(
await self._channel.send(
"initialize",
{
"sdkLanguage": "python",
},
)
)
class Connection(EventEmitter):
def __init__(
self,
dispatcher_fiber: Any,
object_factory: Callable[[ChannelOwner, str, str, Dict], ChannelOwner],
transport: Transport,
loop: asyncio.AbstractEventLoop,
local_utils: Optional["LocalUtils"] = None,
) -> None:
super().__init__()
self._dispatcher_fiber = dispatcher_fiber
self._transport = transport
self._transport.on_message = lambda msg: self.dispatch(msg)
self._waiting_for_object: Dict[str, Callable[[ChannelOwner], None]] = {}
self._last_id = 0
self._objects: Dict[str, ChannelOwner] = {}
self._callbacks: Dict[int, ProtocolCallback] = {}
self._object_factory = object_factory
self._is_sync = False
self._child_ws_connections: List["Connection"] = []
self._loop = loop
self.playwright_future: asyncio.Future["Playwright"] = loop.create_future()
self._error: Optional[BaseException] = None
self.is_remote = False
self._init_task: Optional[asyncio.Task] = None
self._api_zone: contextvars.ContextVar[Optional[ParsedStackTrace]] = (
contextvars.ContextVar("ApiZone", default=None)
)
self._local_utils: Optional["LocalUtils"] = local_utils
self._tracing_count = 0
self._closed_error: Optional[Exception] = None
@property
def local_utils(self) -> "LocalUtils":
assert self._local_utils
return self._local_utils
def mark_as_remote(self) -> None:
self.is_remote = True
async def run_as_sync(self) -> None:
self._is_sync = True
await self.run()
async def run(self) -> None:
self._loop = asyncio.get_running_loop()
self._root_object = RootChannelOwner(self)
async def init() -> None:
self.playwright_future.set_result(await self._root_object.initialize())
await self._transport.connect()
self._init_task = self._loop.create_task(init())
await self._transport.run()
def stop_sync(self) -> None:
self._transport.request_stop()
self._dispatcher_fiber.switch()
self._loop.run_until_complete(self._transport.wait_until_stopped())
self.cleanup()
async def stop_async(self) -> None:
self._transport.request_stop()
await self._transport.wait_until_stopped()
self.cleanup()
def cleanup(self, cause: str = None) -> None:
self._closed_error = TargetClosedError(cause) if cause else TargetClosedError()
if self._init_task and not self._init_task.done():
self._init_task.cancel()
for ws_connection in self._child_ws_connections:
ws_connection._transport.dispose()
for callback in self._callbacks.values():
# To prevent 'Future exception was never retrieved' we ignore all callbacks that are no_reply.
if callback.no_reply:
continue
if callback.future.cancelled():
continue
callback.future.set_exception(self._closed_error)
self._callbacks.clear()
self.emit("close")
def call_on_object_with_known_name(
self, guid: str, callback: Callable[[ChannelOwner], None]
) -> None:
self._waiting_for_object[guid] = callback
def set_is_tracing(self, is_tracing: bool) -> None:
if is_tracing:
self._tracing_count += 1
else:
self._tracing_count -= 1
def _send_message_to_server(
self, object: ChannelOwner, method: str, params: Dict, no_reply: bool = False
) -> ProtocolCallback:
if self._closed_error:
raise self._closed_error
if object._was_collected:
raise Error(
"The object has been collected to prevent unbounded heap growth."
)
self._last_id += 1
id = self._last_id
callback = ProtocolCallback(self._loop)
task = asyncio.current_task(self._loop)
callback.stack_trace = cast(
traceback.StackSummary,
getattr(task, "__pw_stack_trace__", traceback.extract_stack()),
)
callback.no_reply = no_reply
self._callbacks[id] = callback
stack_trace_information = cast(ParsedStackTrace, self._api_zone.get())
frames = stack_trace_information.get("frames", [])
location = (
{
"file": frames[0]["file"],
"line": frames[0]["line"],
"column": frames[0]["column"],
}
if frames
else None
)
metadata = {
"wallTime": int(datetime.datetime.now().timestamp() * 1000),
"apiName": stack_trace_information["apiName"],
"internal": not stack_trace_information["apiName"],
}
if location:
metadata["location"] = location # type: ignore
message = {
"id": id,
"guid": object._guid,
"method": method,
"params": self._replace_channels_with_guids(params),
"metadata": metadata,
}
if (
self._tracing_count > 0
and frames
and frames
and object._guid != "localUtils"
):
self.local_utils.add_stack_to_tracing_no_reply(id, frames)
self._transport.send(message)
self._callbacks[id] = callback
return callback
def dispatch(self, msg: ParsedMessagePayload) -> None:
if self._closed_error:
return
id = msg.get("id")
if id:
callback = self._callbacks.pop(id)
if callback.future.cancelled():
return
# No reply messages are used to e.g. waitForEventInfo(after) which returns exceptions on page close.
# To prevent 'Future exception was never retrieved' we just ignore such messages.
if callback.no_reply:
return
error = msg.get("error")
if error and not msg.get("result"):
parsed_error = parse_error(
error["error"], format_call_log(msg.get("log")) # type: ignore
)
parsed_error._stack = "".join(
traceback.format_list(callback.stack_trace)[-10:]
)
callback.future.set_exception(parsed_error)
else:
result = self._replace_guids_with_channels(msg.get("result"))
callback.future.set_result(result)
return
guid = msg["guid"]
method = msg["method"]
params = msg.get("params")
if method == "__create__":
assert params
parent = self._objects[guid]
self._create_remote_object(
parent, params["type"], params["guid"], params["initializer"]
)
return
object = self._objects.get(guid)
if not object:
raise Exception(f'Cannot find object to "{method}": {guid}')
if method == "__adopt__":
child_guid = cast(Dict[str, str], params)["guid"]
child = self._objects.get(child_guid)
if not child:
raise Exception(f"Unknown new child: {child_guid}")
object._adopt(child)
return
if method == "__dispose__":
assert isinstance(params, dict)
self._objects[guid]._dispose(cast(Optional[str], params.get("reason")))
return
object = self._objects[guid]
should_replace_guids_with_channels = "jsonPipe@" not in guid
try:
if self._is_sync:
for listener in object._channel.listeners(method):
# Event handlers like route/locatorHandlerTriggered require us to perform async work.
# In order to report their potential errors to the user, we need to catch it and store it in the connection
def _done_callback(future: asyncio.Future) -> None:
exc = future.exception()
if exc:
self._on_event_listener_error(exc)
def _listener_with_error_handler_attached(params: Any) -> None:
potential_future = listener(params)
if asyncio.isfuture(potential_future):
potential_future.add_done_callback(_done_callback)
# Each event handler is a potentilly blocking context, create a fiber for each
# and switch to them in order, until they block inside and pass control to each
# other and then eventually back to dispatcher as listener functions return.
g = EventGreenlet(_listener_with_error_handler_attached)
if should_replace_guids_with_channels:
g.switch(self._replace_guids_with_channels(params))
else:
g.switch(params)
else:
if should_replace_guids_with_channels:
object._channel.emit(
method, self._replace_guids_with_channels(params)
)
else:
object._channel.emit(method, params)
except BaseException as exc:
self._on_event_listener_error(exc)
def _on_event_listener_error(self, exc: BaseException) -> None:
print("Error occurred in event listener", file=sys.stderr)
traceback.print_exception(type(exc), exc, exc.__traceback__, file=sys.stderr)
# Save the error to throw at the next API call. This "replicates" unhandled rejection in Node.js.
self._error = exc
def _create_remote_object(
self, parent: ChannelOwner, type: str, guid: str, initializer: Dict
) -> ChannelOwner:
initializer = self._replace_guids_with_channels(initializer)
result = self._object_factory(parent, type, guid, initializer)
if guid in self._waiting_for_object:
self._waiting_for_object.pop(guid)(result)
return result
def _replace_channels_with_guids(
self,
payload: Any,
) -> Any:
if payload is None:
return payload
if isinstance(payload, Path):
return str(payload)
if isinstance(payload, collections.abc.Sequence) and not isinstance(
payload, str
):
return list(map(self._replace_channels_with_guids, payload))
if isinstance(payload, Channel):
return dict(guid=payload._guid)
if isinstance(payload, dict):
result = {}
for key, value in payload.items():
result[key] = self._replace_channels_with_guids(value)
return result
return payload
def _replace_guids_with_channels(self, payload: Any) -> Any:
if payload is None:
return payload
if isinstance(payload, list):
return list(map(self._replace_guids_with_channels, payload))
if isinstance(payload, dict):
if payload.get("guid") in self._objects:
return self._objects[payload["guid"]]._channel
result = {}
for key, value in payload.items():
result[key] = self._replace_guids_with_channels(value)
return result
return payload
async def wrap_api_call(
self, cb: Callable[[], Any], is_internal: bool = False
) -> Any:
if self._api_zone.get():
return await cb()
task = asyncio.current_task(self._loop)
st: List[inspect.FrameInfo] = getattr(task, "__pw_stack__", inspect.stack())
parsed_st = _extract_stack_trace_information_from_stack(st, is_internal)
self._api_zone.set(parsed_st)
try:
return await cb()
except Exception as error:
raise rewrite_error(error, f"{parsed_st['apiName']}: {error}") from None
finally:
self._api_zone.set(None)
def wrap_api_call_sync(
self, cb: Callable[[], Any], is_internal: bool = False
) -> Any:
if self._api_zone.get():
return cb()
task = asyncio.current_task(self._loop)
st: List[inspect.FrameInfo] = getattr(task, "__pw_stack__", inspect.stack())
parsed_st = _extract_stack_trace_information_from_stack(st, is_internal)
self._api_zone.set(parsed_st)
try:
return cb()
except Exception as error:
raise rewrite_error(error, f"{parsed_st['apiName']}: {error}") from None
finally:
self._api_zone.set(None)
def from_channel(channel: Channel) -> Any:
return channel._object
def from_nullable_channel(channel: Optional[Channel]) -> Optional[Any]:
return channel._object if channel else None
class StackFrame(TypedDict):
file: str
line: int
column: int
function: Optional[str]
class ParsedStackTrace(TypedDict):
frames: List[StackFrame]
apiName: Optional[str]
def _extract_stack_trace_information_from_stack(
st: List[inspect.FrameInfo], is_internal: bool
) -> ParsedStackTrace:
playwright_module_path = str(Path(playwright.__file__).parents[0])
last_internal_api_name = ""
api_name = ""
parsed_frames: List[StackFrame] = []
for frame in st:
# Sync and Async implementations can have event handlers. When these are sync, they
# get evaluated in the context of the event loop, so they contain the stack trace of when
# the message was received. _impl_to_api_mapping is glue between the user-code and internal
# code to translate impl classes to api classes. We want to ignore these frames.
if playwright._impl._impl_to_api_mapping.__file__ == frame.filename:
continue
is_playwright_internal = frame.filename.startswith(playwright_module_path)
method_name = ""
if "self" in frame[0].f_locals:
method_name = frame[0].f_locals["self"].__class__.__name__ + "."
method_name += frame[0].f_code.co_name
if not is_playwright_internal:
parsed_frames.append(
{
"file": frame.filename,
"line": frame.lineno,
"column": 0,
"function": method_name,
}
)
if is_playwright_internal:
last_internal_api_name = method_name
elif last_internal_api_name:
api_name = last_internal_api_name
last_internal_api_name = ""
if not api_name:
api_name = last_internal_api_name
return {
"frames": parsed_frames,
"apiName": "" if is_internal else api_name,
}
def _filter_none(d: Mapping) -> Dict:
return {k: v for k, v in d.items() if v is not None}
def format_call_log(log: Optional[List[str]]) -> str:
if not log:
return ""
if len(list(filter(lambda x: x.strip(), log))) == 0:
return ""
return "\nCall log:\n" + "\n - ".join(log) + "\n"