import asyncio
from collections.abc import Coroutine
from logging import Logger
import os
import platform
from typing import TYPE_CHECKING, Any, cast
from osmium_chat import __version__
from websockets.asyncio.client import ClientConnection, connect
from websockets.exceptions import ConnectionClosed
from osmium_protos import (
PB_FileMetadataMetadataImage,
unwrap,
wrap,
PB_Authorization,
PB_Authorize,
PB_FileMetadata,
PB_FileMetadataMetadataCustomEmoji,
PB_FileMetadataMetadataFile,
PB_Initialize,
PB_MediaRef,
PB_MediaRefUploadedFile,
PB_RpcResult,
PB_ServerMessage,
PB_StickerPackRef,
PB_UpdateMessageCreated,
PB_UploadFilePart,
PB_UploadedFileRef,
)
from osmium_protos.osmium.client.auth import Authorization
from osmium_chat.errors import RequestError
from osmium_chat.user.user import User
if TYPE_CHECKING:
from osmium_chat.bot import Bot
[docs]
class Client:
"""Low-level WebSocket transport between a :class:`~osmium_chat.bot.Bot`
and the Osmium gateway.
Handles connecting, the initialize/authorize handshake, sending protobuf
messages, and reading the inbound message stream.
"""
__slots__: tuple[str, ...] = (
"bot",
"id",
"_connection",
"__session_id",
"__token",
"_logger",
"_pending",
"_req_counter",
"_tasks",
)
WS_URL: str = "wss://ws-0.osmium.chat"
def __init__(
self,
client_id: int,
bot: "Bot",
*,
logger: Logger | None = None,
) -> None:
"""Create a client bound to ``bot``.
:param client_id: The Osmium client id to identify as.
:param bot: The owning bot, used to dispatch events and store the user.
:param logger: Optional logger; a default one is created if omitted.
"""
self.bot: "Bot" = bot
self.id: int = client_id
self._connection: ClientConnection | None = None
self.__session_id: int | None = None
self.__token: str | None = None
self._logger = logger or Logger(__name__)
# Outstanding requests awaiting an ``RpcResult``, keyed by request id.
self._pending: dict[int, asyncio.Future[PB_RpcResult]] = {}
# Monotonic source of request ids; starts at 1 so it never collides with
# the ``id=0`` used by fire-and-forget :meth:`send_pb` calls.
self._req_counter: int = 0
# Strong references to in-flight dispatch tasks so they aren't GC'd.
self._tasks: set[asyncio.Task[None]] = set()
async def _handle_msg(self, message: Any) -> None:
"""Process a single decoded inbound message.
Routes ``message_created`` updates into the bot's command pipeline; all
other messages are logged for now.
:param message: The unwrapped protobuf message.
"""
self._logger.debug(f"Received message: {message}")
if isinstance(message, PB_UpdateMessageCreated):
await self.bot.process_commands(message)
async def _handle_ws(self, **kwargs: Any) -> None:
"""Read and dispatch inbound messages until the connection closes.
Responses to outstanding :meth:`request` calls are resolved inline so the
correlation stays ordered, while everything else is dispatched on its own
task. Dispatching concurrently is what lets a command await a
:meth:`request` response: the read loop keeps draining frames (including
that very response) instead of blocking on the handler.
"""
assert self._connection is not None
async for data in self._connection:
try:
server = PB_ServerMessage.parse(cast(bytes, data))
except ConnectionClosed as e:
self._logger.error("WebSocket connection closed: %s", e)
raise ConnectionError("Connection closed unexpectedly") from e
except Exception:
self._logger.exception("Failed to parse inbound frame")
continue
if self._resolve_result(server):
continue
self._spawn(self._dispatch_frame(cast(bytes, data)))
def _resolve_result(self, server: PB_ServerMessage) -> bool:
"""Hand a server frame to the request that's waiting on it, if any.
:param server: The parsed top-level ``ServerMessage``.
:returns: ``True`` if the frame was a result that matched a pending
request (and was consumed here), ``False`` if it should fall through
to normal dispatch.
"""
result = server.result
if result is None:
return False
future = self._pending.get(result.req_id)
if future is None or future.done():
return False
if result.error is not None:
future.set_exception(RequestError(result.error.error_code, result.error.error_message))
else:
future.set_result(result)
return True
async def _dispatch_frame(self, data: bytes) -> None:
"""Decode a non-result frame to its leaf payload and handle it.
:param data: The raw frame bytes.
"""
try:
_, message = unwrap(data)
await self._handle_msg(message)
except Exception:
self._logger.exception("Error dispatching inbound frame")
def _spawn(self, coro: Coroutine[Any, Any, None]) -> None:
"""Schedule ``coro`` as a tracked background task.
:param coro: The coroutine to run; a strong reference to its task is kept
until it completes so it isn't garbage collected mid-flight.
"""
task = asyncio.ensure_future(coro)
self._tasks.add(task)
task.add_done_callback(self._tasks.discard)
[docs]
async def send_pb(self, message: Any) -> None:
"""Wrap, serialize, and send a protobuf message over the connection.
This is fire-and-forget: it sends with request id ``0`` and does not wait
for a reply. Use :meth:`request` when you need the server's response.
:param message: The protobuf message to send.
"""
assert self._connection is not None
await self._connection.send(wrap(message).SerializeToString())
[docs]
async def request(self, payload: Any, *, timeout: float = 60.0) -> PB_RpcResult:
"""Send a request and wait for the gateway's matching ``RpcResult``.
Tags the outbound frame with a unique request id, registers a future for
it, and resolves once the server replies with a result carrying the same
id. This requires the read loop (:meth:`_handle_ws`) to be running, which
it is for the whole lifetime of a connected bot.
:param payload: The request protobuf to send.
:param timeout: How long to wait for the response, in seconds.
:returns: The :class:`~osmium_protos.PB_RpcResult` for this request.
:raises RequestError: If the gateway answers with an error.
:raises TimeoutError: If no response arrives within ``timeout``.
"""
assert self._connection is not None
self._req_counter += 1
req_id = self._req_counter
future: asyncio.Future[PB_RpcResult] = asyncio.get_running_loop().create_future()
self._pending[req_id] = future
try:
await self._connection.send(wrap(payload, id=req_id).SerializeToString())
return await asyncio.wait_for(future, timeout)
finally:
self._pending.pop(req_id, None)
def _handle_authorization(self, message: Authorization) -> None:
"""Store the session id/token and the authenticated user from an
authorization response.
:param message: The authorization payload from the gateway.
"""
self.__session_id = message.session_id
self.__token = message.token
if message.user:
self.bot.user = User(message.user, self)
# 512 KiB per part — small enough for reliable delivery over WebSocket.
_UPLOAD_CHUNK_SIZE: int = 512 * 1024
[docs]
async def upload_file(
self,
data: bytes,
filename: str,
mimetype: str = "application/octet-stream",
send_as_file: bool = False,
) -> tuple[PB_UploadedFileRef, PB_MediaRef]:
"""Upload ``data`` to the Osmium media service in chunks.
Splits ``data`` into parts of up to :attr:`_UPLOAD_CHUNK_SIZE` bytes,
sends each part sequentially, and returns the server-side ref together
with a ready-to-use :class:`~osmium_protos.PB_MediaRef` for attaching
to a :class:`~osmium_protos.PB_SendMessage`.
:param data: The raw file bytes to upload.
:param filename: The file name that will be shown to recipients.
:param mimetype: The MIME type of the file; defaults to
``application/octet-stream``.
:returns: A ``(PB_UploadedFileRef, PB_MediaRef)`` pair.
:raises RequestError: If the gateway rejects any part.
"""
upload_id = int.from_bytes(os.urandom(8), "big")
chunk_size = self._UPLOAD_CHUNK_SIZE
parts = [data[i:i + chunk_size] for i in range(0, max(len(data), 1), chunk_size)]
for index, chunk in enumerate(parts):
await self.request(PB_UploadFilePart(
upload_id=upload_id,
part=index,
data=chunk,
))
file_ref = PB_UploadedFileRef(
id=upload_id,
name=filename,
part_count=len(parts),
)
# if you mark `send_as_file=True` it will be sent as a downloadable file (not embedding!), mimicking the client behavior.
# otherwise, it sends it with image metadata... which gets automatically inferred to the correct metadata server-side
# yes, you can send a text file with image metadata and that will server-side automatically infer to a text file.
if send_as_file:
metadata = PB_FileMetadata(file=PB_FileMetadataMetadataFile())
else:
metadata = PB_FileMetadata(image=PB_FileMetadataMetadataImage())
media_ref = PB_MediaRef(uploaded=PB_MediaRefUploadedFile(
file=file_ref,
filename=filename,
mimetype=mimetype,
metadata=metadata,
))
return file_ref, media_ref
[docs]
async def upload_emoji_image(
self,
data: bytes,
name: str,
mimetype: str,
community_id: int,
) -> PB_MediaRefUploadedFile:
"""Upload emoji image bytes and return a ref ready for :class:`~osmium_protos.PB_AddStickerToPack`.
Splits ``data`` into chunks, uploads them, and wraps the result in a
:class:`~osmium_protos.PB_MediaRefUploadedFile` carrying
custom-emoji metadata pointing at ``community_id``'s sticker pack.
:param data: Raw image bytes (PNG or WebP recommended).
:param name: The emoji short name stored in the file metadata.
:param mimetype: The MIME type of the image.
:param community_id: The community whose emoji pack this image will
be added to; used to populate the ``pack`` metadata field.
:returns: A :class:`~osmium_protos.PB_MediaRefUploadedFile` ready to
pass as the ``sticker`` argument of
:class:`~osmium_protos.PB_AddStickerToPack`.
:raises RequestError: If the gateway rejects any part.
"""
upload_id = int.from_bytes(os.urandom(8), "big")
chunk_size = self._UPLOAD_CHUNK_SIZE
parts = [data[i:i + chunk_size] for i in range(0, max(len(data), 1), chunk_size)]
for index, chunk in enumerate(parts):
await self.request(PB_UploadFilePart(
upload_id=upload_id,
part=index,
data=chunk,
))
file_ref = PB_UploadedFileRef(
id=upload_id,
name=name,
part_count=len(parts),
)
return PB_MediaRefUploadedFile(
file=file_ref,
filename=name,
mimetype=mimetype,
metadata=PB_FileMetadata(
custom_emoji=PB_FileMetadataMetadataCustomEmoji(
emoji=name,
pack=PB_StickerPackRef(id=community_id),
)
),
)
[docs]
async def connect(self, token: str) -> None:
"""Open the connection, run the handshake, and process messages.
Performs the initialize/authorize exchange, dispatches the bot's
``connect`` event once authorized, then blocks reading messages until
the connection closes.
:param token: The authorization token for the bot.
:raises ConnectionError: If the connection closes before authorization.
"""
self._logger.info("Connecting to WebSocket server...")
self._connection = await connect(
uri=self.WS_URL,
)
self._logger.info("Connected to WebSocket server, initializing...")
await self.send_pb(PB_Initialize(
client_id=self.id,
device_type="Library[Python/OsmiumChat]",
device_version=__version__,
app_version=f"OsmiumChat Python API Wrapper (Python {platform.python_version()}) (OsmiumChat {__version__})",
no_subscribe=False,
))
self._logger.info("Received initialization response, getting entry points...")
# this will return entry points and vapidPublicKey, but we don't need them for now
await self._connection.recv()
self._logger.info("Received initialization response, sending authorization...")
await self.send_pb(PB_Authorize(
token=token,
))
# wait for authorization
# most of the time it is the first or second message, but to be safe we will loop until we get it
self._logger.info("Waiting for authorization...")
async for data in self._connection:
try:
_, message = unwrap(cast(bytes, data))
if isinstance(message, PB_Authorization):
self._handle_authorization(message)
break
else:
await self._handle_msg(message)
except ConnectionClosed:
self._logger.error("Connection closed while waiting for authorization")
raise ConnectionError("Connection closed while waiting for authorization")
self._logger.info("Authorized successfully, dispatching connect event...")
await self.bot.dispatch("connect")
self._logger.info("Starting message handler...")
await self._handle_ws()