import asyncio
import traceback
import typing
from urllib.parse import parse_qs
from .headers import Headers, make_headers
from .json import dumps
from .types import HeadersType, Receive, Scope, Send, StrOrBytes
from .utils import to_str
from .websocketexceptions import CloseServerError, WebsocketClose
if typing.TYPE_CHECKING:
from .app import Baguette
[docs]class Websocket:
"""Base websocket class.
You usually only need to overwrite the :meth:`on_connect`,
:meth:`on_message`, :meth:`on_disconnect` and :meth:`on_close` when
subclassing.
"""
def __init__(
self, app: "Baguette", scope: Scope, receive: Receive, send: Send
):
self.app: "Baguette" = app
self._scope: Scope = scope
self._receive: Receive = receive
self._send: Send = send
self._message_queue: asyncio.Queue = asyncio.Queue()
self.accepted: asyncio.Event = asyncio.Event()
self.closed: asyncio.Event = asyncio.Event()
self.asgi_version: str = scope["asgi"]["version"]
self.headers: Headers = Headers(*scope["headers"])
self.scheme: str = scope["scheme"]
self.root_path: str = scope["root_path"]
self.path: str = scope["path"].rstrip("/") or "/"
self.querystring: typing.Dict[str, typing.List[str]] = parse_qs(
scope["query_string"].decode("ascii")
)
self.server: typing.Tuple[str, int] = scope["server"]
self.client: typing.Tuple[str, int] = scope["client"]
self.subprotocols: typing.List[str] = scope["subprotocols"]
# --------------------------------------------------------------------------
# Websocket methods
[docs] async def connect(self) -> bool:
"""Connects to the websocket.
.. warning:: This method must be called before :meth:`handle_messages`.
Returns
-------
:class:`bool`
Whether the websocket is connected.
Raises
------
:exc:`RuntimeError`
Websocket is already connected
"""
message = await self._receive()
if message["type"] != "websocket.connect":
raise RuntimeError("Websocket already connected.")
try:
await self.on_connect()
except WebsocketClose as close:
await self.close(403, str(close))
return False
except Exception as error:
await self.on_error("on_connect", error)
reason = "Error in connection."
if self.app.debug:
reason = (
reason[:-1]
+ ": "
+ str(error)
+ "\nTraceback (most recent call last):\n"
+ "".join(traceback.format_tb(error.__traceback__))
)
await self.close(403, reason)
return False
else:
await self.accept()
finally:
if self.accepted.is_set():
# run the main loop
async def wrapped_main():
while not self.closed.is_set():
await self.main()
await asyncio.sleep(0) # release pool
self._schedule_coro(wrapped_main, "main")
return True
return False
[docs] async def accept(
self, headers: HeadersType = None, subprotocol: str = None
):
"""Accepts the websocket connection.
If you want to accept with your own headers or subprotocol, call this in
:meth:`on_connect`. If you don't, it will be called in :meth:`connect`
if :meth:`on_connect` doesn't error.
Parameters
----------
headers : :class:`list` of ``(str, str)`` tuples, \
:class:`dict` or :class:`Headers`
The headers to include in the accept message.
Default: No headers.
subprotocol : Optional :class:`str`
The subprotocol to use in the websocket.
Default: ``None``
"""
if not self.accepted.is_set():
headers: Headers = make_headers(headers)
await self._send(
{
"type": "websocket.accept",
"headers": headers.raw(),
"subprotocol": subprotocol,
}
)
self.accepted.set()
[docs] async def receive(self) -> str:
"""Receives a message from the websocket.
.. note:: You don't need to call this method in :meth:`on_message`.
Returns
-------
:class:`str`
The received message
"""
return to_str(await self._message_queue.get())
[docs] async def send(self, message: StrOrBytes):
"""Sends a message to the websocket.
Parameters
----------
message : :class:`str` or :class:`bytes`
The message to send to the websocket.
Raises
------
TypeError
The message isn't of type :class:`str` or :class:`bytes`.
"""
if isinstance(message, str):
message = dict(text=message)
elif isinstance(message, bytes):
message = dict(bytes=message)
else:
raise TypeError(
"message must be of type str or bytes. Got: "
+ message.__class__.__name__
)
message["type"] = "websocket.send"
await self._send(message)
[docs] async def send_json(self, data):
await self.send(dumps(data))
[docs] async def close(self, code: int = 1000, reason: str = ""):
"""Closes the websocket connection.
Parameters
----------
code : Optional :class:`int`
The status code to close the websocket connection with.
Default: ``1000``.
reason : Optional :class:`str`
The reason to close the websocket connection with.
Default: ``""``.
Raises
------
:exc:`RuntimeError`
The connection is already closed.
"""
if not self.closed.is_set():
await self._send(
{"type": "websocket.close", "code": code, "reason": reason}
)
self.closed.set()
self.dispatch("close", code, reason)
else:
raise RuntimeError("Websocket already closed.")
[docs] async def handle_messages(self):
"""Handles the received messages, calls :meth:`on_message` and puts them
in queue.
Raises
------
:exc:`RuntimeError`
The websocket isn't connected.
"""
while not self.closed.is_set():
if not self.accepted.is_set():
raise RuntimeError("Websocket not connected.")
message = await self._receive()
if message["type"] == "websocket.receive":
message = message.get("bytes") or message.get("text")
await self._message_queue.put(message)
self.dispatch("message", message)
elif message["type"] == "websocket.disconnect":
self.closed.set()
self.dispatch("disconnect", message["code"])
async def _run_coro(
self, coro: typing.Coroutine, event_name: str, *args, **kwargs
):
"""Runs a coroutine and catches errors."""
try:
await coro(*args, **kwargs)
except WebsocketClose as close:
await self.close(close.close_code, str(close))
except Exception as error:
await self.on_error(event_name, error, *args, **kwargs)
def _schedule_coro(
self, coro: typing.Coroutine, event_name: str, *args, **kwargs
) -> asyncio.Task:
"""Schedules a coroutine in an :class:`asyncio task <asyncio.Task>`."""
wrapped = self._run_coro(coro, event_name, *args, **kwargs)
return asyncio.create_task(wrapped)
[docs] def dispatch(self, event: str, *args, **kwargs):
"""Dispatches an event to the correct handler."""
method = "on_" + event
try:
coro = getattr(self, method)
except AttributeError:
pass
else:
self._schedule_coro(coro, method, *args, **kwargs)
# --------------------------------------------------------------------------
# Websocket events
[docs] async def main(self):
"""Runs in a loop until the websocket is closed."""
[docs] async def on_connect(self):
"""Called on websocket connection.
If this function raises an exception, the websocket connection wont be
accepted.
"""
[docs] async def on_message(self, message: str):
"""Called on every websocket message.
Parameters
----------
message : :class:`str`
The websocket message
"""
[docs] async def on_disconnect(self, code: int):
"""Called on websocket disconnection.
If the server closed the connection, this is called before
:meth:`on_close`.
Parameters
----------
code : :class:`int`
The websocket close status code.
"""
[docs] async def on_close(self, code: int, reason: str):
"""Called when the websocket is closed by the server.
Parameters
----------
code : :class:`int`
The websocket close status code.
reason : :class:`str`
The reason why the websocket is closed.
"""
[docs] async def on_error(
self, event_name: str, error: Exception, *args, **kwargs
):
"""Called when an event errors.
Parameters
----------
event_name : :class:`str`
The event that raised the error.
error : :class:`str`
The error that was raised.
*args : :class:`tuple`
The arguments of the event.
**kwargs : :class:`dict`
The keyword arguments of the event.
"""
if event_name not in ["on_connect", "on_disconnect", "on_close"]:
close = CloseServerError(
description=(
(f"Exception in {event_name}: " + str(error))
if self.app.debug
else "Internal server error while operating."
)
)
await self.close(close.close_code, str(close))
traceback.print_exc()