bring type completeness to 100%

This commit is contained in:
Zomatree
2023-05-20 03:04:52 +01:00
parent dfb45494ba
commit 60ff8d81c9
33 changed files with 398 additions and 332 deletions
+8 -3
View File
@@ -2,15 +2,20 @@ on: [push, pull_request]
name: pyright
jobs:
pyright:
strategy:
matrix:
version: ["3.9", "3.10", "3.11"]
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- uses: actions/setup-python@v2
with:
python-version: '3.9'
python-version: ${{ matrix.version }}
cache: pip
- run: pip install .[speedups,docs]
- uses: jakebailey/pyright-action@v1
with:
lib: true
python-version: 3.9
python-version: ${{ matrix.version }}
working-directory: revolt
verify-types: true
+2 -2
View File
@@ -11,10 +11,10 @@ upload:
python -m twine upload dist/*
lint:
pyright --lib
pyright .
coverage:
pyright --lib --ignoreexternal --verifytypes revolt
pyright --ignoreexternal --verifytypes revolt
docs:
cd docs && make html
+20 -21
View File
@@ -42,14 +42,16 @@ class Asset(Ulid):
__slots__ = ("state", "id", "tag", "size", "filename", "content_type", "width", "height", "type", "url")
def __init__(self, data: FilePayload, state: State):
self.state = state
self.state: State = state
self.id = data['_id']
self.tag = data['tag']
self.size = data['size']
self.filename = data['filename']
self.id: str = data['_id']
self.tag: str = data['tag']
self.size: int = data['size']
self.filename: str = data['filename']
metadata = data['metadata']
self.height: int | None
self.width: int | None
if metadata["type"] == "Image" or metadata["type"] == "Video": # cannot use `in` because type narrowing will not happen
self.height = metadata["height"]
@@ -58,17 +60,17 @@ class Asset(Ulid):
self.height = None
self.width = None
self.content_type = data["content_type"]
self.type = AssetType(metadata["type"])
self.content_type: str | None = data["content_type"]
self.type: AssetType = AssetType(metadata["type"])
base_url = self.state.api_info["features"]["autumn"]["url"]
self.url = f"{base_url}/{self.tag}/{self.id}"
self.url: str = f"{base_url}/{self.tag}/{self.id}"
async def read(self) -> bytes:
"""Reads the files content into bytes"""
return await self.state.http.request_file(self.url)
async def save(self, fp: IOBase):
async def save(self, fp: IOBase) -> None:
"""Reads the files content and saves it to a file
Parameters
@@ -85,8 +87,6 @@ class PartialAsset(Asset):
-----------
id: :class:`str`
The id of the asset, this will always be ``"0"``
tag: Optional[:class:`str`]
The tag of the asset, this corrasponds to where the asset is used, this will always be ``None``
size: :class:`int`
Amount of bytes in the file, this will always be ``0``
filename: :class:`str`
@@ -102,13 +102,12 @@ class PartialAsset(Asset):
"""
def __init__(self, url: str, state: State):
self.state = state
self.id = "0"
self.tag = None
self.size = 0
self.filename = ""
self.height = None
self.width = None
self.content_type = mimetypes.guess_extension(url)
self.type = AssetType.file
self.url = url
self.state: State = state
self.id: str = "0"
self.size: int = 0
self.filename: str = ""
self.height: int | None = None
self.width: int | None = None
self.content_type: str | None = mimetypes.guess_extension(url)
self.type: AssetType = AssetType.file
self.url: str = url
+4 -4
View File
@@ -25,10 +25,10 @@ class Category(Ulid):
"""
def __init__(self, data: CategoryPayload, state: State):
self.state = state
self.name = data["title"]
self.id = data["id"]
self.channel_ids = data["channels"]
self.state: State = state
self.name: str = data["title"]
self.id: str = data["id"]
self.channel_ids: list[str] = data["channels"]
@property
def channels(self) -> list[Channel]:
+26 -21
View File
@@ -31,7 +31,7 @@ class EditableChannel:
state: State
id: str
async def edit(self, **kwargs: Any):
async def edit(self, **kwargs: Any) -> None:
"""Edits the channel
Passing ``None`` to the parameters that accept it will remove them.
@@ -80,18 +80,18 @@ class Channel(Ulid):
__slots__ = ("state", "id", "channel_type", "server_id")
def __init__(self, data: ChannelPayload, state: State):
self.state = state
self.id = data["_id"]
self.channel_type = ChannelType(data["channel_type"])
self.state: State = state
self.id: str = data["_id"]
self.channel_type: ChannelType = ChannelType(data["channel_type"])
self.server_id: Optional[str] = None
async def _get_channel_id(self) -> str:
return self.id
def _update(self, **_: Any):
def _update(self, **_: Any) -> None:
pass
async def delete(self):
async def delete(self) -> None:
"""Deletes or closes the channel"""
await self.state.http.close_channel(self.id)
@@ -134,7 +134,7 @@ class DMChannel(Channel, Messageable):
def __init__(self, data: DMChannelPayload, state: State):
super().__init__(data, state)
self.recipient_ids: tuple[str, str] = tuple(data["recipients"])
self.last_message_id = data.get("last_message_id")
self.last_message_id: str | None = data.get("last_message_id")
@property
def recipients(self) -> tuple[User, User]:
@@ -190,20 +190,22 @@ class GroupDMChannel(Channel, Messageable, EditableChannel):
def __init__(self, data: GroupDMChannelPayload, state: State):
super().__init__(data, state)
self.recipient_ids = data["recipients"]
self.name = data["name"]
self.owner_id = data["owner"]
self.description: Optional[str] = data.get("description")
self.last_message_id = data.get("last_message_id")
self.recipient_ids: list[str] = data["recipients"]
self.name: str = data["name"]
self.owner_id: str = data["owner"]
self.description: str | None = data.get("description")
self.last_message_id: str | None = data.get("last_message_id")
self.icon: Asset | None
if icon := data.get("icon"):
self.icon = Asset(icon, state)
else:
self.icon = None
self.permissions = Permissions(data.get("permissions", 0))
self.permissions: Permissions = Permissions(data.get("permissions", 0))
def _update(self, *, name: Optional[str] = None, recipients: Optional[list[str]] = None, description: Optional[str] = None):
def _update(self, *, name: Optional[str] = None, recipients: Optional[list[str]] = None, description: Optional[str] = None) -> None:
if name is not None:
self.name = name
@@ -263,12 +265,12 @@ class ServerChannel(Channel):
def __init__(self, data: ServerChannelPayload, state: State):
super().__init__(data, state)
self.server_id = data["server"]
self.name = data["name"]
self.server_id: str = data["server"]
self.name: str = data["name"]
self.description: Optional[str] = data.get("description")
self.nsfw = data.get("nsfw", False)
self.active = False
self.default_permissions = PermissionsOverwrite._from_overwrite(data.get("default_permissions", {"a": 0, "d": 0}))
self.nsfw: bool = data.get("nsfw", False)
self.active: bool = False
self.default_permissions: PermissionsOverwrite = PermissionsOverwrite._from_overwrite(data.get("default_permissions", {"a": 0, "d": 0}))
permissions: dict[str, PermissionsOverwrite] = {}
@@ -276,7 +278,10 @@ class ServerChannel(Channel):
overwrite = PermissionsOverwrite._from_overwrite(overwrite_data)
permissions[role_name] = overwrite
self.permissions = permissions
self.permissions: dict[str, PermissionsOverwrite] = permissions
self.icon: Asset | None
if icon := data.get("icon"):
self.icon = Asset(icon, state)
else:
@@ -359,7 +364,7 @@ class TextChannel(ServerChannel, Messageable, EditableChannel):
def __init__(self, data: TextChannelPayload, state: State):
super().__init__(data, state)
self.last_message_id = data.get("last_message_id")
self.last_message_id: str | None = data.get("last_message_id")
async def _get_channel_id(self) -> str:
return self.id
+12 -12
View File
@@ -30,7 +30,7 @@ if TYPE_CHECKING:
__all__ = ("Client",)
logger = logging.getLogger("revolt")
logger: logging.Logger = logging.getLogger("revolt")
class Client:
"""The client for interacting with revolt
@@ -48,11 +48,11 @@ class Client:
"""
def __init__(self, session: aiohttp.ClientSession, token: str, *, api_url: str = "https://api.revolt.chat", max_messages: int = 5000, bot: bool = True):
self.session = session
self.token = token
self.api_url = api_url
self.max_messages = max_messages
self.bot = bot
self.session: aiohttp.ClientSession = session
self.token: str = token
self.api_url: str = api_url
self.max_messages: int = max_messages
self.bot: bool = bot
self.api_info: ApiInfo
self.http: HttpClient
@@ -63,7 +63,7 @@ class Client:
super().__init__()
def dispatch(self, event: str, *args: Any):
def dispatch(self, event: str, *args: Any) -> None:
"""Dispatch an event, this is typically used for testing and internals.
Parameters
@@ -88,7 +88,7 @@ class Client:
async with self.session.get(self.api_url) as resp:
return json.loads(await resp.text())
async def start(self):
async def start(self) -> None:
"""Starts the client"""
api_info = await self.get_api_info()
@@ -98,7 +98,7 @@ class Client:
self.websocket = WebsocketHandler(self.session, self.token, api_info["ws"], self.dispatch, self.state)
await self.websocket.start()
async def stop(self):
async def stop(self) -> None:
await self.websocket.websocket.close()
def get_user(self, id: str) -> User:
@@ -300,7 +300,7 @@ class Client:
raise LookupError
async def edit_self(self, **kwargs: Any):
async def edit_self(self, **kwargs: Any) -> None:
"""Edits the client's own user
Parameters
@@ -316,7 +316,7 @@ class Client:
await self.state.http.edit_self(remove, kwargs)
async def edit_status(self, **kwargs: Any):
async def edit_status(self, **kwargs: Any) -> None:
"""Edits the client's own status
Parameters
@@ -337,7 +337,7 @@ class Client:
await self.state.http.edit_self(remove, {"status": kwargs})
async def edit_profile(self, **kwargs: Any):
async def edit_profile(self, **kwargs: Any) -> None:
"""Edits the client's own profile
Parameters
+26 -21
View File
@@ -4,6 +4,8 @@ from typing import TYPE_CHECKING, Optional, TypedDict, Union
from typing_extensions import NotRequired, Unpack
from revolt.types.embed import WebsiteSpecial
from .asset import Asset
from .enums import EmbedType
@@ -14,6 +16,7 @@ if TYPE_CHECKING:
from .types import SendableEmbed as SendableEmbedPayload
from .types import TextEmbed as TextEmbedPayload
from .types import WebsiteEmbed as WebsiteEmbedPayload
from .types import JanuaryImage, JanuaryVideo
__all__ = ("Embed", "WebsiteEmbed", "ImageEmbed", "TextEmbed", "NoneEmbed", "to_embed", "SendableEmbed")
@@ -21,43 +24,45 @@ class WebsiteEmbed:
type = EmbedType.website
def __init__(self, embed: WebsiteEmbedPayload):
self.url = embed.get("url")
self.special = embed.get("special")
self.title = embed.get("title")
self.description = embed.get("description")
self.image = embed.get("image")
self.video = embed.get("video")
self.site_name = embed.get("site_name")
self.icon_url = embed.get("icon_url")
self.colour = embed.get("colour")
self.url: str | None = embed.get("url")
self.special: WebsiteSpecial | None = embed.get("special")
self.title: str | None = embed.get("title")
self.description: str | None = embed.get("description")
self.image: JanuaryImage | None = embed.get("image")
self.video: JanuaryVideo | None = embed.get("video")
self.site_name: str | None = embed.get("site_name")
self.icon_url: str | None = embed.get("icon_url")
self.colour: str | None = embed.get("colour")
class ImageEmbed:
type = EmbedType.image
type: EmbedType = EmbedType.image
def __init__(self, image: ImageEmbedPayload):
self.url = image.get("url")
self.width = image.get("width")
self.height = image.get("height")
self.size = image.get("size")
self.url: str = image.get("url")
self.width: int = image.get("width")
self.height: int = image.get("height")
self.size: str = image.get("size")
class TextEmbed:
type = EmbedType.text
type: EmbedType = EmbedType.text
def __init__(self, embed: TextEmbedPayload, state: State):
self.icon_url = embed.get("icon_url")
self.url = embed.get("url")
self.title = embed.get("title")
self.description = embed.get("description")
self.icon_url: str | None = embed.get("icon_url")
self.url: str | None = embed.get("url")
self.title: str | None = embed.get("title")
self.description: str | None = embed.get("description")
self.media: Asset | None
if media := embed.get("media"):
self.media = Asset(media, state)
else:
self.media = None
self.colour = embed.get("colour")
self.colour: str | None = embed.get("colour")
class NoneEmbed:
type = EmbedType.none
type: EmbedType = EmbedType.none
Embed = Union[WebsiteEmbed, ImageEmbed, TextEmbed, NoneEmbed]
+9 -9
View File
@@ -1,6 +1,6 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING
from .utils import Ulid
@@ -30,16 +30,16 @@ class Emoji(Ulid):
The server id this emoji belongs to, if any
"""
def __init__(self, payload: EmojiPayload, state: State):
self.state = state
self.state: State = state
self.id = payload["_id"]
self.author_id = payload["creator_id"]
self.name = payload["name"]
self.animated = payload.get("animated", False)
self.nsfw = payload.get("nsfw", False)
self.server_id: Optional[str] = payload["parent"].get("id")
self.id: str = payload["_id"]
self.author_id: str = payload["creator_id"]
self.name: str = payload["name"]
self.animated: bool = payload.get("animated", False)
self.nsfw: bool = payload.get("nsfw", False)
self.server_id: str | None = payload["parent"].get("id")
async def delete(self):
async def delete(self) -> None:
"""Deletes the emoji."""
await self.state.http.delete_emoji(self.id)
+15 -10
View File
@@ -1,6 +1,7 @@
from __future__ import annotations
from typing import Any, Callable, Coroutine, TypeVar, Union, cast
from typing import Any, Callable, Coroutine, Union, cast
from typing_extensions import TypeVar
import revolt
@@ -12,11 +13,11 @@ from .utils import ClientT
__all__ = ("check", "Check", "is_bot_owner", "is_server_owner", "has_permissions", "has_channel_permissions")
T = TypeVar("T", Callable[..., Any], Command)
T = TypeVar("T", Callable[..., Any], Command, default=Command)
Check = Callable[[Context[ClientT]], Union[Any, Coroutine[Any, Any, Any]]]
def check(check: Check[ClientT]):
def check(check: Check[ClientT]) -> Callable[[T], T]:
"""A decorator for adding command checks
Parameters
@@ -37,7 +38,7 @@ def check(check: Check[ClientT]):
return inner
def is_bot_owner():
def is_bot_owner() -> Callable[[T], T]:
"""A command check for limiting the command to only the bot's owner"""
@check
def inner(context: Context[ClientT]):
@@ -48,10 +49,10 @@ def is_bot_owner():
return inner
def is_server_owner():
def is_server_owner() -> Callable[[T], T]:
"""A command check for limiting the command to only a server's owner"""
@check
def inner(context: Context[ClientT]):
def inner(context: Context[ClientT]) -> bool:
if not context.server_id:
raise ServerOnly
@@ -62,19 +63,21 @@ def is_server_owner():
return inner
def has_permissions(**permissions: bool):
def has_permissions(**permissions: bool) -> Callable[[T], T]:
@check
def inner(context: Context[ClientT]):
def inner(context: Context[ClientT]) -> bool:
author = context.author
if not author.has_permissions(**permissions):
raise MissingPermissionsError(permissions)
return True
return inner
def has_channel_permissions(**permissions: bool):
def has_channel_permissions(**permissions: bool) -> Callable[[T], T]:
@check
def inner(context: Context[ClientT]):
def inner(context: Context[ClientT]) -> bool:
author = context.author
if not isinstance(author, revolt.Member):
@@ -83,4 +86,6 @@ def has_channel_permissions(**permissions: bool):
if not author.has_channel_permissions(context.channel, **permissions):
raise MissingPermissionsError(permissions)
return True
return inner
+7 -10
View File
@@ -36,7 +36,7 @@ class ExtensionProtocol(Protocol):
class CommandsMeta(type):
_commands: list[Command[Any]]
def __new__(cls, name: str, bases: tuple[type, ...], attrs: dict[str, Any]):
def __new__(cls, name: str, bases: tuple[type, ...], attrs: dict[str, Any]) -> Self:
commands: list[Command[Any]] = []
self = super().__new__(cls, name, bases, attrs)
for base in reversed(self.__mro__):
@@ -139,7 +139,7 @@ class CommandsClient(revolt.Client, metaclass=CommandsMeta):
"""
return self.all_commands[name]
def add_command(self, command: Command[Self]):
def add_command(self, command: Command[Self]) -> None:
"""Adds a command, this is typically only used for dynamic commands, you should use the `commands.command` decorator for most usecases.
Parameters
@@ -196,9 +196,6 @@ class CommandsClient(revolt.Client, metaclass=CommandsMeta):
"""
content = message.content
if not isinstance(content, str):
return
prefixes = await self.get_prefix(message)
if isinstance(prefixes, str):
@@ -248,7 +245,7 @@ class CommandsClient(revolt.Client, metaclass=CommandsMeta):
await command._error_handler(command.cog or self, context, e)
self.dispatch("command_error", context, e)
async def on_command_error(self, ctx: Context[Self], error: Exception, /):
async def on_command_error(self, ctx: Context[Self], error: Exception, /) -> None:
traceback.print_exception(type(error), error, error.__traceback__)
on_message = process_commands
@@ -268,7 +265,7 @@ class CommandsClient(revolt.Client, metaclass=CommandsMeta):
return True
def add_cog(self, cog: Cog[Self]):
def add_cog(self, cog: Cog[Self]) -> None:
"""Adds a cog to the bot, this cog must subclass `Cog`.
Parameters
@@ -296,7 +293,7 @@ class CommandsClient(revolt.Client, metaclass=CommandsMeta):
return cog
def load_extension(self, name: str):
def load_extension(self, name: str) -> None:
"""Loads an extension, this takes a module name and runs the setup function inside of it.
Parameters
@@ -312,7 +309,7 @@ class CommandsClient(revolt.Client, metaclass=CommandsMeta):
self.extensions[name] = extension
extension.setup(self)
def unload_extension(self, name: str):
def unload_extension(self, name: str) -> None:
"""Unloads an extension, this takes a module name and runs the teardown function inside of it.
Parameters
@@ -327,7 +324,7 @@ class CommandsClient(revolt.Client, metaclass=CommandsMeta):
if teardown := getattr(extension, "teardown", None):
teardown(self)
def reload_extension(self, name: str):
def reload_extension(self, name: str) -> None:
"""Reloads an extension, this will unload and reload the extension.
Parameters
+6 -5
View File
@@ -1,6 +1,7 @@
from __future__ import annotations
from typing import Any, Generic, Optional, cast
from typing_extensions import Self
from .command import Command
from .utils import ClientT
@@ -11,7 +12,7 @@ class CogMeta(type, Generic[ClientT]):
_commands: list[Command[ClientT]]
qualified_name: str
def __new__(cls, name: str, bases: tuple[type, ...], attrs: dict[str, Any], *, qualified_name: Optional[str] = None):
def __new__(cls, name: str, bases: tuple[type, ...], attrs: dict[str, Any], *, qualified_name: Optional[str] = None) -> Self:
commands: list[Command[ClientT]] = []
self = super().__new__(cls, name, bases, attrs)
@@ -29,15 +30,15 @@ class Cog(Generic[ClientT], metaclass=CogMeta):
_commands: list[Command[ClientT]]
qualified_name: str
def cog_load(self):
def cog_load(self) -> None:
"""A special method that is called when the cog gets loaded."""
pass
def cog_unload(self):
def cog_unload(self) -> None:
"""A special method that is called when the cog gets removed."""
pass
def _inject(self, client: ClientT):
def _inject(self, client: ClientT) -> None:
client.cogs[self.qualified_name] = self
for command in self._commands:
@@ -46,7 +47,7 @@ class Cog(Generic[ClientT], metaclass=CogMeta):
self.cog_load()
def _uninject(self, client: ClientT):
def _uninject(self, client: ClientT) -> None:
for name, command in client.all_commands.copy().items():
if command in self._commands:
del client.all_commands[name]
+15 -14
View File
@@ -5,6 +5,7 @@ import traceback
from contextlib import suppress
from typing import (TYPE_CHECKING, Annotated, Any, Callable, Coroutine,
Generic, Literal, Optional, Union, get_args, get_origin)
from typing_extensions import ParamSpec
from revolt.utils import copy_doc, maybe_coroutine
@@ -17,13 +18,13 @@ if TYPE_CHECKING:
from .context import Context
from .group import Group
__all__ = (
__all__: tuple[str, ...] = (
"Command",
"command"
)
NoneType = type(None)
NoneType: type[None] = type(None)
P = ParamSpec("P")
class Command(Generic[ClientCoT]):
"""Class for holding info about a command.
@@ -46,17 +47,17 @@ class Command(Generic[ClientCoT]):
__slots__ = ("callback", "name", "aliases", "signature", "checks", "parent", "_error_handler", "cog", "description", "usage", "parameters")
def __init__(self, callback: Callable[..., Coroutine[Any, Any, Any]], name: str, aliases: list[str], usage: Optional[str] = None):
self.callback = callback
self.name = name
self.aliases = aliases
self.usage = usage
self.signature = inspect.signature(self.callback)
self.parameters = evaluate_parameters(self.signature.parameters.values(), getattr(callback, "__globals__", {}))
self.callback: Callable[..., Coroutine[Any, Any, Any]] = callback
self.name: str = name
self.aliases: list[str] = aliases
self.usage: str | None = usage
self.signature: inspect.Signature = inspect.signature(self.callback)
self.parameters: list[inspect.Parameter] = evaluate_parameters(self.signature.parameters.values(), getattr(callback, "__globals__", {}))
self.checks: list[Check[ClientCoT]] = getattr(callback, "_checks", [])
self.parent: Optional[Group[ClientCoT]] = None
self.cog: Optional[Cog[ClientCoT]] = None
self._error_handler: Callable[[Any, Context[ClientCoT], Exception], Coroutine[Any, Any, Any]] = type(self)._default_error_handler
self.description = callback.__doc__
self.description: str | None = callback.__doc__
async def invoke(self, context: Context[ClientCoT], *args: Any, **kwargs: Any) -> Any:
"""Runs the command and calls the error handler if the command errors.
@@ -77,7 +78,7 @@ class Command(Generic[ClientCoT]):
def __call__(self, context: Context[ClientCoT], *args: Any, **kwargs: Any) -> Any:
return self.invoke(context, *args, **kwargs)
def error(self, func: Callable[..., Coroutine[Any, Any, Any]]):
def error(self, func: Callable[..., Coroutine[Any, Any, Any]]) -> Callable[..., Coroutine[Any, Any, Any]]:
"""Sets the error handler for the command.
Parameters
@@ -141,7 +142,7 @@ class Command(Generic[ClientCoT]):
else:
return arg
async def parse_arguments(self, context: Context[ClientCoT]):
async def parse_arguments(self, context: Context[ClientCoT]) -> None:
# please pr if you can think of a better way to do this
for parameter in self.parameters[2:]:
@@ -214,8 +215,8 @@ class Command(Generic[ClientCoT]):
return f"{' '.join(parents[::-1])} {self.name} {' '.join(parameters)}"
def command(*, name: Optional[str] = None, aliases: Optional[list[str]] = None, cls: type[Command[ClientCoT]] = Command, usage: Optional[str] = None):
"""A decorator that turns a function into a :class:`Command`.
def command(*, name: Optional[str] = None, aliases: Optional[list[str]] = None, cls: type[Command[ClientCoT]] = Command, usage: Optional[str] = None) -> Callable[[Callable[..., Coroutine[Any, Any, Any]]], Command[ClientCoT]]:
"""A decorator that turns a function into a :class:`Command`.n
Parameters
-----------
+11 -10
View File
@@ -11,6 +11,7 @@ from .utils import ClientCoT
if TYPE_CHECKING:
from .view import StringView
from revolt.state import State
__all__ = (
"Context",
@@ -46,17 +47,17 @@ class Context(revolt.Messageable, Generic[ClientCoT]):
return self.channel.id
def __init__(self, command: Optional[Command[ClientCoT]], invoked_with: str, view: StringView, message: revolt.Message, client: ClientCoT):
self.command = command
self.invoked_with = invoked_with
self.view = view
self.message = message
self.client = client
self.command: Command[ClientCoT] | None = command
self.invoked_with: str = invoked_with
self.view: StringView = view
self.message: revolt.Message = message
self.client: ClientCoT = client
self.args: list[Any] = []
self.kwargs: dict[str, Any] = {}
self.server_id = message.server_id
self.channel = message.channel
self.author = message.author
self.state = message.state
self.server_id: str | None = message.server_id
self.channel: revolt.TextChannel | revolt.GroupDMChannel | revolt.DMChannel = message.channel
self.author: revolt.Member | revolt.User = message.author
self.state: State = message.state
@property
def server(self) -> revolt.Server:
@@ -105,7 +106,7 @@ class Context(revolt.Messageable, Generic[ClientCoT]):
return all([await maybe_coroutine(check, self) for check in (command.checks if command else [])])
async def send_help(self, argument: Command[Any] | Group[Any] | ClientCoT | None = None):
async def send_help(self, argument: Command[Any] | Group[Any] | ClientCoT | None = None) -> None:
argument = argument or self.client
command = self.client.get_command("help")
+4 -4
View File
@@ -13,14 +13,14 @@ from .errors import (BadBoolArgument, CategoryConverterError,
if TYPE_CHECKING:
from .client import CommandsClient
__all__ = ("bool_converter", "category_converter", "channel_converter", "user_converter", "member_converter", "IntConverter", "BoolConverter", "CategoryConverter", "UserConverter", "MemberConverter", "ChannelConverter")
__all__: tuple[str, ...] = ("bool_converter", "category_converter", "channel_converter", "user_converter", "member_converter", "IntConverter", "BoolConverter", "CategoryConverter", "UserConverter", "MemberConverter", "ChannelConverter")
channel_regex = re.compile("<#([A-z0-9]{26})>")
user_regex = re.compile("<@([A-z0-9]{26})>")
channel_regex: re.Pattern[str] = re.compile("<#([A-z0-9]{26})>")
user_regex: re.Pattern[str] = re.compile("<@([A-z0-9]{26})>")
ClientT = TypeVar("ClientT", bound="CommandsClient")
def bool_converter(arg: str, _):
def bool_converter(arg: str, _: Context[ClientT]) -> bool:
lowered = arg.lower()
if lowered in ("yes", "true", "ye", "y", "1", "on", "enable"):
return True
+1 -1
View File
@@ -32,7 +32,7 @@ class CommandNotFound(CommandError):
__slots__ = ("command_name",)
def __init__(self, command_name: str):
self.command_name = command_name
self.command_name: str = command_name
class NoClosingQuote(CommandError):
"""Raised when there is no closing quote for a command argument"""
+4 -4
View File
@@ -26,13 +26,13 @@ class Group(Command[ClientCoT]):
The group's subcommands.
"""
__slots__ = ("subcommands",)
__slots__: tuple[str, ...] = ("subcommands",)
def __init__(self, callback: Callable[..., Coroutine[Any, Any, Any]], name: str, aliases: list[str]):
self.subcommands: dict[str, Command[ClientCoT]] = {}
super().__init__(callback, name, aliases)
def command(self, *, name: Optional[str] = None, aliases: Optional[list[str]] = None, cls: type[Command[ClientCoT]] = Command[ClientCoT]):
def command(self, *, name: Optional[str] = None, aliases: Optional[list[str]] = None, cls: type[Command[ClientCoT]] = Command[ClientCoT]) -> Callable[[Callable[..., Coroutine[Any, Any, Any]]], Command[ClientCoT]]:
"""A decorator that turns a function into a :class:`Command` and registers the command as a subcommand.
Parameters
@@ -57,7 +57,7 @@ class Group(Command[ClientCoT]):
return inner
def group(self, *, name: Optional[str] = None, aliases: Optional[list[str]] = None, cls: Optional[type[Group[ClientCoT]]] = None):
def group(self, *, name: Optional[str] = None, aliases: Optional[list[str]] = None, cls: Optional[type[Group[ClientCoT]]] = None) -> Callable[[Callable[..., Coroutine[Any, Any, Any]]], Group[ClientCoT]]:
"""A decorator that turns a function into a :class:`Group` and registers the command as a subcommand
Parameters
@@ -91,7 +91,7 @@ class Group(Command[ClientCoT]):
def commands(self) -> list[Command[ClientCoT]]:
return list(self.subcommands.values())
def group(*, name: Optional[str] = None, aliases: Optional[list[str]] = None, cls: type[Group[ClientT]] = Group):
def group(*, name: Optional[str] = None, aliases: Optional[list[str]] = None, cls: type[Group[ClientT]] = Group) -> Callable[[Callable[..., Coroutine[Any, Any, Any]]], Group[ClientT]]:
"""A decorator that turns a function into a :class:`Group`
Parameters
+8 -7
View File
@@ -66,7 +66,7 @@ class HelpCommand(ABC, Generic[ClientCoT]):
return cogs
async def handle_message(self, context: Context[ClientCoT], message: Message):
async def handle_message(self, context: Context[ClientCoT], message: Message) -> None:
pass
async def get_channel(self, context: Context) -> Messageable:
@@ -145,11 +145,11 @@ class DefaultHelpCommand(HelpCommand[ClientCoT]):
lines.append("```")
return "\n".join(lines)
async def handle_no_command_found(self, context: Context[ClientCoT], name: str):
async def handle_no_command_found(self, context: Context[ClientCoT], name: str) -> None:
channel = await self.get_channel(context)
await channel.send(f"Command `{name}` not found.")
async def handle_no_cog_found(self, context: Context[ClientCoT], name: str):
async def handle_no_cog_found(self, context: Context[ClientCoT], name: str) -> None:
channel = await self.get_channel(context)
await channel.send(f"Cog `{name}` not found.")
@@ -158,14 +158,14 @@ class HelpCommandImpl(Command[ClientCoT]):
def __init__(self, client: ClientCoT):
self.client = client
async def callback(_: Union[ClientCoT, Cog[ClientCoT]], context: Context[ClientCoT], *args: str):
async def callback(_: Union[ClientCoT, Cog[ClientCoT]], context: Context[ClientCoT], *args: str) -> None:
await help_command_impl(context.client, context, *args)
super().__init__(callback=callback, name="help", aliases=[])
self.description = "Shows help for a command, cog or the entire bot"
self.description: str | None = "Shows help for a command, cog or the entire bot"
async def help_command_impl(self: ClientT, context: Context[ClientT], *arguments: str):
async def help_command_impl(self: ClientT, context: Context[ClientT], *arguments: str) -> None:
help_command = self.help_command
if not help_command:
@@ -202,4 +202,5 @@ async def help_command_impl(self: ClientT, context: Context[ClientT], *arguments
else:
msg_payload = payload
await help_command.send_help_command(context, msg_payload)
message = await help_command.send_help_command(context, msg_payload)
await help_command.handle_message(context, message)
+5 -4
View File
@@ -1,13 +1,14 @@
from typing import Iterator
from .errors import NoClosingQuote
class StringView:
def __init__(self, string: str):
self.value = iter(string)
self.temp = ""
self.should_undo = False
self.value: Iterator[str] = iter(string)
self.temp: str = ""
self.should_undo: bool = False
def undo(self):
def undo(self) -> None:
self.should_undo = True
def next_char(self) -> str:
+8 -4
View File
@@ -1,5 +1,7 @@
from __future__ import annotations
import io
from typing import Optional, Union
from typing import Optional, Union, cast
__all__ = ("File",)
@@ -18,17 +20,19 @@ class File:
__slots__ = ("f", "spoiler", "filename")
def __init__(self, file: Union[str, bytes], *, filename: Optional[str] = None, spoiler: bool = False):
self.f: io.BufferedIOBase
if isinstance(file, str):
self.f = open(file, "rb")
else:
self.f = io.BytesIO(file)
if filename is None and isinstance(file, str):
filename = self.f.name
filename = cast(Optional[str], self.f.name)
self.spoiler = spoiler or (filename and filename.startswith("SPOILER_"))
self.spoiler: bool = spoiler or (bool(filename) and filename.startswith("SPOILER_"))
if self.spoiler and (filename and not filename.startswith("SPOILER_")):
filename = f"SPOILER_{filename}"
self.filename = filename
self.filename: str | None = filename
+5 -5
View File
@@ -11,8 +11,8 @@ class Flag:
__slots__ = ("flag", "__doc__")
def __init__(self, func: Callable[[], int]):
self.flag = func()
self.__doc__ = func.__doc__
self.flag: int = func()
self.__doc__: str | None = func.__doc__
@overload
def __get__(self: Self, instance: None, owner: type[Flags]) -> Self:
@@ -28,7 +28,7 @@ class Flag:
return instance._check_flag(self.flag)
def __set__(self, instance: Flags, value: bool):
def __set__(self, instance: Flags, value: bool) -> None:
instance._set_flag(self.flag, value)
class Flags:
@@ -52,7 +52,7 @@ class Flags:
def _check_flag(self, flag: int) -> bool:
return (self.value & flag) == flag
def _set_flag(self, flag: int, value: bool):
def _set_flag(self, flag: int, value: bool) -> None:
if value:
self.value |= flag
else:
@@ -85,7 +85,7 @@ class Flags:
def __gt__(self, other: Self) -> bool:
return self.value > other.value
def __repr__(self):
def __repr__(self) -> str:
return f"<{self.__class__.__name__} value={self.value}>"
def __iter__(self) -> Iterator[tuple[str, bool]]:
+16 -16
View File
@@ -46,11 +46,11 @@ class HttpClient:
__slots__ = ("session", "token", "api_url", "api_info", "auth_header")
def __init__(self, session: aiohttp.ClientSession, token: str, api_url: str, api_info: ApiInfo, bot: bool = True):
self.session = session
self.token = token
self.api_url = api_url
self.api_info = api_info
self.auth_header = "x-bot-token" if bot else "x-session-token"
self.session: aiohttp.ClientSession = session
self.token: str = token
self.api_url: str = api_url
self.api_info: ApiInfo = api_info
self.auth_header: str = "x-bot-token" if bot else "x-session-token"
async def request(self, method: Literal["GET", "POST", "PUT", "DELETE", "PATCH"], route: str, *, json: Optional[dict[str, Any]] = None, nonce: bool = True, params: Optional[dict[str, Any]] = None) -> Any:
url = f"{self.api_url}{route}"
@@ -359,19 +359,19 @@ class HttpClient:
def delete_invite(self, code: str) -> Request[None]:
return self.request("DELETE", f"/invites/{code}")
def edit_channel(self, channel_id: str, remove: list[str] | None, values: dict[str, Any]):
def edit_channel(self, channel_id: str, remove: list[str] | None, values: dict[str, Any]) -> Request[None]:
if remove:
values["remove"] = remove
return self.request("PATCH", f"/channels/{channel_id}", json=values)
def edit_role(self, server_id: str, role_id: str, remove: list[str] | None, values: dict[str, Any]):
def edit_role(self, server_id: str, role_id: str, remove: list[str] | None, values: dict[str, Any]) -> Request[None]:
if remove:
values["remove"] = remove
return self.request("PATCH", f"/servers/{server_id}/roles/{role_id}", json=values)
async def edit_self(self, remove: list[str] | None, values: dict[str, Any]):
async def edit_self(self, remove: list[str] | None, values: dict[str, Any]) -> Request[None]:
if remove:
values["remove"] = remove
@@ -392,25 +392,25 @@ class HttpClient:
def set_guild_channel_role_permissions(self, channel_id: str, role_id: str, allow: int, deny: int) -> Request[None]:
return self.request("PUT", f"/channels/{channel_id}/permissions/{role_id}", json={"permissions": {"allow": allow, "deny": deny}})
def set_group_channel_default_permissions(self, channel_id: str, value: int):
def set_group_channel_default_permissions(self, channel_id: str, value: int) -> Request[None]:
return self.request("PUT", f"/channels/{channel_id}/permissions/default", json={"permissions": value})
def set_server_role_permissions(self, server_id: str, role_id: str, allow: int, deny: int):
def set_server_role_permissions(self, server_id: str, role_id: str, allow: int, deny: int) -> Request[None]:
return self.request("PUT", f"/servers/{server_id}/permissions/{role_id}", json={"permissions": {"allow": allow, "deny": deny}})
def set_server_default_permissions(self, server_id: str, value: int):
def set_server_default_permissions(self, server_id: str, value: int) -> Request[None]:
return self.request("PUT", f"/servers/{server_id}/permissions/default", json={"permissions": value})
def add_reaction(self, channel_id: str, message_id: str, emoji: str):
def add_reaction(self, channel_id: str, message_id: str, emoji: str) -> Request[None]:
return self.request("PUT", f"/channels/{channel_id}/messages/{message_id}/reactions/{emoji}")
def remove_reaction(self, channel_id: str, message_id: str, emoji: str, user_id: Optional[str], remove_all: bool):
def remove_reaction(self, channel_id: str, message_id: str, emoji: str, user_id: Optional[str], remove_all: bool) -> Request[None]:
return self.request("PUT", f"/channels/{channel_id}/messages/{message_id}/reactions/{emoji}")
def remove_all_reactions(self, channel_id: str, message_id: str):
def remove_all_reactions(self, channel_id: str, message_id: str) -> Request[None]:
return self.request("DELETE", f"/channels/{channel_id}/messages/{message_id}/reactions")
def delete_emoji(self, emoji_id: str):
def delete_emoji(self, emoji_id: str) -> Request[None]:
return self.request("DELETE", f"/custom/emoji/{emoji_id}")
def fetch_emoji(self, emoji_id: str) -> Request[EmojiPayload]:
@@ -424,5 +424,5 @@ class HttpClient:
def edit_member(self, server_id: str, member_id: str, remove: list[str] | None, values: dict[str, Any]) -> Request[MemberPayload]:
return self.request("PATCH", f"/servers/{server_id}/members/{member_id}", json={"remove": remove, **values})
def delete_messages(self, channel_id: str, messages: list[str]):
def delete_messages(self, channel_id: str, messages: list[str]) -> Request[None]:
return self.request("DELETE", f"/channels/{channel_id}/messages/bulk", json={"ids": messages})
+16 -9
View File
@@ -2,12 +2,17 @@ from __future__ import annotations
from typing import TYPE_CHECKING
from .asset import Asset
from .utils import Ulid
if TYPE_CHECKING:
from .state import State
from .channel import Channel
from .server import Server
from .types import Invite as InvitePayload
from .user import User
__all__ = ("Invite",)
@@ -37,22 +42,24 @@ class Invite(Ulid):
__slots__ = ("state", "code", "id", "server", "channel", "user_name", "user_avatar", "user", "member_count")
def __init__(self, data: InvitePayload, code: str, state: State):
self.state = state
self.state: State = state
self.code = code
self.id = code
self.server = state.get_server(data["server_id"])
self.channel = self.server.get_channel(data["channel_id"])
self.code: str = code
self.id: str = code
self.server: Server = state.get_server(data["server_id"])
self.channel: Channel = self.server.get_channel(data["channel_id"])
self.user_name = data["user_name"]
self.user = None
self.user_name: str = data["user_name"]
self.user: User | None = None
self.user_avatar: Asset | None
if avatar := data.get("user_avatar"):
self.user_avatar = Asset(avatar, state)
else:
self.user_avatar = None
self.member_count = data["member_count"]
self.member_count: int = data["member_count"]
@staticmethod
def _from_partial(code: str, server: str, creator: str, channel: str, state: State) -> Invite:
@@ -69,6 +76,6 @@ class Invite(Ulid):
return invite
async def delete(self):
async def delete(self) -> None:
"""Deletes the invite"""
await self.state.http.delete_invite(self.code)
+19 -14
View File
@@ -22,7 +22,7 @@ if TYPE_CHECKING:
__all__ = ("Member",)
def flattern_user(member: Member, user: User):
def flattern_user(member: Member, user: User) -> None:
for attr in user.__flattern_attributes__:
setattr(member, attr, getattr(user, attr))
@@ -49,7 +49,9 @@ class Member(User):
flattern_user(self, user)
user._members[server.id] = self
self.state = state
self.state: State = state
self.guild_avatar: Asset | None
if avatar := data.get("avatar"):
self.guild_avatar = Asset(avatar, state)
@@ -57,20 +59,23 @@ class Member(User):
self.guild_avatar = None
roles = [server.get_role(role_id) for role_id in data.get("roles", [])]
self.roles = sorted(roles, key=lambda role: role.rank, reverse=True)
self.roles: list[Role] = sorted(roles, key=lambda role: role.rank, reverse=True)
self.server = server
self.nickname = data.get("nickname")
self.server: Server = server
self.nickname: str | None = data.get("nickname")
joined_at = data["joined_at"]
if isinstance(joined_at, int):
self.joined_at = datetime.datetime.fromtimestamp(joined_at / 1000)
self.joined_at: datetime.datetime = datetime.datetime.fromtimestamp(joined_at / 1000)
else:
self.joined_at = datetime.datetime.strptime(joined_at, "%Y-%m-%dT%H:%M:%S.%f%z")
self.current_timeout = None
self.joined_at: datetime.datetime = datetime.datetime.strptime(joined_at, "%Y-%m-%dT%H:%M:%S.%f%z")
self.current_timeout: datetime.datetime | None
if current_timeout := data.get("timeout"):
self.current_timeout = datetime.datetime.strptime(current_timeout, "%Y-%m-%dT%H:%M:%S.%f%z")
else:
self.current_timeout = None
@property
def avatar(self) -> Optional[Asset]:
@@ -93,11 +98,11 @@ class Member(User):
member_roles = [self.server.get_role(role_id) for role_id in roles]
self.roles = sorted(member_roles, key=lambda role: role.rank, reverse=True)
async def kick(self):
async def kick(self) -> None:
"""Kicks the member from the server"""
await self.state.http.kick_member(self.server.id, self.id)
async def ban(self, *, reason: Optional[str] = None):
async def ban(self, *, reason: Optional[str] = None) -> None:
"""Bans the member from the server
Parameters
@@ -107,7 +112,7 @@ class Member(User):
"""
await self.state.http.ban_member(self.server.id, self.id, reason)
async def unban(self):
async def unban(self) -> None:
"""Unbans the member from the server"""
await self.state.http.unban_member(self.server.id, self.id)
@@ -118,7 +123,7 @@ class Member(User):
roles: list[Role] | None | _Missing = Missing,
avatar: File | None | _Missing = Missing,
timeout: datetime.timedelta | None | _Missing = Missing
):
) -> None:
remove: list[str] = []
data: dict[str, Any] = {}
@@ -148,7 +153,7 @@ class Member(User):
await self.state.http.edit_member(self.server.id, self.id, remove, data)
async def timeout(self, length: datetime.timedelta):
async def timeout(self, length: datetime.timedelta) -> None:
"""Timeouts the member
Parameters
@@ -170,7 +175,7 @@ class Member(User):
"""
return calculate_permissions(self, self.server)
def get_channel_permissions(self, channel: Channel):
def get_channel_permissions(self, channel: Channel) -> Permissions:
"""Gets the permissions for the member in the server taking into account the channel as well
Parameters
+39 -24
View File
@@ -1,11 +1,13 @@
from __future__ import annotations
import datetime
from typing import TYPE_CHECKING, Any, Optional, Union
from typing import TYPE_CHECKING, Any, Coroutine, Optional, Union
from revolt.types.message import SystemMessageContent
from .asset import Asset, PartialAsset
from .channel import Messageable
from .embed import SendableEmbed, to_embed
from .channel import DMChannel, GroupDMChannel, TextChannel
from .embed import Embed, SendableEmbed, to_embed
from .utils import Ulid
if TYPE_CHECKING:
@@ -17,6 +19,7 @@ if TYPE_CHECKING:
from .types import Message as MessagePayload
from .types import MessageReplyPayload
from .user import User
from .member import Member
__all__ = (
"Message",
@@ -58,18 +61,28 @@ class Message(Ulid):
__slots__ = ("state", "id", "content", "attachments", "embeds", "channel", "author", "edited_at", "mentions", "replies", "reply_ids", "reactions", "interactions")
def __init__(self, data: MessagePayload, state: State):
self.state = state
self.state: State = state
self.id = data["_id"]
self.content = data.get("content", "")
self.attachments = [Asset(attachment, state) for attachment in data.get("attachments", [])]
self.embeds = [to_embed(embed, state) for embed in data.get("embeds", [])]
self.id: str = data["_id"]
content = data.get("content", "")
if not isinstance(content, str):
self.system_content: SystemMessageContent = content
self.content: str = ""
else:
self.content = content
self.attachments: list[Asset] = [Asset(attachment, state) for attachment in data.get("attachments", [])]
self.embeds: list[Embed] = [to_embed(embed, state) for embed in data.get("embeds", [])]
channel = state.get_channel(data["channel"])
assert isinstance(channel, Messageable)
self.channel = channel
assert isinstance(channel, Union[TextChannel, GroupDMChannel, DMChannel])
self.channel: TextChannel | GroupDMChannel | DMChannel = channel
self.server_id = self.channel.server_id
self.server_id: str | None = self.channel.server_id
self.mentions: list[Member | User]
if self.server_id:
author = state.get_member(self.server_id, data["author"])
@@ -78,7 +91,7 @@ class Message(Ulid):
author = state.get_user(data["author"])
self.mentions = [state.get_user(member_id) for member_id in data.get("mentions", [])]
self.author = author
self.author: Member | User = author
if masquerade := data.get("masquerade"):
if name := masquerade.get("name"):
@@ -109,6 +122,8 @@ class Message(Ulid):
for emoji, users in reactions.items():
self.reactions[emoji] = [self.state.get_user(user_id) for user_id in users]
self.interactions: MessageInteractions | None
if interactions := data.get("interactions"):
self.interactions = MessageInteractions(reactions=interactions.get("reactions"), restrict_reactions=interactions.get("restrict_reactions", False))
else:
@@ -143,7 +158,7 @@ class Message(Ulid):
"""Deletes the message. The bot can only delete its own messages and messages it has permission to delete """
await self.state.http.delete_message(self.channel.id, self.id)
def reply(self, *args: Any, mention: bool = False, **kwargs: Any):
def reply(self, *args: Any, mention: bool = False, **kwargs: Any) -> Coroutine[Any, Any, Message]:
"""Replies to this message, equivilant to:
.. code-block:: python
@@ -153,13 +168,13 @@ class Message(Ulid):
"""
return self.channel.send(*args, **kwargs, replies=[MessageReply(self, mention)])
async def add_reaction(self, emoji: str):
async def add_reaction(self, emoji: str) -> None:
await self.state.http.add_reaction(self.channel.id, self.id, emoji)
async def remove_reaction(self, emoji: str, user: Optional[User] = None, remove_all: bool = False):
async def remove_reaction(self, emoji: str, user: Optional[User] = None, remove_all: bool = False) -> None:
await self.state.http.remove_reaction(self.channel.id, self.id, emoji, user.id if user else None, remove_all)
async def remove_all_reactions(self):
async def remove_all_reactions(self) -> None:
await self.state.http.remove_all_reactions(self.channel.id, self.id)
@@ -187,8 +202,8 @@ class MessageReply:
__slots__ = ("message", "mention")
def __init__(self, message: Message, mention: bool = False):
self.message = message
self.mention = mention
self.message: Message = message
self.mention: bool = mention
def to_dict(self) -> MessageReplyPayload:
return { "id": self.message.id, "mention": self.mention }
@@ -208,9 +223,9 @@ class Masquerade:
__slots__ = ("name", "avatar", "colour")
def __init__(self, name: Optional[str] = None, avatar: Optional[str] = None, colour: Optional[str] = None):
self.name = name
self.avatar = avatar
self.colour = colour
self.name: str | None = name
self.avatar: str | None = avatar
self.colour: str | None = colour
def to_dict(self) -> MasqueradePayload:
output: MasqueradePayload = {}
@@ -239,10 +254,10 @@ class MessageInteractions:
__slots__ = ("reactions", "restrict_reactions")
def __init__(self, *, reactions: Optional[list[str]] = None, restrict_reactions: bool = False):
self.reactions = reactions
self.restrict_reactions = restrict_reactions
self.reactions: list[str] | None = reactions
self.restrict_reactions: bool = restrict_reactions
def to_dict(self):
def to_dict(self) -> InteractionsPayload:
output: InteractionsPayload = {}
if reactions := self.reactions:
+1 -1
View File
@@ -138,7 +138,7 @@ class Messageable:
payloads = await self.state.http.search_messages(await self._get_channel_id(), query, sort=sort, limit=limit, before=before, after=after)
return [Message(payload, self.state) for payload in payloads]
async def delete_messages(self, messages: list[Message]):
async def delete_messages(self, messages: list[Message]) -> None:
"""Bulk deletes messages from the channel
.. note:: The messages must have been sent in the last 7 days.
+1 -1
View File
@@ -176,7 +176,7 @@ class PermissionsOverwrite:
super().__setattr__(perm, value)
def __setattr__(self, key: str, value: Any):
def __setattr__(self, key: str, value: Any) -> None:
if key in Permissions.FLAG_NAMES:
if key is True:
setattr(self._allow, key, True)
+13 -13
View File
@@ -35,20 +35,20 @@ class Role(Ulid):
channel_permissions: :class:`ChannelPermissions`
The channel permissions for the role
"""
__slots__ = ("id", "name", "colour", "hoist", "rank", "state", "server", "permissions")
__slots__: tuple[str, ...] = ("id", "name", "colour", "hoist", "rank", "state", "server", "permissions")
def __init__(self, data: RolePayload, role_id: str, server: Server, state: State):
self.state = state
self.id = role_id
self.name = data["name"]
self.colour = data.get("colour", None)
self.hoist = False
self.rank = 0
self.server = server
self.permissions = PermissionsOverwrite._from_overwrite(data.get("permissions", {"a": 0, "d": 0}))
self.state: State = state
self.id: str = role_id
self.name: str = data["name"]
self.colour: str | None = data.get("colour", None)
self.hoist: bool = data.get("hoist", False)
self.rank: int = data["rank"]
self.server: Server = server
self.permissions: PermissionsOverwrite = PermissionsOverwrite._from_overwrite(data.get("permissions", {"a": 0, "d": 0}))
@property
def color(self):
def color(self) -> str | None:
return self.colour
async def set_permissions_overwrite(self, *, permissions: PermissionsOverwrite) -> None:
@@ -63,7 +63,7 @@ class Role(Ulid):
allow, deny = permissions.to_pair()
await self.state.http.set_server_role_permissions(self.server.id, self.id, allow.value, deny.value)
def _update(self, *, name: Optional[str] = None, colour: Optional[str] = None, hoist: Optional[bool] = None, rank: Optional[int] = None, permissions: Optional[Overwrite] = None):
def _update(self, *, name: Optional[str] = None, colour: Optional[str] = None, hoist: Optional[bool] = None, rank: Optional[int] = None, permissions: Optional[Overwrite] = None) -> None:
if name is not None:
self.name = name
@@ -79,11 +79,11 @@ class Role(Ulid):
if permissions is not None:
self.permissions = PermissionsOverwrite._from_overwrite(permissions)
async def delete(self):
async def delete(self) -> None:
"""Deletes the role"""
await self.state.http.delete_role(self.server.id, self.id)
async def edit(self, **kwargs: Any):
async def edit(self, **kwargs: Any) -> None:
"""Edits the role
Parameters
+26 -22
View File
@@ -25,11 +25,11 @@ __all__ = ("Server", "SystemMessages", "ServerBan")
class SystemMessages:
def __init__(self, data: SystemMessagesConfig, state: State):
self.state = state
self.user_joined_id = data.get("user_joined")
self.user_left_id = data.get("user_left")
self.user_kicked_id = data.get("user_kicked")
self.user_banned_id = data.get("user_banned")
self.state: State = state
self.user_joined_id: str | None = data.get("user_joined")
self.user_left_id: str | None = data.get("user_left")
self.user_kicked_id: str | None = data.get("user_kicked")
self.user_banned_id: str | None = data.get("user_banned")
@property
def user_joined(self) -> Optional[TextChannel]:
@@ -94,21 +94,25 @@ class Server(Ulid):
__slots__ = ("state", "id", "name", "owner_id", "default_permissions", "_members", "_roles", "_channels", "description", "icon", "banner", "nsfw", "system_messages", "_categories", "_emojis")
def __init__(self, data: ServerPayload, state: State):
self.state = state
self.id = data["_id"]
self.name = data["name"]
self.owner_id = data["owner"]
self.description = data.get("description") or None
self.nsfw = data.get("nsfw", False)
self.system_messages = SystemMessages(data.get("system_messages", cast("SystemMessagesConfig", {})), state)
self._categories = {data["id"]: Category(data, state) for data in data.get("categories", [])}
self.default_permissions = Permissions(data["default_permissions"])
self.state: State = state
self.id: str = data["_id"]
self.name: str = data["name"]
self.owner_id: str = data["owner"]
self.description: str | None = data.get("description") or None
self.nsfw: bool = data.get("nsfw", False)
self.system_messages: SystemMessages = SystemMessages(data.get("system_messages", cast("SystemMessagesConfig", {})), state)
self._categories: dict[str, Category] = {data["id"]: Category(data, state) for data in data.get("categories", [])}
self.default_permissions: Permissions = Permissions(data["default_permissions"])
self.icon: Asset | None
if icon := data.get("icon"):
self.icon = Asset(icon, state)
else:
self.icon = None
self.banner: Asset | None
if banner := data.get("banner"):
self.banner = Asset(banner, state)
else:
@@ -279,11 +283,11 @@ class Server(Ulid):
await self.state.http.set_server_default_permissions(self.id, permissions.value)
async def leave_server(self):
async def leave_server(self) -> None:
"""Leaves or deletes the server"""
await self.state.http.delete_leave_server(self.id)
async def delete_server(self):
async def delete_server(self) -> None:
"""Leaves or deletes a server, alias to :meth`Server.leave_server`"""
await self.leave_server()
@@ -388,7 +392,7 @@ class Server(Ulid):
return Role(payload, name, self, self.state)
async def create_emoji(self, name: str, file: File, *, nsfw: bool = False):
async def create_emoji(self, name: str, file: File, *, nsfw: bool = False) -> Emoji:
"""Creates an emoji
Parameters
@@ -421,11 +425,11 @@ class ServerBan:
__slots__ = ("reason", "server", "user_id", "state")
def __init__(self, ban: Ban, state: State):
self.reason = ban.get("reason")
self.server = state.get_server(ban["_id"]["server"])
self.user_id = ban["_id"]["user"]
self.state = state
self.reason: str | None = ban.get("reason")
self.server: Server = state.get_server(ban["_id"]["server"])
self.user_id: str = ban["_id"]["user"]
self.state: State = state
async def unban(self):
async def unban(self) -> None:
"""Unbans the user"""
await self.state.http.unban_member(self.server.id, self.user_id)
+5 -5
View File
@@ -26,9 +26,9 @@ class State:
__slots__ = ("http", "api_info", "max_messages", "users", "channels", "servers", "messages", "global_emojis", "user_id", "me")
def __init__(self, http: HttpClient, api_info: ApiInfo, max_messages: int):
self.http = http
self.api_info = api_info
self.max_messages = max_messages
self.http: HttpClient = http
self.api_info: ApiInfo = api_info
self.max_messages: int = max_messages
self.me: User
@@ -114,7 +114,7 @@ class State:
raise LookupError
async def fetch_server_members(self, server_id: str):
async def fetch_server_members(self, server_id: str) -> None:
data = await self.http.fetch_members(server_id)
for user in data["users"]:
@@ -123,6 +123,6 @@ class State:
for member in data["members"]:
self.add_member(server_id, member)
async def fetch_all_server_members(self):
async def fetch_all_server_members(self) -> None:
for server_id in self.servers:
await self.fetch_server_members(server_id)
+3 -1
View File
@@ -65,11 +65,13 @@ class Interactions(TypedDict):
reactions: NotRequired[list[str]]
restrict_reactions: NotRequired[bool]
SystemMessageContent = Union[UserAddContent, UserRemoveContent, UserJoinedContent, UserLeftContent, UserKickedContent, UserBannedContent, ChannelRenameContent, ChannelDescriptionChangeContent, ChannelIconChangeContent]
class Message(TypedDict):
_id: str
channel: str
author: str
content: Union[str, UserAddContent, UserRemoveContent, UserJoinedContent, UserLeftContent, UserKickedContent, UserBannedContent, ChannelRenameContent, ChannelDescriptionChangeContent, ChannelIconChangeContent]
content: Union[str, SystemMessageContent]
attachments: NotRequired[list[File]]
embeds: NotRequired[list[Embed]]
mentions: NotRequired[list[str]]
+18 -12
View File
@@ -65,17 +65,21 @@ class User(Messageable, Ulid):
privileged: :class:`bool`
Whether the user is privileged
"""
__flattern_attributes__ = ("id", "bot", "owner_id", "badges", "online", "flags", "relations", "relationship", "status", "masquerade_avatar", "masquerade_name", "original_name", "original_avatar", "profile", "dm_channel", "privileged")
__slots__ = (*__flattern_attributes__, "state", "_members")
__flattern_attributes__: tuple[str, ...] = ("id", "bot", "owner_id", "badges", "online", "flags", "relations", "relationship", "status", "masquerade_avatar", "masquerade_name", "original_name", "original_avatar", "profile", "dm_channel", "privileged")
__slots__: tuple[str, ...] = (*__flattern_attributes__, "state", "_members")
def __init__(self, data: UserPayload, state: State):
self.state = state
self._members: WeakValueDictionary[str, Member] = WeakValueDictionary() # we store all member versions of this user to avoid having to check every guild when needing to update.
self.id = data["_id"]
self.original_name = data["username"]
self.dm_channel = None
self.id: str = data["_id"]
self.original_name: str = data["username"]
self.dm_channel: DMChannel | None = None
bot = data.get("bot")
self.bot: bool
self.owner_id: str | None
if bot:
self.bot = True
self.owner_id = bot["owner"]
@@ -83,13 +87,13 @@ class User(Messageable, Ulid):
self.bot = False
self.owner_id = None
self.badges = UserBadges._from_value(data.get("badges", 0))
self.online = data.get("online", False)
self.flags = data.get("flags", 0)
self.privileged = data.get("privileged", False)
self.badges: UserBadges = UserBadges._from_value(data.get("badges", 0))
self.online: bool = data.get("online", False)
self.flags: int = data.get("flags", 0)
self.privileged: bool = data.get("privileged", False)
avatar = data.get("avatar")
self.original_avatar = Asset(avatar, state) if avatar else None
self.original_avatar: Asset | None = Asset(avatar, state) if avatar else None
relations: list[Relation] = []
@@ -97,12 +101,14 @@ class User(Messageable, Ulid):
user = state.get_user(relation["_id"])
if user:
relations.append(Relation(RelationshipType(relation["status"]), user))
self.relations = relations
self.relations: list[Relation] = relations
relationship = data.get("relationship")
self.relationship = RelationshipType(relationship) if relationship else None
self.relationship: RelationshipType | None = RelationshipType(relationship) if relationship else None
status = data.get("status")
self.status: Status | None
if status:
presence = status.get("presence")
self.status = Status(status.get("text"), PresenceType(presence) if presence else None) if status else None
+2 -2
View File
@@ -11,13 +11,13 @@ from typing_extensions import ParamSpec
__all__ = ("Missing", "copy_doc", "maybe_coroutine", "get", "client_session")
class _Missing:
def __repr__(self):
def __repr__(self) -> str:
return "<Missing>"
def __bool__(self) -> Literal[False]:
return False
Missing = _Missing()
Missing: _Missing = _Missing()
T = TypeVar("T")
+43 -41
View File
@@ -27,7 +27,7 @@ from .types import (ServerCreateEventPayload, ServerDeleteEventPayload,
ServerRoleDeleteEventPayload, ServerRoleUpdateEventPayload,
ServerUpdateEventPayload, UserRelationshipEventPayload,
UserUpdateEventPayload)
from .user import Status, UserProfile
from .user import Status, User, UserProfile
import aiohttp
@@ -36,6 +36,8 @@ try:
except ImportError:
import json
use_msgpack: bool
try:
import msgpack
use_msgpack = True
@@ -54,42 +56,42 @@ class WSMessage(NamedTuple):
type: aiohttp.WSMsgType
data: str | bytes | aiohttp.WSCloseCode
__all__ = ("WebsocketHandler",)
__all__: tuple[str, ...] = ("WebsocketHandler",)
logger = logging.getLogger("revolt")
logger: logging.Logger = logging.getLogger("revolt")
class WebsocketHandler:
__slots__ = ("session", "token", "ws_url", "dispatch", "state", "websocket", "loop", "user", "ready", "server_events")
def __init__(self, session: aiohttp.ClientSession, token: str, ws_url: str, dispatch: Callable[..., None], state: State):
self.session = session
self.token = token
self.ws_url = ws_url
self.dispatch = dispatch
self.state = state
self.session: aiohttp.ClientSession = session
self.token: str = token
self.ws_url: str = ws_url
self.dispatch: Callable[..., None] = dispatch
self.state: State = state
self.websocket: aiohttp.ClientWebSocketResponse
self.loop = asyncio.get_running_loop()
self.user = None
self.ready = asyncio.Event()
self.loop: asyncio.AbstractEventLoop = asyncio.get_running_loop()
self.user: User | None = None
self.ready: asyncio.Event = asyncio.Event()
self.server_events: dict[str, asyncio.Event] = {}
async def _wait_for_server_ready(self, server_id: str):
async def _wait_for_server_ready(self, server_id: str) -> None:
if event := self.server_events.get(server_id):
await event.wait()
async def send_payload(self, payload: BasePayload):
async def send_payload(self, payload: BasePayload) -> None:
if use_msgpack:
await self.websocket.send_bytes(msgpack.packb(payload))
else:
await self.websocket.send_str(json.dumps(payload))
async def heartbeat(self):
async def heartbeat(self) -> None:
while not self.websocket.closed:
logger.info("Sending hearbeat")
await self.websocket.ping()
await asyncio.sleep(15)
async def send_authenticate(self):
async def send_authenticate(self) -> None:
payload: AuthenticatePayload = {
"type": "Authenticate",
"token": self.token
@@ -97,7 +99,7 @@ class WebsocketHandler:
await self.send_payload(payload)
async def handle_event(self, payload: BasePayload):
async def handle_event(self, payload: BasePayload) -> None:
event_type = payload["type"].lower()
logger.debug("Recieved event %s %s", event_type, payload)
try:
@@ -110,10 +112,10 @@ class WebsocketHandler:
await func(payload)
async def handle_authenticated(self, _):
async def handle_authenticated(self, _: BasePayload) -> None:
logger.info("Successfully authenticated")
async def handle_ready(self, payload: ReadyEventPayload):
async def handle_ready(self, payload: ReadyEventPayload) -> None:
for user_payload in payload["users"]:
user = self.state.add_user(user_payload)
@@ -138,7 +140,7 @@ class WebsocketHandler:
self.ready.set()
self.dispatch("ready")
async def handle_message(self, payload: MessageEventPayload):
async def handle_message(self, payload: MessageEventPayload) -> None:
if server := self.state.get_channel(payload["channel"]).server_id:
await self._wait_for_server_ready(server)
@@ -147,7 +149,7 @@ class WebsocketHandler:
self.dispatch("message", message)
async def handle_messageupdate(self, payload: MessageUpdateEventPayload):
async def handle_messageupdate(self, payload: MessageUpdateEventPayload) -> None:
self.dispatch("raw_message_update", payload)
try:
@@ -162,7 +164,7 @@ class WebsocketHandler:
self.dispatch("message_update", message)
async def handle_messagedelete(self, payload: MessageDeleteEventPayload):
async def handle_messagedelete(self, payload: MessageDeleteEventPayload) -> None:
self.dispatch("raw_message_delete", payload)
try:
@@ -178,7 +180,7 @@ class WebsocketHandler:
self.dispatch("message_delete", message)
async def handle_channelcreate(self, payload: ChannelCreateEventPayload):
async def handle_channelcreate(self, payload: ChannelCreateEventPayload) -> None:
channel = self.state.add_channel(payload)
if server_id := channel.server_id:
@@ -186,7 +188,7 @@ class WebsocketHandler:
self.dispatch("channel_create", channel)
async def handle_channelupdate(self, payload: ChannelUpdateEventPayload):
async def handle_channelupdate(self, payload: ChannelUpdateEventPayload) -> None:
# Revolt sends channel updates for channels we dont have permissions to see, a bug, but still can cause issues as its not in the cache
if not (channel := self.state.channels.get(payload["id"], None)):
@@ -211,7 +213,7 @@ class WebsocketHandler:
self.dispatch("channel_update", old_channel, channel)
async def handle_channeldelete(self, payload: ChannelDeleteEventPayload):
async def handle_channeldelete(self, payload: ChannelDeleteEventPayload) -> None:
channel = self.state.channels.pop(payload["id"])
if server_id := channel.server_id:
@@ -219,7 +221,7 @@ class WebsocketHandler:
self.dispatch("channel_delete", channel)
async def handle_channelstarttyping(self, payload: ChannelStartTypingEventPayload):
async def handle_channelstarttyping(self, payload: ChannelStartTypingEventPayload) -> None:
channel = self.state.get_channel(payload["id"])
if server_id := channel.server_id:
@@ -229,7 +231,7 @@ class WebsocketHandler:
self.dispatch("typing_start", channel, user)
async def handle_channelstoptyping(self, payload: ChannelDeleteTypingEventPayload):
async def handle_channelstoptyping(self, payload: ChannelDeleteTypingEventPayload) -> None:
channel = self.state.get_channel(payload["id"])
if server_id := channel.server_id:
@@ -239,7 +241,7 @@ class WebsocketHandler:
self.dispatch("typing_stop", channel, user)
async def handle_serverupdate(self, payload: ServerUpdateEventPayload):
async def handle_serverupdate(self, payload: ServerUpdateEventPayload) -> None:
await self._wait_for_server_ready(payload["id"])
server = self.state.get_server(payload["id"])
@@ -261,7 +263,7 @@ class WebsocketHandler:
self.dispatch("server_update", old_server, server)
async def handle_serverdelete(self, payload: ServerDeleteEventPayload):
async def handle_serverdelete(self, payload: ServerDeleteEventPayload) -> None:
server = self.state.servers.pop(payload["id"])
for channel in server.channels:
@@ -271,7 +273,7 @@ class WebsocketHandler:
self.dispatch("server_delete", server)
async def handle_servercreate(self, payload: ServerCreateEventPayload):
async def handle_servercreate(self, payload: ServerCreateEventPayload) -> None:
for channel in payload["channels"]:
self.state.add_channel(channel)
@@ -284,7 +286,7 @@ class WebsocketHandler:
self.dispatch("server_join", server)
async def handle_servermemberupdate(self, payload: ServerMemberUpdateEventPayload):
async def handle_servermemberupdate(self, payload: ServerMemberUpdateEventPayload) -> None:
await self._wait_for_server_ready(payload["id"]["server"])
member = self.state.get_member(payload["id"]["server"], payload["id"]["user"])
@@ -300,7 +302,7 @@ class WebsocketHandler:
self.dispatch("member_update", old_member, member)
async def handle_servermemberjoin(self, payload: ServerMemberJoinEventPayload):
async def handle_servermemberjoin(self, payload: ServerMemberJoinEventPayload) -> None:
# avoid an api request if possible
if payload["user"] not in self.state.users:
user = await self.state.http.fetch_user(payload["user"])
@@ -310,7 +312,7 @@ class WebsocketHandler:
self.dispatch("member_join", member)
async def handle_memberleave(self, payload: ServerMemberLeaveEventPayload):
async def handle_memberleave(self, payload: ServerMemberLeaveEventPayload) -> None:
await self._wait_for_server_ready(payload["id"])
server = self.state.get_server(payload["id"])
@@ -323,7 +325,7 @@ class WebsocketHandler:
self.dispatch("member_leave", member)
async def handle_serverroleupdate(self, payload: ServerRoleUpdateEventPayload):
async def handle_serverroleupdate(self, payload: ServerRoleUpdateEventPayload) -> None:
server = self.state.get_server(payload["id"])
await self._wait_for_server_ready(server.id)
@@ -346,7 +348,7 @@ class WebsocketHandler:
self.dispatch("role_update", old_role, role)
async def handle_serverroledelete(self, payload: ServerRoleDeleteEventPayload):
async def handle_serverroledelete(self, payload: ServerRoleDeleteEventPayload) -> None:
server = self.state.get_server(payload["id"])
role = server._roles.pop(payload["role_id"])
@@ -354,7 +356,7 @@ class WebsocketHandler:
self.dispatch("role_delete", role)
async def handle_userupdate(self, payload: UserUpdateEventPayload):
async def handle_userupdate(self, payload: UserUpdateEventPayload) -> None:
user = self.state.get_user(payload["id"])
old_user = copy(user)
@@ -377,14 +379,14 @@ class WebsocketHandler:
self.dispatch("user_update", old_user, user)
async def handle_userrelationship(self, payload: UserRelationshipEventPayload):
async def handle_userrelationship(self, payload: UserRelationshipEventPayload) -> None:
user = self.state.get_user(payload["user"])
old_relationship = user.relationship
user.relationship = RelationshipType(payload["status"])
self.dispatch("user_relationship_update", user, old_relationship, user.relationship)
async def handle_messagereact(self, payload: MessageReactEventPayload):
async def handle_messagereact(self, payload: MessageReactEventPayload) -> None:
if server := self.state.get_channel(payload["channel_id"]).server_id:
await self._wait_for_server_ready(server)
@@ -401,7 +403,7 @@ class WebsocketHandler:
self.dispatch("reaction_add", message, user, emoji_id)
async def handle_messageunreact(self, payload: MessageUnreactEventPayload):
async def handle_messageunreact(self, payload: MessageUnreactEventPayload) -> None:
if server := self.state.get_channel(payload["channel_id"]).server_id:
await self._wait_for_server_ready(server)
@@ -417,7 +419,7 @@ class WebsocketHandler:
self.dispatch("reaction_remove", message, user, payload["emoji_id"])
async def handle_messageremovereaction(self, payload: MessageRemoveReactionEventPayload):
async def handle_messageremovereaction(self, payload: MessageRemoveReactionEventPayload) -> None:
if server := self.state.get_channel(payload["channel_id"]).server_id:
await self._wait_for_server_ready(server)
@@ -432,7 +434,7 @@ class WebsocketHandler:
self.dispatch("reaction_clear", message, users, payload["emoji_id"])
async def handle_bulkmessagedelete(self, payload: BulkMessageDeleteEventPayload):
async def handle_bulkmessagedelete(self, payload: BulkMessageDeleteEventPayload) -> None:
channel = self.state.get_channel(payload["channel"])
self.dispatch("raw_bulk_message_delete", payload)
@@ -457,7 +459,7 @@ class WebsocketHandler:
self.dispatch("bulk_message_delete", messages)
async def start(self):
async def start(self) -> None:
if use_msgpack:
url = f"{self.ws_url}?format=msgpack"
else: