mirror of
https://github.com/stoatchat/python-client-sdk.git
synced 2026-07-01 20:44:04 -04:00
bring type completeness to 100%
This commit is contained in:
@@ -2,15 +2,20 @@ on: [push, pull_request]
|
||||
name: pyright
|
||||
jobs:
|
||||
pyright:
|
||||
strategy:
|
||||
matrix:
|
||||
version: ["3.9", "3.10", "3.11"]
|
||||
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: '3.9'
|
||||
python-version: ${{ matrix.version }}
|
||||
cache: pip
|
||||
- run: pip install .[speedups,docs]
|
||||
- uses: jakebailey/pyright-action@v1
|
||||
with:
|
||||
lib: true
|
||||
python-version: 3.9
|
||||
python-version: ${{ matrix.version }}
|
||||
working-directory: revolt
|
||||
verify-types: true
|
||||
|
||||
@@ -11,10 +11,10 @@ upload:
|
||||
python -m twine upload dist/*
|
||||
|
||||
lint:
|
||||
pyright --lib
|
||||
pyright .
|
||||
|
||||
coverage:
|
||||
pyright --lib --ignoreexternal --verifytypes revolt
|
||||
pyright --ignoreexternal --verifytypes revolt
|
||||
|
||||
docs:
|
||||
cd docs && make html
|
||||
|
||||
+20
-21
@@ -42,14 +42,16 @@ class Asset(Ulid):
|
||||
__slots__ = ("state", "id", "tag", "size", "filename", "content_type", "width", "height", "type", "url")
|
||||
|
||||
def __init__(self, data: FilePayload, state: State):
|
||||
self.state = state
|
||||
self.state: State = state
|
||||
|
||||
self.id = data['_id']
|
||||
self.tag = data['tag']
|
||||
self.size = data['size']
|
||||
self.filename = data['filename']
|
||||
self.id: str = data['_id']
|
||||
self.tag: str = data['tag']
|
||||
self.size: int = data['size']
|
||||
self.filename: str = data['filename']
|
||||
|
||||
metadata = data['metadata']
|
||||
self.height: int | None
|
||||
self.width: int | None
|
||||
|
||||
if metadata["type"] == "Image" or metadata["type"] == "Video": # cannot use `in` because type narrowing will not happen
|
||||
self.height = metadata["height"]
|
||||
@@ -58,17 +60,17 @@ class Asset(Ulid):
|
||||
self.height = None
|
||||
self.width = None
|
||||
|
||||
self.content_type = data["content_type"]
|
||||
self.type = AssetType(metadata["type"])
|
||||
self.content_type: str | None = data["content_type"]
|
||||
self.type: AssetType = AssetType(metadata["type"])
|
||||
|
||||
base_url = self.state.api_info["features"]["autumn"]["url"]
|
||||
self.url = f"{base_url}/{self.tag}/{self.id}"
|
||||
self.url: str = f"{base_url}/{self.tag}/{self.id}"
|
||||
|
||||
async def read(self) -> bytes:
|
||||
"""Reads the files content into bytes"""
|
||||
return await self.state.http.request_file(self.url)
|
||||
|
||||
async def save(self, fp: IOBase):
|
||||
async def save(self, fp: IOBase) -> None:
|
||||
"""Reads the files content and saves it to a file
|
||||
|
||||
Parameters
|
||||
@@ -85,8 +87,6 @@ class PartialAsset(Asset):
|
||||
-----------
|
||||
id: :class:`str`
|
||||
The id of the asset, this will always be ``"0"``
|
||||
tag: Optional[:class:`str`]
|
||||
The tag of the asset, this corrasponds to where the asset is used, this will always be ``None``
|
||||
size: :class:`int`
|
||||
Amount of bytes in the file, this will always be ``0``
|
||||
filename: :class:`str`
|
||||
@@ -102,13 +102,12 @@ class PartialAsset(Asset):
|
||||
"""
|
||||
|
||||
def __init__(self, url: str, state: State):
|
||||
self.state = state
|
||||
self.id = "0"
|
||||
self.tag = None
|
||||
self.size = 0
|
||||
self.filename = ""
|
||||
self.height = None
|
||||
self.width = None
|
||||
self.content_type = mimetypes.guess_extension(url)
|
||||
self.type = AssetType.file
|
||||
self.url = url
|
||||
self.state: State = state
|
||||
self.id: str = "0"
|
||||
self.size: int = 0
|
||||
self.filename: str = ""
|
||||
self.height: int | None = None
|
||||
self.width: int | None = None
|
||||
self.content_type: str | None = mimetypes.guess_extension(url)
|
||||
self.type: AssetType = AssetType.file
|
||||
self.url: str = url
|
||||
|
||||
+4
-4
@@ -25,10 +25,10 @@ class Category(Ulid):
|
||||
"""
|
||||
|
||||
def __init__(self, data: CategoryPayload, state: State):
|
||||
self.state = state
|
||||
self.name = data["title"]
|
||||
self.id = data["id"]
|
||||
self.channel_ids = data["channels"]
|
||||
self.state: State = state
|
||||
self.name: str = data["title"]
|
||||
self.id: str = data["id"]
|
||||
self.channel_ids: list[str] = data["channels"]
|
||||
|
||||
@property
|
||||
def channels(self) -> list[Channel]:
|
||||
|
||||
+26
-21
@@ -31,7 +31,7 @@ class EditableChannel:
|
||||
state: State
|
||||
id: str
|
||||
|
||||
async def edit(self, **kwargs: Any):
|
||||
async def edit(self, **kwargs: Any) -> None:
|
||||
"""Edits the channel
|
||||
|
||||
Passing ``None`` to the parameters that accept it will remove them.
|
||||
@@ -80,18 +80,18 @@ class Channel(Ulid):
|
||||
__slots__ = ("state", "id", "channel_type", "server_id")
|
||||
|
||||
def __init__(self, data: ChannelPayload, state: State):
|
||||
self.state = state
|
||||
self.id = data["_id"]
|
||||
self.channel_type = ChannelType(data["channel_type"])
|
||||
self.state: State = state
|
||||
self.id: str = data["_id"]
|
||||
self.channel_type: ChannelType = ChannelType(data["channel_type"])
|
||||
self.server_id: Optional[str] = None
|
||||
|
||||
async def _get_channel_id(self) -> str:
|
||||
return self.id
|
||||
|
||||
def _update(self, **_: Any):
|
||||
def _update(self, **_: Any) -> None:
|
||||
pass
|
||||
|
||||
async def delete(self):
|
||||
async def delete(self) -> None:
|
||||
"""Deletes or closes the channel"""
|
||||
await self.state.http.close_channel(self.id)
|
||||
|
||||
@@ -134,7 +134,7 @@ class DMChannel(Channel, Messageable):
|
||||
def __init__(self, data: DMChannelPayload, state: State):
|
||||
super().__init__(data, state)
|
||||
self.recipient_ids: tuple[str, str] = tuple(data["recipients"])
|
||||
self.last_message_id = data.get("last_message_id")
|
||||
self.last_message_id: str | None = data.get("last_message_id")
|
||||
|
||||
@property
|
||||
def recipients(self) -> tuple[User, User]:
|
||||
@@ -190,20 +190,22 @@ class GroupDMChannel(Channel, Messageable, EditableChannel):
|
||||
|
||||
def __init__(self, data: GroupDMChannelPayload, state: State):
|
||||
super().__init__(data, state)
|
||||
self.recipient_ids = data["recipients"]
|
||||
self.name = data["name"]
|
||||
self.owner_id = data["owner"]
|
||||
self.description: Optional[str] = data.get("description")
|
||||
self.last_message_id = data.get("last_message_id")
|
||||
self.recipient_ids: list[str] = data["recipients"]
|
||||
self.name: str = data["name"]
|
||||
self.owner_id: str = data["owner"]
|
||||
self.description: str | None = data.get("description")
|
||||
self.last_message_id: str | None = data.get("last_message_id")
|
||||
|
||||
self.icon: Asset | None
|
||||
|
||||
if icon := data.get("icon"):
|
||||
self.icon = Asset(icon, state)
|
||||
else:
|
||||
self.icon = None
|
||||
|
||||
self.permissions = Permissions(data.get("permissions", 0))
|
||||
self.permissions: Permissions = Permissions(data.get("permissions", 0))
|
||||
|
||||
def _update(self, *, name: Optional[str] = None, recipients: Optional[list[str]] = None, description: Optional[str] = None):
|
||||
def _update(self, *, name: Optional[str] = None, recipients: Optional[list[str]] = None, description: Optional[str] = None) -> None:
|
||||
if name is not None:
|
||||
self.name = name
|
||||
|
||||
@@ -263,12 +265,12 @@ class ServerChannel(Channel):
|
||||
def __init__(self, data: ServerChannelPayload, state: State):
|
||||
super().__init__(data, state)
|
||||
|
||||
self.server_id = data["server"]
|
||||
self.name = data["name"]
|
||||
self.server_id: str = data["server"]
|
||||
self.name: str = data["name"]
|
||||
self.description: Optional[str] = data.get("description")
|
||||
self.nsfw = data.get("nsfw", False)
|
||||
self.active = False
|
||||
self.default_permissions = PermissionsOverwrite._from_overwrite(data.get("default_permissions", {"a": 0, "d": 0}))
|
||||
self.nsfw: bool = data.get("nsfw", False)
|
||||
self.active: bool = False
|
||||
self.default_permissions: PermissionsOverwrite = PermissionsOverwrite._from_overwrite(data.get("default_permissions", {"a": 0, "d": 0}))
|
||||
|
||||
permissions: dict[str, PermissionsOverwrite] = {}
|
||||
|
||||
@@ -276,7 +278,10 @@ class ServerChannel(Channel):
|
||||
overwrite = PermissionsOverwrite._from_overwrite(overwrite_data)
|
||||
permissions[role_name] = overwrite
|
||||
|
||||
self.permissions = permissions
|
||||
self.permissions: dict[str, PermissionsOverwrite] = permissions
|
||||
|
||||
self.icon: Asset | None
|
||||
|
||||
if icon := data.get("icon"):
|
||||
self.icon = Asset(icon, state)
|
||||
else:
|
||||
@@ -359,7 +364,7 @@ class TextChannel(ServerChannel, Messageable, EditableChannel):
|
||||
def __init__(self, data: TextChannelPayload, state: State):
|
||||
super().__init__(data, state)
|
||||
|
||||
self.last_message_id = data.get("last_message_id")
|
||||
self.last_message_id: str | None = data.get("last_message_id")
|
||||
|
||||
async def _get_channel_id(self) -> str:
|
||||
return self.id
|
||||
|
||||
+12
-12
@@ -30,7 +30,7 @@ if TYPE_CHECKING:
|
||||
|
||||
__all__ = ("Client",)
|
||||
|
||||
logger = logging.getLogger("revolt")
|
||||
logger: logging.Logger = logging.getLogger("revolt")
|
||||
|
||||
class Client:
|
||||
"""The client for interacting with revolt
|
||||
@@ -48,11 +48,11 @@ class Client:
|
||||
"""
|
||||
|
||||
def __init__(self, session: aiohttp.ClientSession, token: str, *, api_url: str = "https://api.revolt.chat", max_messages: int = 5000, bot: bool = True):
|
||||
self.session = session
|
||||
self.token = token
|
||||
self.api_url = api_url
|
||||
self.max_messages = max_messages
|
||||
self.bot = bot
|
||||
self.session: aiohttp.ClientSession = session
|
||||
self.token: str = token
|
||||
self.api_url: str = api_url
|
||||
self.max_messages: int = max_messages
|
||||
self.bot: bool = bot
|
||||
|
||||
self.api_info: ApiInfo
|
||||
self.http: HttpClient
|
||||
@@ -63,7 +63,7 @@ class Client:
|
||||
|
||||
super().__init__()
|
||||
|
||||
def dispatch(self, event: str, *args: Any):
|
||||
def dispatch(self, event: str, *args: Any) -> None:
|
||||
"""Dispatch an event, this is typically used for testing and internals.
|
||||
|
||||
Parameters
|
||||
@@ -88,7 +88,7 @@ class Client:
|
||||
async with self.session.get(self.api_url) as resp:
|
||||
return json.loads(await resp.text())
|
||||
|
||||
async def start(self):
|
||||
async def start(self) -> None:
|
||||
"""Starts the client"""
|
||||
api_info = await self.get_api_info()
|
||||
|
||||
@@ -98,7 +98,7 @@ class Client:
|
||||
self.websocket = WebsocketHandler(self.session, self.token, api_info["ws"], self.dispatch, self.state)
|
||||
await self.websocket.start()
|
||||
|
||||
async def stop(self):
|
||||
async def stop(self) -> None:
|
||||
await self.websocket.websocket.close()
|
||||
|
||||
def get_user(self, id: str) -> User:
|
||||
@@ -300,7 +300,7 @@ class Client:
|
||||
|
||||
raise LookupError
|
||||
|
||||
async def edit_self(self, **kwargs: Any):
|
||||
async def edit_self(self, **kwargs: Any) -> None:
|
||||
"""Edits the client's own user
|
||||
|
||||
Parameters
|
||||
@@ -316,7 +316,7 @@ class Client:
|
||||
|
||||
await self.state.http.edit_self(remove, kwargs)
|
||||
|
||||
async def edit_status(self, **kwargs: Any):
|
||||
async def edit_status(self, **kwargs: Any) -> None:
|
||||
"""Edits the client's own status
|
||||
|
||||
Parameters
|
||||
@@ -337,7 +337,7 @@ class Client:
|
||||
|
||||
await self.state.http.edit_self(remove, {"status": kwargs})
|
||||
|
||||
async def edit_profile(self, **kwargs: Any):
|
||||
async def edit_profile(self, **kwargs: Any) -> None:
|
||||
"""Edits the client's own profile
|
||||
|
||||
Parameters
|
||||
|
||||
+26
-21
@@ -4,6 +4,8 @@ from typing import TYPE_CHECKING, Optional, TypedDict, Union
|
||||
|
||||
from typing_extensions import NotRequired, Unpack
|
||||
|
||||
from revolt.types.embed import WebsiteSpecial
|
||||
|
||||
from .asset import Asset
|
||||
from .enums import EmbedType
|
||||
|
||||
@@ -14,6 +16,7 @@ if TYPE_CHECKING:
|
||||
from .types import SendableEmbed as SendableEmbedPayload
|
||||
from .types import TextEmbed as TextEmbedPayload
|
||||
from .types import WebsiteEmbed as WebsiteEmbedPayload
|
||||
from .types import JanuaryImage, JanuaryVideo
|
||||
|
||||
__all__ = ("Embed", "WebsiteEmbed", "ImageEmbed", "TextEmbed", "NoneEmbed", "to_embed", "SendableEmbed")
|
||||
|
||||
@@ -21,43 +24,45 @@ class WebsiteEmbed:
|
||||
type = EmbedType.website
|
||||
|
||||
def __init__(self, embed: WebsiteEmbedPayload):
|
||||
self.url = embed.get("url")
|
||||
self.special = embed.get("special")
|
||||
self.title = embed.get("title")
|
||||
self.description = embed.get("description")
|
||||
self.image = embed.get("image")
|
||||
self.video = embed.get("video")
|
||||
self.site_name = embed.get("site_name")
|
||||
self.icon_url = embed.get("icon_url")
|
||||
self.colour = embed.get("colour")
|
||||
self.url: str | None = embed.get("url")
|
||||
self.special: WebsiteSpecial | None = embed.get("special")
|
||||
self.title: str | None = embed.get("title")
|
||||
self.description: str | None = embed.get("description")
|
||||
self.image: JanuaryImage | None = embed.get("image")
|
||||
self.video: JanuaryVideo | None = embed.get("video")
|
||||
self.site_name: str | None = embed.get("site_name")
|
||||
self.icon_url: str | None = embed.get("icon_url")
|
||||
self.colour: str | None = embed.get("colour")
|
||||
|
||||
class ImageEmbed:
|
||||
type = EmbedType.image
|
||||
type: EmbedType = EmbedType.image
|
||||
|
||||
def __init__(self, image: ImageEmbedPayload):
|
||||
self.url = image.get("url")
|
||||
self.width = image.get("width")
|
||||
self.height = image.get("height")
|
||||
self.size = image.get("size")
|
||||
self.url: str = image.get("url")
|
||||
self.width: int = image.get("width")
|
||||
self.height: int = image.get("height")
|
||||
self.size: str = image.get("size")
|
||||
|
||||
class TextEmbed:
|
||||
type = EmbedType.text
|
||||
type: EmbedType = EmbedType.text
|
||||
|
||||
def __init__(self, embed: TextEmbedPayload, state: State):
|
||||
self.icon_url = embed.get("icon_url")
|
||||
self.url = embed.get("url")
|
||||
self.title = embed.get("title")
|
||||
self.description = embed.get("description")
|
||||
self.icon_url: str | None = embed.get("icon_url")
|
||||
self.url: str | None = embed.get("url")
|
||||
self.title: str | None = embed.get("title")
|
||||
self.description: str | None = embed.get("description")
|
||||
|
||||
self.media: Asset | None
|
||||
|
||||
if media := embed.get("media"):
|
||||
self.media = Asset(media, state)
|
||||
else:
|
||||
self.media = None
|
||||
|
||||
self.colour = embed.get("colour")
|
||||
self.colour: str | None = embed.get("colour")
|
||||
|
||||
class NoneEmbed:
|
||||
type = EmbedType.none
|
||||
type: EmbedType = EmbedType.none
|
||||
|
||||
Embed = Union[WebsiteEmbed, ImageEmbed, TextEmbed, NoneEmbed]
|
||||
|
||||
|
||||
+9
-9
@@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from .utils import Ulid
|
||||
|
||||
@@ -30,16 +30,16 @@ class Emoji(Ulid):
|
||||
The server id this emoji belongs to, if any
|
||||
"""
|
||||
def __init__(self, payload: EmojiPayload, state: State):
|
||||
self.state = state
|
||||
self.state: State = state
|
||||
|
||||
self.id = payload["_id"]
|
||||
self.author_id = payload["creator_id"]
|
||||
self.name = payload["name"]
|
||||
self.animated = payload.get("animated", False)
|
||||
self.nsfw = payload.get("nsfw", False)
|
||||
self.server_id: Optional[str] = payload["parent"].get("id")
|
||||
self.id: str = payload["_id"]
|
||||
self.author_id: str = payload["creator_id"]
|
||||
self.name: str = payload["name"]
|
||||
self.animated: bool = payload.get("animated", False)
|
||||
self.nsfw: bool = payload.get("nsfw", False)
|
||||
self.server_id: str | None = payload["parent"].get("id")
|
||||
|
||||
async def delete(self):
|
||||
async def delete(self) -> None:
|
||||
"""Deletes the emoji."""
|
||||
await self.state.http.delete_emoji(self.id)
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Callable, Coroutine, TypeVar, Union, cast
|
||||
from typing import Any, Callable, Coroutine, Union, cast
|
||||
from typing_extensions import TypeVar
|
||||
|
||||
import revolt
|
||||
|
||||
@@ -12,11 +13,11 @@ from .utils import ClientT
|
||||
|
||||
__all__ = ("check", "Check", "is_bot_owner", "is_server_owner", "has_permissions", "has_channel_permissions")
|
||||
|
||||
T = TypeVar("T", Callable[..., Any], Command)
|
||||
T = TypeVar("T", Callable[..., Any], Command, default=Command)
|
||||
|
||||
Check = Callable[[Context[ClientT]], Union[Any, Coroutine[Any, Any, Any]]]
|
||||
|
||||
def check(check: Check[ClientT]):
|
||||
def check(check: Check[ClientT]) -> Callable[[T], T]:
|
||||
"""A decorator for adding command checks
|
||||
|
||||
Parameters
|
||||
@@ -37,7 +38,7 @@ def check(check: Check[ClientT]):
|
||||
|
||||
return inner
|
||||
|
||||
def is_bot_owner():
|
||||
def is_bot_owner() -> Callable[[T], T]:
|
||||
"""A command check for limiting the command to only the bot's owner"""
|
||||
@check
|
||||
def inner(context: Context[ClientT]):
|
||||
@@ -48,10 +49,10 @@ def is_bot_owner():
|
||||
|
||||
return inner
|
||||
|
||||
def is_server_owner():
|
||||
def is_server_owner() -> Callable[[T], T]:
|
||||
"""A command check for limiting the command to only a server's owner"""
|
||||
@check
|
||||
def inner(context: Context[ClientT]):
|
||||
def inner(context: Context[ClientT]) -> bool:
|
||||
if not context.server_id:
|
||||
raise ServerOnly
|
||||
|
||||
@@ -62,19 +63,21 @@ def is_server_owner():
|
||||
|
||||
return inner
|
||||
|
||||
def has_permissions(**permissions: bool):
|
||||
def has_permissions(**permissions: bool) -> Callable[[T], T]:
|
||||
@check
|
||||
def inner(context: Context[ClientT]):
|
||||
def inner(context: Context[ClientT]) -> bool:
|
||||
author = context.author
|
||||
|
||||
if not author.has_permissions(**permissions):
|
||||
raise MissingPermissionsError(permissions)
|
||||
|
||||
return True
|
||||
|
||||
return inner
|
||||
|
||||
def has_channel_permissions(**permissions: bool):
|
||||
def has_channel_permissions(**permissions: bool) -> Callable[[T], T]:
|
||||
@check
|
||||
def inner(context: Context[ClientT]):
|
||||
def inner(context: Context[ClientT]) -> bool:
|
||||
author = context.author
|
||||
|
||||
if not isinstance(author, revolt.Member):
|
||||
@@ -83,4 +86,6 @@ def has_channel_permissions(**permissions: bool):
|
||||
if not author.has_channel_permissions(context.channel, **permissions):
|
||||
raise MissingPermissionsError(permissions)
|
||||
|
||||
return True
|
||||
|
||||
return inner
|
||||
|
||||
@@ -36,7 +36,7 @@ class ExtensionProtocol(Protocol):
|
||||
class CommandsMeta(type):
|
||||
_commands: list[Command[Any]]
|
||||
|
||||
def __new__(cls, name: str, bases: tuple[type, ...], attrs: dict[str, Any]):
|
||||
def __new__(cls, name: str, bases: tuple[type, ...], attrs: dict[str, Any]) -> Self:
|
||||
commands: list[Command[Any]] = []
|
||||
self = super().__new__(cls, name, bases, attrs)
|
||||
for base in reversed(self.__mro__):
|
||||
@@ -139,7 +139,7 @@ class CommandsClient(revolt.Client, metaclass=CommandsMeta):
|
||||
"""
|
||||
return self.all_commands[name]
|
||||
|
||||
def add_command(self, command: Command[Self]):
|
||||
def add_command(self, command: Command[Self]) -> None:
|
||||
"""Adds a command, this is typically only used for dynamic commands, you should use the `commands.command` decorator for most usecases.
|
||||
|
||||
Parameters
|
||||
@@ -196,9 +196,6 @@ class CommandsClient(revolt.Client, metaclass=CommandsMeta):
|
||||
"""
|
||||
content = message.content
|
||||
|
||||
if not isinstance(content, str):
|
||||
return
|
||||
|
||||
prefixes = await self.get_prefix(message)
|
||||
|
||||
if isinstance(prefixes, str):
|
||||
@@ -248,7 +245,7 @@ class CommandsClient(revolt.Client, metaclass=CommandsMeta):
|
||||
await command._error_handler(command.cog or self, context, e)
|
||||
self.dispatch("command_error", context, e)
|
||||
|
||||
async def on_command_error(self, ctx: Context[Self], error: Exception, /):
|
||||
async def on_command_error(self, ctx: Context[Self], error: Exception, /) -> None:
|
||||
traceback.print_exception(type(error), error, error.__traceback__)
|
||||
|
||||
on_message = process_commands
|
||||
@@ -268,7 +265,7 @@ class CommandsClient(revolt.Client, metaclass=CommandsMeta):
|
||||
|
||||
return True
|
||||
|
||||
def add_cog(self, cog: Cog[Self]):
|
||||
def add_cog(self, cog: Cog[Self]) -> None:
|
||||
"""Adds a cog to the bot, this cog must subclass `Cog`.
|
||||
|
||||
Parameters
|
||||
@@ -296,7 +293,7 @@ class CommandsClient(revolt.Client, metaclass=CommandsMeta):
|
||||
|
||||
return cog
|
||||
|
||||
def load_extension(self, name: str):
|
||||
def load_extension(self, name: str) -> None:
|
||||
"""Loads an extension, this takes a module name and runs the setup function inside of it.
|
||||
|
||||
Parameters
|
||||
@@ -312,7 +309,7 @@ class CommandsClient(revolt.Client, metaclass=CommandsMeta):
|
||||
self.extensions[name] = extension
|
||||
extension.setup(self)
|
||||
|
||||
def unload_extension(self, name: str):
|
||||
def unload_extension(self, name: str) -> None:
|
||||
"""Unloads an extension, this takes a module name and runs the teardown function inside of it.
|
||||
|
||||
Parameters
|
||||
@@ -327,7 +324,7 @@ class CommandsClient(revolt.Client, metaclass=CommandsMeta):
|
||||
if teardown := getattr(extension, "teardown", None):
|
||||
teardown(self)
|
||||
|
||||
def reload_extension(self, name: str):
|
||||
def reload_extension(self, name: str) -> None:
|
||||
"""Reloads an extension, this will unload and reload the extension.
|
||||
|
||||
Parameters
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Generic, Optional, cast
|
||||
from typing_extensions import Self
|
||||
|
||||
from .command import Command
|
||||
from .utils import ClientT
|
||||
@@ -11,7 +12,7 @@ class CogMeta(type, Generic[ClientT]):
|
||||
_commands: list[Command[ClientT]]
|
||||
qualified_name: str
|
||||
|
||||
def __new__(cls, name: str, bases: tuple[type, ...], attrs: dict[str, Any], *, qualified_name: Optional[str] = None):
|
||||
def __new__(cls, name: str, bases: tuple[type, ...], attrs: dict[str, Any], *, qualified_name: Optional[str] = None) -> Self:
|
||||
commands: list[Command[ClientT]] = []
|
||||
self = super().__new__(cls, name, bases, attrs)
|
||||
|
||||
@@ -29,15 +30,15 @@ class Cog(Generic[ClientT], metaclass=CogMeta):
|
||||
_commands: list[Command[ClientT]]
|
||||
qualified_name: str
|
||||
|
||||
def cog_load(self):
|
||||
def cog_load(self) -> None:
|
||||
"""A special method that is called when the cog gets loaded."""
|
||||
pass
|
||||
|
||||
def cog_unload(self):
|
||||
def cog_unload(self) -> None:
|
||||
"""A special method that is called when the cog gets removed."""
|
||||
pass
|
||||
|
||||
def _inject(self, client: ClientT):
|
||||
def _inject(self, client: ClientT) -> None:
|
||||
client.cogs[self.qualified_name] = self
|
||||
|
||||
for command in self._commands:
|
||||
@@ -46,7 +47,7 @@ class Cog(Generic[ClientT], metaclass=CogMeta):
|
||||
|
||||
self.cog_load()
|
||||
|
||||
def _uninject(self, client: ClientT):
|
||||
def _uninject(self, client: ClientT) -> None:
|
||||
for name, command in client.all_commands.copy().items():
|
||||
if command in self._commands:
|
||||
del client.all_commands[name]
|
||||
|
||||
@@ -5,6 +5,7 @@ import traceback
|
||||
from contextlib import suppress
|
||||
from typing import (TYPE_CHECKING, Annotated, Any, Callable, Coroutine,
|
||||
Generic, Literal, Optional, Union, get_args, get_origin)
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
from revolt.utils import copy_doc, maybe_coroutine
|
||||
|
||||
@@ -17,13 +18,13 @@ if TYPE_CHECKING:
|
||||
from .context import Context
|
||||
from .group import Group
|
||||
|
||||
__all__ = (
|
||||
__all__: tuple[str, ...] = (
|
||||
"Command",
|
||||
"command"
|
||||
)
|
||||
|
||||
NoneType = type(None)
|
||||
|
||||
NoneType: type[None] = type(None)
|
||||
P = ParamSpec("P")
|
||||
|
||||
class Command(Generic[ClientCoT]):
|
||||
"""Class for holding info about a command.
|
||||
@@ -46,17 +47,17 @@ class Command(Generic[ClientCoT]):
|
||||
__slots__ = ("callback", "name", "aliases", "signature", "checks", "parent", "_error_handler", "cog", "description", "usage", "parameters")
|
||||
|
||||
def __init__(self, callback: Callable[..., Coroutine[Any, Any, Any]], name: str, aliases: list[str], usage: Optional[str] = None):
|
||||
self.callback = callback
|
||||
self.name = name
|
||||
self.aliases = aliases
|
||||
self.usage = usage
|
||||
self.signature = inspect.signature(self.callback)
|
||||
self.parameters = evaluate_parameters(self.signature.parameters.values(), getattr(callback, "__globals__", {}))
|
||||
self.callback: Callable[..., Coroutine[Any, Any, Any]] = callback
|
||||
self.name: str = name
|
||||
self.aliases: list[str] = aliases
|
||||
self.usage: str | None = usage
|
||||
self.signature: inspect.Signature = inspect.signature(self.callback)
|
||||
self.parameters: list[inspect.Parameter] = evaluate_parameters(self.signature.parameters.values(), getattr(callback, "__globals__", {}))
|
||||
self.checks: list[Check[ClientCoT]] = getattr(callback, "_checks", [])
|
||||
self.parent: Optional[Group[ClientCoT]] = None
|
||||
self.cog: Optional[Cog[ClientCoT]] = None
|
||||
self._error_handler: Callable[[Any, Context[ClientCoT], Exception], Coroutine[Any, Any, Any]] = type(self)._default_error_handler
|
||||
self.description = callback.__doc__
|
||||
self.description: str | None = callback.__doc__
|
||||
|
||||
async def invoke(self, context: Context[ClientCoT], *args: Any, **kwargs: Any) -> Any:
|
||||
"""Runs the command and calls the error handler if the command errors.
|
||||
@@ -77,7 +78,7 @@ class Command(Generic[ClientCoT]):
|
||||
def __call__(self, context: Context[ClientCoT], *args: Any, **kwargs: Any) -> Any:
|
||||
return self.invoke(context, *args, **kwargs)
|
||||
|
||||
def error(self, func: Callable[..., Coroutine[Any, Any, Any]]):
|
||||
def error(self, func: Callable[..., Coroutine[Any, Any, Any]]) -> Callable[..., Coroutine[Any, Any, Any]]:
|
||||
"""Sets the error handler for the command.
|
||||
|
||||
Parameters
|
||||
@@ -141,7 +142,7 @@ class Command(Generic[ClientCoT]):
|
||||
else:
|
||||
return arg
|
||||
|
||||
async def parse_arguments(self, context: Context[ClientCoT]):
|
||||
async def parse_arguments(self, context: Context[ClientCoT]) -> None:
|
||||
# please pr if you can think of a better way to do this
|
||||
|
||||
for parameter in self.parameters[2:]:
|
||||
@@ -214,8 +215,8 @@ class Command(Generic[ClientCoT]):
|
||||
|
||||
return f"{' '.join(parents[::-1])} {self.name} {' '.join(parameters)}"
|
||||
|
||||
def command(*, name: Optional[str] = None, aliases: Optional[list[str]] = None, cls: type[Command[ClientCoT]] = Command, usage: Optional[str] = None):
|
||||
"""A decorator that turns a function into a :class:`Command`.
|
||||
def command(*, name: Optional[str] = None, aliases: Optional[list[str]] = None, cls: type[Command[ClientCoT]] = Command, usage: Optional[str] = None) -> Callable[[Callable[..., Coroutine[Any, Any, Any]]], Command[ClientCoT]]:
|
||||
"""A decorator that turns a function into a :class:`Command`.n
|
||||
|
||||
Parameters
|
||||
-----------
|
||||
|
||||
@@ -11,6 +11,7 @@ from .utils import ClientCoT
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .view import StringView
|
||||
from revolt.state import State
|
||||
|
||||
__all__ = (
|
||||
"Context",
|
||||
@@ -46,17 +47,17 @@ class Context(revolt.Messageable, Generic[ClientCoT]):
|
||||
return self.channel.id
|
||||
|
||||
def __init__(self, command: Optional[Command[ClientCoT]], invoked_with: str, view: StringView, message: revolt.Message, client: ClientCoT):
|
||||
self.command = command
|
||||
self.invoked_with = invoked_with
|
||||
self.view = view
|
||||
self.message = message
|
||||
self.client = client
|
||||
self.command: Command[ClientCoT] | None = command
|
||||
self.invoked_with: str = invoked_with
|
||||
self.view: StringView = view
|
||||
self.message: revolt.Message = message
|
||||
self.client: ClientCoT = client
|
||||
self.args: list[Any] = []
|
||||
self.kwargs: dict[str, Any] = {}
|
||||
self.server_id = message.server_id
|
||||
self.channel = message.channel
|
||||
self.author = message.author
|
||||
self.state = message.state
|
||||
self.server_id: str | None = message.server_id
|
||||
self.channel: revolt.TextChannel | revolt.GroupDMChannel | revolt.DMChannel = message.channel
|
||||
self.author: revolt.Member | revolt.User = message.author
|
||||
self.state: State = message.state
|
||||
|
||||
@property
|
||||
def server(self) -> revolt.Server:
|
||||
@@ -105,7 +106,7 @@ class Context(revolt.Messageable, Generic[ClientCoT]):
|
||||
|
||||
return all([await maybe_coroutine(check, self) for check in (command.checks if command else [])])
|
||||
|
||||
async def send_help(self, argument: Command[Any] | Group[Any] | ClientCoT | None = None):
|
||||
async def send_help(self, argument: Command[Any] | Group[Any] | ClientCoT | None = None) -> None:
|
||||
argument = argument or self.client
|
||||
|
||||
command = self.client.get_command("help")
|
||||
|
||||
@@ -13,14 +13,14 @@ from .errors import (BadBoolArgument, CategoryConverterError,
|
||||
if TYPE_CHECKING:
|
||||
from .client import CommandsClient
|
||||
|
||||
__all__ = ("bool_converter", "category_converter", "channel_converter", "user_converter", "member_converter", "IntConverter", "BoolConverter", "CategoryConverter", "UserConverter", "MemberConverter", "ChannelConverter")
|
||||
__all__: tuple[str, ...] = ("bool_converter", "category_converter", "channel_converter", "user_converter", "member_converter", "IntConverter", "BoolConverter", "CategoryConverter", "UserConverter", "MemberConverter", "ChannelConverter")
|
||||
|
||||
channel_regex = re.compile("<#([A-z0-9]{26})>")
|
||||
user_regex = re.compile("<@([A-z0-9]{26})>")
|
||||
channel_regex: re.Pattern[str] = re.compile("<#([A-z0-9]{26})>")
|
||||
user_regex: re.Pattern[str] = re.compile("<@([A-z0-9]{26})>")
|
||||
|
||||
ClientT = TypeVar("ClientT", bound="CommandsClient")
|
||||
|
||||
def bool_converter(arg: str, _):
|
||||
def bool_converter(arg: str, _: Context[ClientT]) -> bool:
|
||||
lowered = arg.lower()
|
||||
if lowered in ("yes", "true", "ye", "y", "1", "on", "enable"):
|
||||
return True
|
||||
|
||||
@@ -32,7 +32,7 @@ class CommandNotFound(CommandError):
|
||||
__slots__ = ("command_name",)
|
||||
|
||||
def __init__(self, command_name: str):
|
||||
self.command_name = command_name
|
||||
self.command_name: str = command_name
|
||||
|
||||
class NoClosingQuote(CommandError):
|
||||
"""Raised when there is no closing quote for a command argument"""
|
||||
|
||||
@@ -26,13 +26,13 @@ class Group(Command[ClientCoT]):
|
||||
The group's subcommands.
|
||||
"""
|
||||
|
||||
__slots__ = ("subcommands",)
|
||||
__slots__: tuple[str, ...] = ("subcommands",)
|
||||
|
||||
def __init__(self, callback: Callable[..., Coroutine[Any, Any, Any]], name: str, aliases: list[str]):
|
||||
self.subcommands: dict[str, Command[ClientCoT]] = {}
|
||||
super().__init__(callback, name, aliases)
|
||||
|
||||
def command(self, *, name: Optional[str] = None, aliases: Optional[list[str]] = None, cls: type[Command[ClientCoT]] = Command[ClientCoT]):
|
||||
def command(self, *, name: Optional[str] = None, aliases: Optional[list[str]] = None, cls: type[Command[ClientCoT]] = Command[ClientCoT]) -> Callable[[Callable[..., Coroutine[Any, Any, Any]]], Command[ClientCoT]]:
|
||||
"""A decorator that turns a function into a :class:`Command` and registers the command as a subcommand.
|
||||
|
||||
Parameters
|
||||
@@ -57,7 +57,7 @@ class Group(Command[ClientCoT]):
|
||||
|
||||
return inner
|
||||
|
||||
def group(self, *, name: Optional[str] = None, aliases: Optional[list[str]] = None, cls: Optional[type[Group[ClientCoT]]] = None):
|
||||
def group(self, *, name: Optional[str] = None, aliases: Optional[list[str]] = None, cls: Optional[type[Group[ClientCoT]]] = None) -> Callable[[Callable[..., Coroutine[Any, Any, Any]]], Group[ClientCoT]]:
|
||||
"""A decorator that turns a function into a :class:`Group` and registers the command as a subcommand
|
||||
|
||||
Parameters
|
||||
@@ -91,7 +91,7 @@ class Group(Command[ClientCoT]):
|
||||
def commands(self) -> list[Command[ClientCoT]]:
|
||||
return list(self.subcommands.values())
|
||||
|
||||
def group(*, name: Optional[str] = None, aliases: Optional[list[str]] = None, cls: type[Group[ClientT]] = Group):
|
||||
def group(*, name: Optional[str] = None, aliases: Optional[list[str]] = None, cls: type[Group[ClientT]] = Group) -> Callable[[Callable[..., Coroutine[Any, Any, Any]]], Group[ClientT]]:
|
||||
"""A decorator that turns a function into a :class:`Group`
|
||||
|
||||
Parameters
|
||||
|
||||
@@ -66,7 +66,7 @@ class HelpCommand(ABC, Generic[ClientCoT]):
|
||||
|
||||
return cogs
|
||||
|
||||
async def handle_message(self, context: Context[ClientCoT], message: Message):
|
||||
async def handle_message(self, context: Context[ClientCoT], message: Message) -> None:
|
||||
pass
|
||||
|
||||
async def get_channel(self, context: Context) -> Messageable:
|
||||
@@ -145,11 +145,11 @@ class DefaultHelpCommand(HelpCommand[ClientCoT]):
|
||||
lines.append("```")
|
||||
return "\n".join(lines)
|
||||
|
||||
async def handle_no_command_found(self, context: Context[ClientCoT], name: str):
|
||||
async def handle_no_command_found(self, context: Context[ClientCoT], name: str) -> None:
|
||||
channel = await self.get_channel(context)
|
||||
await channel.send(f"Command `{name}` not found.")
|
||||
|
||||
async def handle_no_cog_found(self, context: Context[ClientCoT], name: str):
|
||||
async def handle_no_cog_found(self, context: Context[ClientCoT], name: str) -> None:
|
||||
channel = await self.get_channel(context)
|
||||
await channel.send(f"Cog `{name}` not found.")
|
||||
|
||||
@@ -158,14 +158,14 @@ class HelpCommandImpl(Command[ClientCoT]):
|
||||
def __init__(self, client: ClientCoT):
|
||||
self.client = client
|
||||
|
||||
async def callback(_: Union[ClientCoT, Cog[ClientCoT]], context: Context[ClientCoT], *args: str):
|
||||
async def callback(_: Union[ClientCoT, Cog[ClientCoT]], context: Context[ClientCoT], *args: str) -> None:
|
||||
await help_command_impl(context.client, context, *args)
|
||||
|
||||
super().__init__(callback=callback, name="help", aliases=[])
|
||||
self.description = "Shows help for a command, cog or the entire bot"
|
||||
self.description: str | None = "Shows help for a command, cog or the entire bot"
|
||||
|
||||
|
||||
async def help_command_impl(self: ClientT, context: Context[ClientT], *arguments: str):
|
||||
async def help_command_impl(self: ClientT, context: Context[ClientT], *arguments: str) -> None:
|
||||
help_command = self.help_command
|
||||
|
||||
if not help_command:
|
||||
@@ -202,4 +202,5 @@ async def help_command_impl(self: ClientT, context: Context[ClientT], *arguments
|
||||
else:
|
||||
msg_payload = payload
|
||||
|
||||
await help_command.send_help_command(context, msg_payload)
|
||||
message = await help_command.send_help_command(context, msg_payload)
|
||||
await help_command.handle_message(context, message)
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
from typing import Iterator
|
||||
from .errors import NoClosingQuote
|
||||
|
||||
|
||||
class StringView:
|
||||
def __init__(self, string: str):
|
||||
self.value = iter(string)
|
||||
self.temp = ""
|
||||
self.should_undo = False
|
||||
self.value: Iterator[str] = iter(string)
|
||||
self.temp: str = ""
|
||||
self.should_undo: bool = False
|
||||
|
||||
def undo(self):
|
||||
def undo(self) -> None:
|
||||
self.should_undo = True
|
||||
|
||||
def next_char(self) -> str:
|
||||
|
||||
+8
-4
@@ -1,5 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
from typing import Optional, Union
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
__all__ = ("File",)
|
||||
|
||||
@@ -18,17 +20,19 @@ class File:
|
||||
__slots__ = ("f", "spoiler", "filename")
|
||||
|
||||
def __init__(self, file: Union[str, bytes], *, filename: Optional[str] = None, spoiler: bool = False):
|
||||
self.f: io.BufferedIOBase
|
||||
|
||||
if isinstance(file, str):
|
||||
self.f = open(file, "rb")
|
||||
else:
|
||||
self.f = io.BytesIO(file)
|
||||
|
||||
if filename is None and isinstance(file, str):
|
||||
filename = self.f.name
|
||||
filename = cast(Optional[str], self.f.name)
|
||||
|
||||
self.spoiler = spoiler or (filename and filename.startswith("SPOILER_"))
|
||||
self.spoiler: bool = spoiler or (bool(filename) and filename.startswith("SPOILER_"))
|
||||
|
||||
if self.spoiler and (filename and not filename.startswith("SPOILER_")):
|
||||
filename = f"SPOILER_{filename}"
|
||||
|
||||
self.filename = filename
|
||||
self.filename: str | None = filename
|
||||
|
||||
+5
-5
@@ -11,8 +11,8 @@ class Flag:
|
||||
__slots__ = ("flag", "__doc__")
|
||||
|
||||
def __init__(self, func: Callable[[], int]):
|
||||
self.flag = func()
|
||||
self.__doc__ = func.__doc__
|
||||
self.flag: int = func()
|
||||
self.__doc__: str | None = func.__doc__
|
||||
|
||||
@overload
|
||||
def __get__(self: Self, instance: None, owner: type[Flags]) -> Self:
|
||||
@@ -28,7 +28,7 @@ class Flag:
|
||||
|
||||
return instance._check_flag(self.flag)
|
||||
|
||||
def __set__(self, instance: Flags, value: bool):
|
||||
def __set__(self, instance: Flags, value: bool) -> None:
|
||||
instance._set_flag(self.flag, value)
|
||||
|
||||
class Flags:
|
||||
@@ -52,7 +52,7 @@ class Flags:
|
||||
def _check_flag(self, flag: int) -> bool:
|
||||
return (self.value & flag) == flag
|
||||
|
||||
def _set_flag(self, flag: int, value: bool):
|
||||
def _set_flag(self, flag: int, value: bool) -> None:
|
||||
if value:
|
||||
self.value |= flag
|
||||
else:
|
||||
@@ -85,7 +85,7 @@ class Flags:
|
||||
def __gt__(self, other: Self) -> bool:
|
||||
return self.value > other.value
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return f"<{self.__class__.__name__} value={self.value}>"
|
||||
|
||||
def __iter__(self) -> Iterator[tuple[str, bool]]:
|
||||
|
||||
+16
-16
@@ -46,11 +46,11 @@ class HttpClient:
|
||||
__slots__ = ("session", "token", "api_url", "api_info", "auth_header")
|
||||
|
||||
def __init__(self, session: aiohttp.ClientSession, token: str, api_url: str, api_info: ApiInfo, bot: bool = True):
|
||||
self.session = session
|
||||
self.token = token
|
||||
self.api_url = api_url
|
||||
self.api_info = api_info
|
||||
self.auth_header = "x-bot-token" if bot else "x-session-token"
|
||||
self.session: aiohttp.ClientSession = session
|
||||
self.token: str = token
|
||||
self.api_url: str = api_url
|
||||
self.api_info: ApiInfo = api_info
|
||||
self.auth_header: str = "x-bot-token" if bot else "x-session-token"
|
||||
|
||||
async def request(self, method: Literal["GET", "POST", "PUT", "DELETE", "PATCH"], route: str, *, json: Optional[dict[str, Any]] = None, nonce: bool = True, params: Optional[dict[str, Any]] = None) -> Any:
|
||||
url = f"{self.api_url}{route}"
|
||||
@@ -359,19 +359,19 @@ class HttpClient:
|
||||
def delete_invite(self, code: str) -> Request[None]:
|
||||
return self.request("DELETE", f"/invites/{code}")
|
||||
|
||||
def edit_channel(self, channel_id: str, remove: list[str] | None, values: dict[str, Any]):
|
||||
def edit_channel(self, channel_id: str, remove: list[str] | None, values: dict[str, Any]) -> Request[None]:
|
||||
if remove:
|
||||
values["remove"] = remove
|
||||
|
||||
return self.request("PATCH", f"/channels/{channel_id}", json=values)
|
||||
|
||||
def edit_role(self, server_id: str, role_id: str, remove: list[str] | None, values: dict[str, Any]):
|
||||
def edit_role(self, server_id: str, role_id: str, remove: list[str] | None, values: dict[str, Any]) -> Request[None]:
|
||||
if remove:
|
||||
values["remove"] = remove
|
||||
|
||||
return self.request("PATCH", f"/servers/{server_id}/roles/{role_id}", json=values)
|
||||
|
||||
async def edit_self(self, remove: list[str] | None, values: dict[str, Any]):
|
||||
async def edit_self(self, remove: list[str] | None, values: dict[str, Any]) -> Request[None]:
|
||||
if remove:
|
||||
values["remove"] = remove
|
||||
|
||||
@@ -392,25 +392,25 @@ class HttpClient:
|
||||
def set_guild_channel_role_permissions(self, channel_id: str, role_id: str, allow: int, deny: int) -> Request[None]:
|
||||
return self.request("PUT", f"/channels/{channel_id}/permissions/{role_id}", json={"permissions": {"allow": allow, "deny": deny}})
|
||||
|
||||
def set_group_channel_default_permissions(self, channel_id: str, value: int):
|
||||
def set_group_channel_default_permissions(self, channel_id: str, value: int) -> Request[None]:
|
||||
return self.request("PUT", f"/channels/{channel_id}/permissions/default", json={"permissions": value})
|
||||
|
||||
def set_server_role_permissions(self, server_id: str, role_id: str, allow: int, deny: int):
|
||||
def set_server_role_permissions(self, server_id: str, role_id: str, allow: int, deny: int) -> Request[None]:
|
||||
return self.request("PUT", f"/servers/{server_id}/permissions/{role_id}", json={"permissions": {"allow": allow, "deny": deny}})
|
||||
|
||||
def set_server_default_permissions(self, server_id: str, value: int):
|
||||
def set_server_default_permissions(self, server_id: str, value: int) -> Request[None]:
|
||||
return self.request("PUT", f"/servers/{server_id}/permissions/default", json={"permissions": value})
|
||||
|
||||
def add_reaction(self, channel_id: str, message_id: str, emoji: str):
|
||||
def add_reaction(self, channel_id: str, message_id: str, emoji: str) -> Request[None]:
|
||||
return self.request("PUT", f"/channels/{channel_id}/messages/{message_id}/reactions/{emoji}")
|
||||
|
||||
def remove_reaction(self, channel_id: str, message_id: str, emoji: str, user_id: Optional[str], remove_all: bool):
|
||||
def remove_reaction(self, channel_id: str, message_id: str, emoji: str, user_id: Optional[str], remove_all: bool) -> Request[None]:
|
||||
return self.request("PUT", f"/channels/{channel_id}/messages/{message_id}/reactions/{emoji}")
|
||||
|
||||
def remove_all_reactions(self, channel_id: str, message_id: str):
|
||||
def remove_all_reactions(self, channel_id: str, message_id: str) -> Request[None]:
|
||||
return self.request("DELETE", f"/channels/{channel_id}/messages/{message_id}/reactions")
|
||||
|
||||
def delete_emoji(self, emoji_id: str):
|
||||
def delete_emoji(self, emoji_id: str) -> Request[None]:
|
||||
return self.request("DELETE", f"/custom/emoji/{emoji_id}")
|
||||
|
||||
def fetch_emoji(self, emoji_id: str) -> Request[EmojiPayload]:
|
||||
@@ -424,5 +424,5 @@ class HttpClient:
|
||||
def edit_member(self, server_id: str, member_id: str, remove: list[str] | None, values: dict[str, Any]) -> Request[MemberPayload]:
|
||||
return self.request("PATCH", f"/servers/{server_id}/members/{member_id}", json={"remove": remove, **values})
|
||||
|
||||
def delete_messages(self, channel_id: str, messages: list[str]):
|
||||
def delete_messages(self, channel_id: str, messages: list[str]) -> Request[None]:
|
||||
return self.request("DELETE", f"/channels/{channel_id}/messages/bulk", json={"ids": messages})
|
||||
|
||||
+16
-9
@@ -2,12 +2,17 @@ from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
|
||||
from .asset import Asset
|
||||
from .utils import Ulid
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .state import State
|
||||
from .channel import Channel
|
||||
from .server import Server
|
||||
from .types import Invite as InvitePayload
|
||||
from .user import User
|
||||
|
||||
|
||||
__all__ = ("Invite",)
|
||||
|
||||
@@ -37,22 +42,24 @@ class Invite(Ulid):
|
||||
__slots__ = ("state", "code", "id", "server", "channel", "user_name", "user_avatar", "user", "member_count")
|
||||
|
||||
def __init__(self, data: InvitePayload, code: str, state: State):
|
||||
self.state = state
|
||||
self.state: State = state
|
||||
|
||||
self.code = code
|
||||
self.id = code
|
||||
self.server = state.get_server(data["server_id"])
|
||||
self.channel = self.server.get_channel(data["channel_id"])
|
||||
self.code: str = code
|
||||
self.id: str = code
|
||||
self.server: Server = state.get_server(data["server_id"])
|
||||
self.channel: Channel = self.server.get_channel(data["channel_id"])
|
||||
|
||||
self.user_name = data["user_name"]
|
||||
self.user = None
|
||||
self.user_name: str = data["user_name"]
|
||||
self.user: User | None = None
|
||||
|
||||
self.user_avatar: Asset | None
|
||||
|
||||
if avatar := data.get("user_avatar"):
|
||||
self.user_avatar = Asset(avatar, state)
|
||||
else:
|
||||
self.user_avatar = None
|
||||
|
||||
self.member_count = data["member_count"]
|
||||
self.member_count: int = data["member_count"]
|
||||
|
||||
@staticmethod
|
||||
def _from_partial(code: str, server: str, creator: str, channel: str, state: State) -> Invite:
|
||||
@@ -69,6 +76,6 @@ class Invite(Ulid):
|
||||
|
||||
return invite
|
||||
|
||||
async def delete(self):
|
||||
async def delete(self) -> None:
|
||||
"""Deletes the invite"""
|
||||
await self.state.http.delete_invite(self.code)
|
||||
|
||||
+19
-14
@@ -22,7 +22,7 @@ if TYPE_CHECKING:
|
||||
|
||||
__all__ = ("Member",)
|
||||
|
||||
def flattern_user(member: Member, user: User):
|
||||
def flattern_user(member: Member, user: User) -> None:
|
||||
for attr in user.__flattern_attributes__:
|
||||
setattr(member, attr, getattr(user, attr))
|
||||
|
||||
@@ -49,7 +49,9 @@ class Member(User):
|
||||
flattern_user(self, user)
|
||||
user._members[server.id] = self
|
||||
|
||||
self.state = state
|
||||
self.state: State = state
|
||||
|
||||
self.guild_avatar: Asset | None
|
||||
|
||||
if avatar := data.get("avatar"):
|
||||
self.guild_avatar = Asset(avatar, state)
|
||||
@@ -57,20 +59,23 @@ class Member(User):
|
||||
self.guild_avatar = None
|
||||
|
||||
roles = [server.get_role(role_id) for role_id in data.get("roles", [])]
|
||||
self.roles = sorted(roles, key=lambda role: role.rank, reverse=True)
|
||||
self.roles: list[Role] = sorted(roles, key=lambda role: role.rank, reverse=True)
|
||||
|
||||
self.server = server
|
||||
self.nickname = data.get("nickname")
|
||||
self.server: Server = server
|
||||
self.nickname: str | None = data.get("nickname")
|
||||
joined_at = data["joined_at"]
|
||||
|
||||
if isinstance(joined_at, int):
|
||||
self.joined_at = datetime.datetime.fromtimestamp(joined_at / 1000)
|
||||
self.joined_at: datetime.datetime = datetime.datetime.fromtimestamp(joined_at / 1000)
|
||||
else:
|
||||
self.joined_at = datetime.datetime.strptime(joined_at, "%Y-%m-%dT%H:%M:%S.%f%z")
|
||||
self.current_timeout = None
|
||||
self.joined_at: datetime.datetime = datetime.datetime.strptime(joined_at, "%Y-%m-%dT%H:%M:%S.%f%z")
|
||||
|
||||
self.current_timeout: datetime.datetime | None
|
||||
|
||||
if current_timeout := data.get("timeout"):
|
||||
self.current_timeout = datetime.datetime.strptime(current_timeout, "%Y-%m-%dT%H:%M:%S.%f%z")
|
||||
else:
|
||||
self.current_timeout = None
|
||||
|
||||
@property
|
||||
def avatar(self) -> Optional[Asset]:
|
||||
@@ -93,11 +98,11 @@ class Member(User):
|
||||
member_roles = [self.server.get_role(role_id) for role_id in roles]
|
||||
self.roles = sorted(member_roles, key=lambda role: role.rank, reverse=True)
|
||||
|
||||
async def kick(self):
|
||||
async def kick(self) -> None:
|
||||
"""Kicks the member from the server"""
|
||||
await self.state.http.kick_member(self.server.id, self.id)
|
||||
|
||||
async def ban(self, *, reason: Optional[str] = None):
|
||||
async def ban(self, *, reason: Optional[str] = None) -> None:
|
||||
"""Bans the member from the server
|
||||
|
||||
Parameters
|
||||
@@ -107,7 +112,7 @@ class Member(User):
|
||||
"""
|
||||
await self.state.http.ban_member(self.server.id, self.id, reason)
|
||||
|
||||
async def unban(self):
|
||||
async def unban(self) -> None:
|
||||
"""Unbans the member from the server"""
|
||||
await self.state.http.unban_member(self.server.id, self.id)
|
||||
|
||||
@@ -118,7 +123,7 @@ class Member(User):
|
||||
roles: list[Role] | None | _Missing = Missing,
|
||||
avatar: File | None | _Missing = Missing,
|
||||
timeout: datetime.timedelta | None | _Missing = Missing
|
||||
):
|
||||
) -> None:
|
||||
remove: list[str] = []
|
||||
data: dict[str, Any] = {}
|
||||
|
||||
@@ -148,7 +153,7 @@ class Member(User):
|
||||
|
||||
await self.state.http.edit_member(self.server.id, self.id, remove, data)
|
||||
|
||||
async def timeout(self, length: datetime.timedelta):
|
||||
async def timeout(self, length: datetime.timedelta) -> None:
|
||||
"""Timeouts the member
|
||||
|
||||
Parameters
|
||||
@@ -170,7 +175,7 @@ class Member(User):
|
||||
"""
|
||||
return calculate_permissions(self, self.server)
|
||||
|
||||
def get_channel_permissions(self, channel: Channel):
|
||||
def get_channel_permissions(self, channel: Channel) -> Permissions:
|
||||
"""Gets the permissions for the member in the server taking into account the channel as well
|
||||
|
||||
Parameters
|
||||
|
||||
+39
-24
@@ -1,11 +1,13 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, Coroutine, Optional, Union
|
||||
|
||||
from revolt.types.message import SystemMessageContent
|
||||
|
||||
from .asset import Asset, PartialAsset
|
||||
from .channel import Messageable
|
||||
from .embed import SendableEmbed, to_embed
|
||||
from .channel import DMChannel, GroupDMChannel, TextChannel
|
||||
from .embed import Embed, SendableEmbed, to_embed
|
||||
from .utils import Ulid
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -17,6 +19,7 @@ if TYPE_CHECKING:
|
||||
from .types import Message as MessagePayload
|
||||
from .types import MessageReplyPayload
|
||||
from .user import User
|
||||
from .member import Member
|
||||
|
||||
__all__ = (
|
||||
"Message",
|
||||
@@ -58,18 +61,28 @@ class Message(Ulid):
|
||||
__slots__ = ("state", "id", "content", "attachments", "embeds", "channel", "author", "edited_at", "mentions", "replies", "reply_ids", "reactions", "interactions")
|
||||
|
||||
def __init__(self, data: MessagePayload, state: State):
|
||||
self.state = state
|
||||
self.state: State = state
|
||||
|
||||
self.id = data["_id"]
|
||||
self.content = data.get("content", "")
|
||||
self.attachments = [Asset(attachment, state) for attachment in data.get("attachments", [])]
|
||||
self.embeds = [to_embed(embed, state) for embed in data.get("embeds", [])]
|
||||
self.id: str = data["_id"]
|
||||
|
||||
content = data.get("content", "")
|
||||
|
||||
if not isinstance(content, str):
|
||||
self.system_content: SystemMessageContent = content
|
||||
self.content: str = ""
|
||||
else:
|
||||
self.content = content
|
||||
|
||||
self.attachments: list[Asset] = [Asset(attachment, state) for attachment in data.get("attachments", [])]
|
||||
self.embeds: list[Embed] = [to_embed(embed, state) for embed in data.get("embeds", [])]
|
||||
|
||||
channel = state.get_channel(data["channel"])
|
||||
assert isinstance(channel, Messageable)
|
||||
self.channel = channel
|
||||
assert isinstance(channel, Union[TextChannel, GroupDMChannel, DMChannel])
|
||||
self.channel: TextChannel | GroupDMChannel | DMChannel = channel
|
||||
|
||||
self.server_id = self.channel.server_id
|
||||
self.server_id: str | None = self.channel.server_id
|
||||
|
||||
self.mentions: list[Member | User]
|
||||
|
||||
if self.server_id:
|
||||
author = state.get_member(self.server_id, data["author"])
|
||||
@@ -78,7 +91,7 @@ class Message(Ulid):
|
||||
author = state.get_user(data["author"])
|
||||
self.mentions = [state.get_user(member_id) for member_id in data.get("mentions", [])]
|
||||
|
||||
self.author = author
|
||||
self.author: Member | User = author
|
||||
|
||||
if masquerade := data.get("masquerade"):
|
||||
if name := masquerade.get("name"):
|
||||
@@ -109,6 +122,8 @@ class Message(Ulid):
|
||||
for emoji, users in reactions.items():
|
||||
self.reactions[emoji] = [self.state.get_user(user_id) for user_id in users]
|
||||
|
||||
self.interactions: MessageInteractions | None
|
||||
|
||||
if interactions := data.get("interactions"):
|
||||
self.interactions = MessageInteractions(reactions=interactions.get("reactions"), restrict_reactions=interactions.get("restrict_reactions", False))
|
||||
else:
|
||||
@@ -143,7 +158,7 @@ class Message(Ulid):
|
||||
"""Deletes the message. The bot can only delete its own messages and messages it has permission to delete """
|
||||
await self.state.http.delete_message(self.channel.id, self.id)
|
||||
|
||||
def reply(self, *args: Any, mention: bool = False, **kwargs: Any):
|
||||
def reply(self, *args: Any, mention: bool = False, **kwargs: Any) -> Coroutine[Any, Any, Message]:
|
||||
"""Replies to this message, equivilant to:
|
||||
|
||||
.. code-block:: python
|
||||
@@ -153,13 +168,13 @@ class Message(Ulid):
|
||||
"""
|
||||
return self.channel.send(*args, **kwargs, replies=[MessageReply(self, mention)])
|
||||
|
||||
async def add_reaction(self, emoji: str):
|
||||
async def add_reaction(self, emoji: str) -> None:
|
||||
await self.state.http.add_reaction(self.channel.id, self.id, emoji)
|
||||
|
||||
async def remove_reaction(self, emoji: str, user: Optional[User] = None, remove_all: bool = False):
|
||||
async def remove_reaction(self, emoji: str, user: Optional[User] = None, remove_all: bool = False) -> None:
|
||||
await self.state.http.remove_reaction(self.channel.id, self.id, emoji, user.id if user else None, remove_all)
|
||||
|
||||
async def remove_all_reactions(self):
|
||||
async def remove_all_reactions(self) -> None:
|
||||
await self.state.http.remove_all_reactions(self.channel.id, self.id)
|
||||
|
||||
|
||||
@@ -187,8 +202,8 @@ class MessageReply:
|
||||
__slots__ = ("message", "mention")
|
||||
|
||||
def __init__(self, message: Message, mention: bool = False):
|
||||
self.message = message
|
||||
self.mention = mention
|
||||
self.message: Message = message
|
||||
self.mention: bool = mention
|
||||
|
||||
def to_dict(self) -> MessageReplyPayload:
|
||||
return { "id": self.message.id, "mention": self.mention }
|
||||
@@ -208,9 +223,9 @@ class Masquerade:
|
||||
__slots__ = ("name", "avatar", "colour")
|
||||
|
||||
def __init__(self, name: Optional[str] = None, avatar: Optional[str] = None, colour: Optional[str] = None):
|
||||
self.name = name
|
||||
self.avatar = avatar
|
||||
self.colour = colour
|
||||
self.name: str | None = name
|
||||
self.avatar: str | None = avatar
|
||||
self.colour: str | None = colour
|
||||
|
||||
def to_dict(self) -> MasqueradePayload:
|
||||
output: MasqueradePayload = {}
|
||||
@@ -239,10 +254,10 @@ class MessageInteractions:
|
||||
__slots__ = ("reactions", "restrict_reactions")
|
||||
|
||||
def __init__(self, *, reactions: Optional[list[str]] = None, restrict_reactions: bool = False):
|
||||
self.reactions = reactions
|
||||
self.restrict_reactions = restrict_reactions
|
||||
self.reactions: list[str] | None = reactions
|
||||
self.restrict_reactions: bool = restrict_reactions
|
||||
|
||||
def to_dict(self):
|
||||
def to_dict(self) -> InteractionsPayload:
|
||||
output: InteractionsPayload = {}
|
||||
|
||||
if reactions := self.reactions:
|
||||
|
||||
@@ -138,7 +138,7 @@ class Messageable:
|
||||
payloads = await self.state.http.search_messages(await self._get_channel_id(), query, sort=sort, limit=limit, before=before, after=after)
|
||||
return [Message(payload, self.state) for payload in payloads]
|
||||
|
||||
async def delete_messages(self, messages: list[Message]):
|
||||
async def delete_messages(self, messages: list[Message]) -> None:
|
||||
"""Bulk deletes messages from the channel
|
||||
|
||||
.. note:: The messages must have been sent in the last 7 days.
|
||||
|
||||
@@ -176,7 +176,7 @@ class PermissionsOverwrite:
|
||||
|
||||
super().__setattr__(perm, value)
|
||||
|
||||
def __setattr__(self, key: str, value: Any):
|
||||
def __setattr__(self, key: str, value: Any) -> None:
|
||||
if key in Permissions.FLAG_NAMES:
|
||||
if key is True:
|
||||
setattr(self._allow, key, True)
|
||||
|
||||
+13
-13
@@ -35,20 +35,20 @@ class Role(Ulid):
|
||||
channel_permissions: :class:`ChannelPermissions`
|
||||
The channel permissions for the role
|
||||
"""
|
||||
__slots__ = ("id", "name", "colour", "hoist", "rank", "state", "server", "permissions")
|
||||
__slots__: tuple[str, ...] = ("id", "name", "colour", "hoist", "rank", "state", "server", "permissions")
|
||||
|
||||
def __init__(self, data: RolePayload, role_id: str, server: Server, state: State):
|
||||
self.state = state
|
||||
self.id = role_id
|
||||
self.name = data["name"]
|
||||
self.colour = data.get("colour", None)
|
||||
self.hoist = False
|
||||
self.rank = 0
|
||||
self.server = server
|
||||
self.permissions = PermissionsOverwrite._from_overwrite(data.get("permissions", {"a": 0, "d": 0}))
|
||||
self.state: State = state
|
||||
self.id: str = role_id
|
||||
self.name: str = data["name"]
|
||||
self.colour: str | None = data.get("colour", None)
|
||||
self.hoist: bool = data.get("hoist", False)
|
||||
self.rank: int = data["rank"]
|
||||
self.server: Server = server
|
||||
self.permissions: PermissionsOverwrite = PermissionsOverwrite._from_overwrite(data.get("permissions", {"a": 0, "d": 0}))
|
||||
|
||||
@property
|
||||
def color(self):
|
||||
def color(self) -> str | None:
|
||||
return self.colour
|
||||
|
||||
async def set_permissions_overwrite(self, *, permissions: PermissionsOverwrite) -> None:
|
||||
@@ -63,7 +63,7 @@ class Role(Ulid):
|
||||
allow, deny = permissions.to_pair()
|
||||
await self.state.http.set_server_role_permissions(self.server.id, self.id, allow.value, deny.value)
|
||||
|
||||
def _update(self, *, name: Optional[str] = None, colour: Optional[str] = None, hoist: Optional[bool] = None, rank: Optional[int] = None, permissions: Optional[Overwrite] = None):
|
||||
def _update(self, *, name: Optional[str] = None, colour: Optional[str] = None, hoist: Optional[bool] = None, rank: Optional[int] = None, permissions: Optional[Overwrite] = None) -> None:
|
||||
if name is not None:
|
||||
self.name = name
|
||||
|
||||
@@ -79,11 +79,11 @@ class Role(Ulid):
|
||||
if permissions is not None:
|
||||
self.permissions = PermissionsOverwrite._from_overwrite(permissions)
|
||||
|
||||
async def delete(self):
|
||||
async def delete(self) -> None:
|
||||
"""Deletes the role"""
|
||||
await self.state.http.delete_role(self.server.id, self.id)
|
||||
|
||||
async def edit(self, **kwargs: Any):
|
||||
async def edit(self, **kwargs: Any) -> None:
|
||||
"""Edits the role
|
||||
|
||||
Parameters
|
||||
|
||||
+26
-22
@@ -25,11 +25,11 @@ __all__ = ("Server", "SystemMessages", "ServerBan")
|
||||
|
||||
class SystemMessages:
|
||||
def __init__(self, data: SystemMessagesConfig, state: State):
|
||||
self.state = state
|
||||
self.user_joined_id = data.get("user_joined")
|
||||
self.user_left_id = data.get("user_left")
|
||||
self.user_kicked_id = data.get("user_kicked")
|
||||
self.user_banned_id = data.get("user_banned")
|
||||
self.state: State = state
|
||||
self.user_joined_id: str | None = data.get("user_joined")
|
||||
self.user_left_id: str | None = data.get("user_left")
|
||||
self.user_kicked_id: str | None = data.get("user_kicked")
|
||||
self.user_banned_id: str | None = data.get("user_banned")
|
||||
|
||||
@property
|
||||
def user_joined(self) -> Optional[TextChannel]:
|
||||
@@ -94,21 +94,25 @@ class Server(Ulid):
|
||||
__slots__ = ("state", "id", "name", "owner_id", "default_permissions", "_members", "_roles", "_channels", "description", "icon", "banner", "nsfw", "system_messages", "_categories", "_emojis")
|
||||
|
||||
def __init__(self, data: ServerPayload, state: State):
|
||||
self.state = state
|
||||
self.id = data["_id"]
|
||||
self.name = data["name"]
|
||||
self.owner_id = data["owner"]
|
||||
self.description = data.get("description") or None
|
||||
self.nsfw = data.get("nsfw", False)
|
||||
self.system_messages = SystemMessages(data.get("system_messages", cast("SystemMessagesConfig", {})), state)
|
||||
self._categories = {data["id"]: Category(data, state) for data in data.get("categories", [])}
|
||||
self.default_permissions = Permissions(data["default_permissions"])
|
||||
self.state: State = state
|
||||
self.id: str = data["_id"]
|
||||
self.name: str = data["name"]
|
||||
self.owner_id: str = data["owner"]
|
||||
self.description: str | None = data.get("description") or None
|
||||
self.nsfw: bool = data.get("nsfw", False)
|
||||
self.system_messages: SystemMessages = SystemMessages(data.get("system_messages", cast("SystemMessagesConfig", {})), state)
|
||||
self._categories: dict[str, Category] = {data["id"]: Category(data, state) for data in data.get("categories", [])}
|
||||
self.default_permissions: Permissions = Permissions(data["default_permissions"])
|
||||
|
||||
self.icon: Asset | None
|
||||
|
||||
if icon := data.get("icon"):
|
||||
self.icon = Asset(icon, state)
|
||||
else:
|
||||
self.icon = None
|
||||
|
||||
self.banner: Asset | None
|
||||
|
||||
if banner := data.get("banner"):
|
||||
self.banner = Asset(banner, state)
|
||||
else:
|
||||
@@ -279,11 +283,11 @@ class Server(Ulid):
|
||||
|
||||
await self.state.http.set_server_default_permissions(self.id, permissions.value)
|
||||
|
||||
async def leave_server(self):
|
||||
async def leave_server(self) -> None:
|
||||
"""Leaves or deletes the server"""
|
||||
await self.state.http.delete_leave_server(self.id)
|
||||
|
||||
async def delete_server(self):
|
||||
async def delete_server(self) -> None:
|
||||
"""Leaves or deletes a server, alias to :meth`Server.leave_server`"""
|
||||
await self.leave_server()
|
||||
|
||||
@@ -388,7 +392,7 @@ class Server(Ulid):
|
||||
|
||||
return Role(payload, name, self, self.state)
|
||||
|
||||
async def create_emoji(self, name: str, file: File, *, nsfw: bool = False):
|
||||
async def create_emoji(self, name: str, file: File, *, nsfw: bool = False) -> Emoji:
|
||||
"""Creates an emoji
|
||||
|
||||
Parameters
|
||||
@@ -421,11 +425,11 @@ class ServerBan:
|
||||
__slots__ = ("reason", "server", "user_id", "state")
|
||||
|
||||
def __init__(self, ban: Ban, state: State):
|
||||
self.reason = ban.get("reason")
|
||||
self.server = state.get_server(ban["_id"]["server"])
|
||||
self.user_id = ban["_id"]["user"]
|
||||
self.state = state
|
||||
self.reason: str | None = ban.get("reason")
|
||||
self.server: Server = state.get_server(ban["_id"]["server"])
|
||||
self.user_id: str = ban["_id"]["user"]
|
||||
self.state: State = state
|
||||
|
||||
async def unban(self):
|
||||
async def unban(self) -> None:
|
||||
"""Unbans the user"""
|
||||
await self.state.http.unban_member(self.server.id, self.user_id)
|
||||
|
||||
+5
-5
@@ -26,9 +26,9 @@ class State:
|
||||
__slots__ = ("http", "api_info", "max_messages", "users", "channels", "servers", "messages", "global_emojis", "user_id", "me")
|
||||
|
||||
def __init__(self, http: HttpClient, api_info: ApiInfo, max_messages: int):
|
||||
self.http = http
|
||||
self.api_info = api_info
|
||||
self.max_messages = max_messages
|
||||
self.http: HttpClient = http
|
||||
self.api_info: ApiInfo = api_info
|
||||
self.max_messages: int = max_messages
|
||||
|
||||
self.me: User
|
||||
|
||||
@@ -114,7 +114,7 @@ class State:
|
||||
|
||||
raise LookupError
|
||||
|
||||
async def fetch_server_members(self, server_id: str):
|
||||
async def fetch_server_members(self, server_id: str) -> None:
|
||||
data = await self.http.fetch_members(server_id)
|
||||
|
||||
for user in data["users"]:
|
||||
@@ -123,6 +123,6 @@ class State:
|
||||
for member in data["members"]:
|
||||
self.add_member(server_id, member)
|
||||
|
||||
async def fetch_all_server_members(self):
|
||||
async def fetch_all_server_members(self) -> None:
|
||||
for server_id in self.servers:
|
||||
await self.fetch_server_members(server_id)
|
||||
|
||||
@@ -65,11 +65,13 @@ class Interactions(TypedDict):
|
||||
reactions: NotRequired[list[str]]
|
||||
restrict_reactions: NotRequired[bool]
|
||||
|
||||
SystemMessageContent = Union[UserAddContent, UserRemoveContent, UserJoinedContent, UserLeftContent, UserKickedContent, UserBannedContent, ChannelRenameContent, ChannelDescriptionChangeContent, ChannelIconChangeContent]
|
||||
|
||||
class Message(TypedDict):
|
||||
_id: str
|
||||
channel: str
|
||||
author: str
|
||||
content: Union[str, UserAddContent, UserRemoveContent, UserJoinedContent, UserLeftContent, UserKickedContent, UserBannedContent, ChannelRenameContent, ChannelDescriptionChangeContent, ChannelIconChangeContent]
|
||||
content: Union[str, SystemMessageContent]
|
||||
attachments: NotRequired[list[File]]
|
||||
embeds: NotRequired[list[Embed]]
|
||||
mentions: NotRequired[list[str]]
|
||||
|
||||
+18
-12
@@ -65,17 +65,21 @@ class User(Messageable, Ulid):
|
||||
privileged: :class:`bool`
|
||||
Whether the user is privileged
|
||||
"""
|
||||
__flattern_attributes__ = ("id", "bot", "owner_id", "badges", "online", "flags", "relations", "relationship", "status", "masquerade_avatar", "masquerade_name", "original_name", "original_avatar", "profile", "dm_channel", "privileged")
|
||||
__slots__ = (*__flattern_attributes__, "state", "_members")
|
||||
__flattern_attributes__: tuple[str, ...] = ("id", "bot", "owner_id", "badges", "online", "flags", "relations", "relationship", "status", "masquerade_avatar", "masquerade_name", "original_name", "original_avatar", "profile", "dm_channel", "privileged")
|
||||
__slots__: tuple[str, ...] = (*__flattern_attributes__, "state", "_members")
|
||||
|
||||
def __init__(self, data: UserPayload, state: State):
|
||||
self.state = state
|
||||
self._members: WeakValueDictionary[str, Member] = WeakValueDictionary() # we store all member versions of this user to avoid having to check every guild when needing to update.
|
||||
self.id = data["_id"]
|
||||
self.original_name = data["username"]
|
||||
self.dm_channel = None
|
||||
self.id: str = data["_id"]
|
||||
self.original_name: str = data["username"]
|
||||
self.dm_channel: DMChannel | None = None
|
||||
|
||||
bot = data.get("bot")
|
||||
|
||||
self.bot: bool
|
||||
self.owner_id: str | None
|
||||
|
||||
if bot:
|
||||
self.bot = True
|
||||
self.owner_id = bot["owner"]
|
||||
@@ -83,13 +87,13 @@ class User(Messageable, Ulid):
|
||||
self.bot = False
|
||||
self.owner_id = None
|
||||
|
||||
self.badges = UserBadges._from_value(data.get("badges", 0))
|
||||
self.online = data.get("online", False)
|
||||
self.flags = data.get("flags", 0)
|
||||
self.privileged = data.get("privileged", False)
|
||||
self.badges: UserBadges = UserBadges._from_value(data.get("badges", 0))
|
||||
self.online: bool = data.get("online", False)
|
||||
self.flags: int = data.get("flags", 0)
|
||||
self.privileged: bool = data.get("privileged", False)
|
||||
|
||||
avatar = data.get("avatar")
|
||||
self.original_avatar = Asset(avatar, state) if avatar else None
|
||||
self.original_avatar: Asset | None = Asset(avatar, state) if avatar else None
|
||||
|
||||
relations: list[Relation] = []
|
||||
|
||||
@@ -97,12 +101,14 @@ class User(Messageable, Ulid):
|
||||
user = state.get_user(relation["_id"])
|
||||
if user:
|
||||
relations.append(Relation(RelationshipType(relation["status"]), user))
|
||||
self.relations = relations
|
||||
self.relations: list[Relation] = relations
|
||||
|
||||
relationship = data.get("relationship")
|
||||
self.relationship = RelationshipType(relationship) if relationship else None
|
||||
self.relationship: RelationshipType | None = RelationshipType(relationship) if relationship else None
|
||||
|
||||
status = data.get("status")
|
||||
self.status: Status | None
|
||||
|
||||
if status:
|
||||
presence = status.get("presence")
|
||||
self.status = Status(status.get("text"), PresenceType(presence) if presence else None) if status else None
|
||||
|
||||
+2
-2
@@ -11,13 +11,13 @@ from typing_extensions import ParamSpec
|
||||
__all__ = ("Missing", "copy_doc", "maybe_coroutine", "get", "client_session")
|
||||
|
||||
class _Missing:
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return "<Missing>"
|
||||
|
||||
def __bool__(self) -> Literal[False]:
|
||||
return False
|
||||
|
||||
Missing = _Missing()
|
||||
Missing: _Missing = _Missing()
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
+43
-41
@@ -27,7 +27,7 @@ from .types import (ServerCreateEventPayload, ServerDeleteEventPayload,
|
||||
ServerRoleDeleteEventPayload, ServerRoleUpdateEventPayload,
|
||||
ServerUpdateEventPayload, UserRelationshipEventPayload,
|
||||
UserUpdateEventPayload)
|
||||
from .user import Status, UserProfile
|
||||
from .user import Status, User, UserProfile
|
||||
|
||||
import aiohttp
|
||||
|
||||
@@ -36,6 +36,8 @@ try:
|
||||
except ImportError:
|
||||
import json
|
||||
|
||||
use_msgpack: bool
|
||||
|
||||
try:
|
||||
import msgpack
|
||||
use_msgpack = True
|
||||
@@ -54,42 +56,42 @@ class WSMessage(NamedTuple):
|
||||
type: aiohttp.WSMsgType
|
||||
data: str | bytes | aiohttp.WSCloseCode
|
||||
|
||||
__all__ = ("WebsocketHandler",)
|
||||
__all__: tuple[str, ...] = ("WebsocketHandler",)
|
||||
|
||||
logger = logging.getLogger("revolt")
|
||||
logger: logging.Logger = logging.getLogger("revolt")
|
||||
|
||||
class WebsocketHandler:
|
||||
__slots__ = ("session", "token", "ws_url", "dispatch", "state", "websocket", "loop", "user", "ready", "server_events")
|
||||
|
||||
def __init__(self, session: aiohttp.ClientSession, token: str, ws_url: str, dispatch: Callable[..., None], state: State):
|
||||
self.session = session
|
||||
self.token = token
|
||||
self.ws_url = ws_url
|
||||
self.dispatch = dispatch
|
||||
self.state = state
|
||||
self.session: aiohttp.ClientSession = session
|
||||
self.token: str = token
|
||||
self.ws_url: str = ws_url
|
||||
self.dispatch: Callable[..., None] = dispatch
|
||||
self.state: State = state
|
||||
self.websocket: aiohttp.ClientWebSocketResponse
|
||||
self.loop = asyncio.get_running_loop()
|
||||
self.user = None
|
||||
self.ready = asyncio.Event()
|
||||
self.loop: asyncio.AbstractEventLoop = asyncio.get_running_loop()
|
||||
self.user: User | None = None
|
||||
self.ready: asyncio.Event = asyncio.Event()
|
||||
self.server_events: dict[str, asyncio.Event] = {}
|
||||
|
||||
async def _wait_for_server_ready(self, server_id: str):
|
||||
async def _wait_for_server_ready(self, server_id: str) -> None:
|
||||
if event := self.server_events.get(server_id):
|
||||
await event.wait()
|
||||
|
||||
async def send_payload(self, payload: BasePayload):
|
||||
async def send_payload(self, payload: BasePayload) -> None:
|
||||
if use_msgpack:
|
||||
await self.websocket.send_bytes(msgpack.packb(payload))
|
||||
else:
|
||||
await self.websocket.send_str(json.dumps(payload))
|
||||
|
||||
async def heartbeat(self):
|
||||
async def heartbeat(self) -> None:
|
||||
while not self.websocket.closed:
|
||||
logger.info("Sending hearbeat")
|
||||
await self.websocket.ping()
|
||||
await asyncio.sleep(15)
|
||||
|
||||
async def send_authenticate(self):
|
||||
async def send_authenticate(self) -> None:
|
||||
payload: AuthenticatePayload = {
|
||||
"type": "Authenticate",
|
||||
"token": self.token
|
||||
@@ -97,7 +99,7 @@ class WebsocketHandler:
|
||||
|
||||
await self.send_payload(payload)
|
||||
|
||||
async def handle_event(self, payload: BasePayload):
|
||||
async def handle_event(self, payload: BasePayload) -> None:
|
||||
event_type = payload["type"].lower()
|
||||
logger.debug("Recieved event %s %s", event_type, payload)
|
||||
try:
|
||||
@@ -110,10 +112,10 @@ class WebsocketHandler:
|
||||
|
||||
await func(payload)
|
||||
|
||||
async def handle_authenticated(self, _):
|
||||
async def handle_authenticated(self, _: BasePayload) -> None:
|
||||
logger.info("Successfully authenticated")
|
||||
|
||||
async def handle_ready(self, payload: ReadyEventPayload):
|
||||
async def handle_ready(self, payload: ReadyEventPayload) -> None:
|
||||
for user_payload in payload["users"]:
|
||||
user = self.state.add_user(user_payload)
|
||||
|
||||
@@ -138,7 +140,7 @@ class WebsocketHandler:
|
||||
self.ready.set()
|
||||
self.dispatch("ready")
|
||||
|
||||
async def handle_message(self, payload: MessageEventPayload):
|
||||
async def handle_message(self, payload: MessageEventPayload) -> None:
|
||||
if server := self.state.get_channel(payload["channel"]).server_id:
|
||||
await self._wait_for_server_ready(server)
|
||||
|
||||
@@ -147,7 +149,7 @@ class WebsocketHandler:
|
||||
|
||||
self.dispatch("message", message)
|
||||
|
||||
async def handle_messageupdate(self, payload: MessageUpdateEventPayload):
|
||||
async def handle_messageupdate(self, payload: MessageUpdateEventPayload) -> None:
|
||||
self.dispatch("raw_message_update", payload)
|
||||
|
||||
try:
|
||||
@@ -162,7 +164,7 @@ class WebsocketHandler:
|
||||
|
||||
self.dispatch("message_update", message)
|
||||
|
||||
async def handle_messagedelete(self, payload: MessageDeleteEventPayload):
|
||||
async def handle_messagedelete(self, payload: MessageDeleteEventPayload) -> None:
|
||||
self.dispatch("raw_message_delete", payload)
|
||||
|
||||
try:
|
||||
@@ -178,7 +180,7 @@ class WebsocketHandler:
|
||||
|
||||
self.dispatch("message_delete", message)
|
||||
|
||||
async def handle_channelcreate(self, payload: ChannelCreateEventPayload):
|
||||
async def handle_channelcreate(self, payload: ChannelCreateEventPayload) -> None:
|
||||
channel = self.state.add_channel(payload)
|
||||
|
||||
if server_id := channel.server_id:
|
||||
@@ -186,7 +188,7 @@ class WebsocketHandler:
|
||||
|
||||
self.dispatch("channel_create", channel)
|
||||
|
||||
async def handle_channelupdate(self, payload: ChannelUpdateEventPayload):
|
||||
async def handle_channelupdate(self, payload: ChannelUpdateEventPayload) -> None:
|
||||
# Revolt sends channel updates for channels we dont have permissions to see, a bug, but still can cause issues as its not in the cache
|
||||
|
||||
if not (channel := self.state.channels.get(payload["id"], None)):
|
||||
@@ -211,7 +213,7 @@ class WebsocketHandler:
|
||||
|
||||
self.dispatch("channel_update", old_channel, channel)
|
||||
|
||||
async def handle_channeldelete(self, payload: ChannelDeleteEventPayload):
|
||||
async def handle_channeldelete(self, payload: ChannelDeleteEventPayload) -> None:
|
||||
channel = self.state.channels.pop(payload["id"])
|
||||
|
||||
if server_id := channel.server_id:
|
||||
@@ -219,7 +221,7 @@ class WebsocketHandler:
|
||||
|
||||
self.dispatch("channel_delete", channel)
|
||||
|
||||
async def handle_channelstarttyping(self, payload: ChannelStartTypingEventPayload):
|
||||
async def handle_channelstarttyping(self, payload: ChannelStartTypingEventPayload) -> None:
|
||||
channel = self.state.get_channel(payload["id"])
|
||||
|
||||
if server_id := channel.server_id:
|
||||
@@ -229,7 +231,7 @@ class WebsocketHandler:
|
||||
|
||||
self.dispatch("typing_start", channel, user)
|
||||
|
||||
async def handle_channelstoptyping(self, payload: ChannelDeleteTypingEventPayload):
|
||||
async def handle_channelstoptyping(self, payload: ChannelDeleteTypingEventPayload) -> None:
|
||||
channel = self.state.get_channel(payload["id"])
|
||||
|
||||
if server_id := channel.server_id:
|
||||
@@ -239,7 +241,7 @@ class WebsocketHandler:
|
||||
|
||||
self.dispatch("typing_stop", channel, user)
|
||||
|
||||
async def handle_serverupdate(self, payload: ServerUpdateEventPayload):
|
||||
async def handle_serverupdate(self, payload: ServerUpdateEventPayload) -> None:
|
||||
await self._wait_for_server_ready(payload["id"])
|
||||
|
||||
server = self.state.get_server(payload["id"])
|
||||
@@ -261,7 +263,7 @@ class WebsocketHandler:
|
||||
|
||||
self.dispatch("server_update", old_server, server)
|
||||
|
||||
async def handle_serverdelete(self, payload: ServerDeleteEventPayload):
|
||||
async def handle_serverdelete(self, payload: ServerDeleteEventPayload) -> None:
|
||||
server = self.state.servers.pop(payload["id"])
|
||||
|
||||
for channel in server.channels:
|
||||
@@ -271,7 +273,7 @@ class WebsocketHandler:
|
||||
|
||||
self.dispatch("server_delete", server)
|
||||
|
||||
async def handle_servercreate(self, payload: ServerCreateEventPayload):
|
||||
async def handle_servercreate(self, payload: ServerCreateEventPayload) -> None:
|
||||
for channel in payload["channels"]:
|
||||
self.state.add_channel(channel)
|
||||
|
||||
@@ -284,7 +286,7 @@ class WebsocketHandler:
|
||||
|
||||
self.dispatch("server_join", server)
|
||||
|
||||
async def handle_servermemberupdate(self, payload: ServerMemberUpdateEventPayload):
|
||||
async def handle_servermemberupdate(self, payload: ServerMemberUpdateEventPayload) -> None:
|
||||
await self._wait_for_server_ready(payload["id"]["server"])
|
||||
|
||||
member = self.state.get_member(payload["id"]["server"], payload["id"]["user"])
|
||||
@@ -300,7 +302,7 @@ class WebsocketHandler:
|
||||
|
||||
self.dispatch("member_update", old_member, member)
|
||||
|
||||
async def handle_servermemberjoin(self, payload: ServerMemberJoinEventPayload):
|
||||
async def handle_servermemberjoin(self, payload: ServerMemberJoinEventPayload) -> None:
|
||||
# avoid an api request if possible
|
||||
if payload["user"] not in self.state.users:
|
||||
user = await self.state.http.fetch_user(payload["user"])
|
||||
@@ -310,7 +312,7 @@ class WebsocketHandler:
|
||||
|
||||
self.dispatch("member_join", member)
|
||||
|
||||
async def handle_memberleave(self, payload: ServerMemberLeaveEventPayload):
|
||||
async def handle_memberleave(self, payload: ServerMemberLeaveEventPayload) -> None:
|
||||
await self._wait_for_server_ready(payload["id"])
|
||||
|
||||
server = self.state.get_server(payload["id"])
|
||||
@@ -323,7 +325,7 @@ class WebsocketHandler:
|
||||
|
||||
self.dispatch("member_leave", member)
|
||||
|
||||
async def handle_serverroleupdate(self, payload: ServerRoleUpdateEventPayload):
|
||||
async def handle_serverroleupdate(self, payload: ServerRoleUpdateEventPayload) -> None:
|
||||
server = self.state.get_server(payload["id"])
|
||||
await self._wait_for_server_ready(server.id)
|
||||
|
||||
@@ -346,7 +348,7 @@ class WebsocketHandler:
|
||||
|
||||
self.dispatch("role_update", old_role, role)
|
||||
|
||||
async def handle_serverroledelete(self, payload: ServerRoleDeleteEventPayload):
|
||||
async def handle_serverroledelete(self, payload: ServerRoleDeleteEventPayload) -> None:
|
||||
server = self.state.get_server(payload["id"])
|
||||
role = server._roles.pop(payload["role_id"])
|
||||
|
||||
@@ -354,7 +356,7 @@ class WebsocketHandler:
|
||||
|
||||
self.dispatch("role_delete", role)
|
||||
|
||||
async def handle_userupdate(self, payload: UserUpdateEventPayload):
|
||||
async def handle_userupdate(self, payload: UserUpdateEventPayload) -> None:
|
||||
user = self.state.get_user(payload["id"])
|
||||
old_user = copy(user)
|
||||
|
||||
@@ -377,14 +379,14 @@ class WebsocketHandler:
|
||||
|
||||
self.dispatch("user_update", old_user, user)
|
||||
|
||||
async def handle_userrelationship(self, payload: UserRelationshipEventPayload):
|
||||
async def handle_userrelationship(self, payload: UserRelationshipEventPayload) -> None:
|
||||
user = self.state.get_user(payload["user"])
|
||||
old_relationship = user.relationship
|
||||
user.relationship = RelationshipType(payload["status"])
|
||||
|
||||
self.dispatch("user_relationship_update", user, old_relationship, user.relationship)
|
||||
|
||||
async def handle_messagereact(self, payload: MessageReactEventPayload):
|
||||
async def handle_messagereact(self, payload: MessageReactEventPayload) -> None:
|
||||
if server := self.state.get_channel(payload["channel_id"]).server_id:
|
||||
await self._wait_for_server_ready(server)
|
||||
|
||||
@@ -401,7 +403,7 @@ class WebsocketHandler:
|
||||
|
||||
self.dispatch("reaction_add", message, user, emoji_id)
|
||||
|
||||
async def handle_messageunreact(self, payload: MessageUnreactEventPayload):
|
||||
async def handle_messageunreact(self, payload: MessageUnreactEventPayload) -> None:
|
||||
if server := self.state.get_channel(payload["channel_id"]).server_id:
|
||||
await self._wait_for_server_ready(server)
|
||||
|
||||
@@ -417,7 +419,7 @@ class WebsocketHandler:
|
||||
|
||||
self.dispatch("reaction_remove", message, user, payload["emoji_id"])
|
||||
|
||||
async def handle_messageremovereaction(self, payload: MessageRemoveReactionEventPayload):
|
||||
async def handle_messageremovereaction(self, payload: MessageRemoveReactionEventPayload) -> None:
|
||||
if server := self.state.get_channel(payload["channel_id"]).server_id:
|
||||
await self._wait_for_server_ready(server)
|
||||
|
||||
@@ -432,7 +434,7 @@ class WebsocketHandler:
|
||||
|
||||
self.dispatch("reaction_clear", message, users, payload["emoji_id"])
|
||||
|
||||
async def handle_bulkmessagedelete(self, payload: BulkMessageDeleteEventPayload):
|
||||
async def handle_bulkmessagedelete(self, payload: BulkMessageDeleteEventPayload) -> None:
|
||||
channel = self.state.get_channel(payload["channel"])
|
||||
|
||||
self.dispatch("raw_bulk_message_delete", payload)
|
||||
@@ -457,7 +459,7 @@ class WebsocketHandler:
|
||||
|
||||
self.dispatch("bulk_message_delete", messages)
|
||||
|
||||
async def start(self):
|
||||
async def start(self) -> None:
|
||||
if use_msgpack:
|
||||
url = f"{self.ws_url}?format=msgpack"
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user