mirror of
https://github.com/stoatchat/python-client-sdk.git
synced 2026-07-01 20:44:04 -04:00
clean up code
This commit is contained in:
@@ -54,6 +54,7 @@ only-packages = true
|
||||
reportPrivateUsage = false
|
||||
reportImportCycles = false
|
||||
reportIncompatibleMethodOverride = false
|
||||
typeCheckingMode = "strict"
|
||||
|
||||
[build-system]
|
||||
requires = ["hatchling"]
|
||||
|
||||
@@ -19,3 +19,4 @@ from .server import *
|
||||
from .user import *
|
||||
|
||||
__version__ = "0.1.9"
|
||||
|
||||
+1
-1
@@ -347,7 +347,7 @@ class Client:
|
||||
background: Optional[:class:`File`]
|
||||
The new background for the profile, passing in ``None`` will remove the profile background
|
||||
"""
|
||||
remove = []
|
||||
remove: list[str] = []
|
||||
|
||||
if kwargs.get("content", Missing) is None:
|
||||
del kwargs["content"]
|
||||
|
||||
@@ -68,7 +68,7 @@ def has_permissions(**permissions: bool):
|
||||
author = context.author
|
||||
|
||||
if not author.has_permissions(**permissions):
|
||||
raise MissingPermissionsError
|
||||
raise MissingPermissionsError(permissions)
|
||||
|
||||
return inner
|
||||
|
||||
@@ -77,10 +77,10 @@ def has_channel_permissions(**permissions: bool):
|
||||
def inner(context: Context[ClientT]):
|
||||
author = context.author
|
||||
|
||||
if isinstance(author, revolt.User):
|
||||
raise MissingPermissionsError
|
||||
if not isinstance(author, revolt.Member):
|
||||
raise ServerOnly
|
||||
|
||||
if not author.has_channel_permissions(context.channel, **permissions):
|
||||
raise MissingPermissionsError
|
||||
raise MissingPermissionsError(permissions)
|
||||
|
||||
return inner
|
||||
|
||||
@@ -9,7 +9,7 @@ from typing import (TYPE_CHECKING, Annotated, Any, Callable, Coroutine,
|
||||
from revolt.utils import copy_doc, maybe_coroutine
|
||||
|
||||
from .errors import InvalidLiteralArgument, UnionConverterError
|
||||
from .utils import ClientT, evaluate_parameters
|
||||
from .utils import ClientCoT, evaluate_parameters
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .checks import Check
|
||||
@@ -25,7 +25,7 @@ __all__ = (
|
||||
NoneType = type(None)
|
||||
|
||||
|
||||
class Command(Generic[ClientT]):
|
||||
class Command(Generic[ClientCoT]):
|
||||
"""Class for holding info about a command.
|
||||
|
||||
Parameters
|
||||
@@ -52,13 +52,13 @@ class Command(Generic[ClientT]):
|
||||
self.usage = usage
|
||||
self.signature = inspect.signature(self.callback)
|
||||
self.parameters = evaluate_parameters(self.signature.parameters.values(), getattr(callback, "__globals__", {}))
|
||||
self.checks: list[Check[ClientT]] = getattr(callback, "_checks", [])
|
||||
self.parent: Optional[Group[ClientT]] = None
|
||||
self.cog: Optional[Cog[ClientT]] = None
|
||||
self._error_handler: Callable[[Any, Context[ClientT], Exception], Coroutine[Any, Any, Any]] = type(self)._default_error_handler
|
||||
self.checks: list[Check[ClientCoT]] = getattr(callback, "_checks", [])
|
||||
self.parent: Optional[Group[ClientCoT]] = None
|
||||
self.cog: Optional[Cog[ClientCoT]] = None
|
||||
self._error_handler: Callable[[Any, Context[ClientCoT], Exception], Coroutine[Any, Any, Any]] = type(self)._default_error_handler
|
||||
self.description = callback.__doc__
|
||||
|
||||
async def invoke(self, context: Context[ClientT], *args: Any, **kwargs: Any) -> Any:
|
||||
async def invoke(self, context: Context[ClientCoT], *args: Any, **kwargs: Any) -> Any:
|
||||
"""Runs the command and calls the error handler if the command errors.
|
||||
|
||||
Parameters
|
||||
@@ -74,7 +74,7 @@ class Command(Generic[ClientT]):
|
||||
return await self._error_handler(self.cog or context.client, context, err)
|
||||
|
||||
@copy_doc(invoke)
|
||||
def __call__(self, context: Context[ClientT], *args: Any, **kwargs: Any) -> Any:
|
||||
def __call__(self, context: Context[ClientCoT], *args: Any, **kwargs: Any) -> Any:
|
||||
return self.invoke(context, *args, **kwargs)
|
||||
|
||||
def error(self, func: Callable[..., Coroutine[Any, Any, Any]]):
|
||||
@@ -97,11 +97,11 @@ class Command(Generic[ClientT]):
|
||||
self._error_handler = func
|
||||
return func
|
||||
|
||||
async def _default_error_handler(self, ctx: Context[ClientT], error: Exception):
|
||||
async def _default_error_handler(self, ctx: Context[ClientCoT], error: Exception):
|
||||
traceback.print_exception(type(error), error, error.__traceback__)
|
||||
|
||||
@classmethod
|
||||
async def handle_origin(cls, context: Context[ClientT], origin: Any, annotation: Any, arg: str) -> Any:
|
||||
async def handle_origin(cls, context: Context[ClientCoT], origin: Any, annotation: Any, arg: str) -> Any:
|
||||
if origin is Union:
|
||||
for converter in get_args(annotation):
|
||||
try:
|
||||
@@ -128,11 +128,12 @@ class Command(Generic[ClientT]):
|
||||
raise InvalidLiteralArgument(arg)
|
||||
|
||||
@classmethod
|
||||
async def convert_argument(cls, arg: str, annotation: Any, context: Context[ClientT]) -> Any:
|
||||
async def convert_argument(cls, arg: str, annotation: Any, context: Context[ClientCoT]) -> Any:
|
||||
if annotation is not inspect.Signature.empty:
|
||||
if annotation is str: # no converting is needed - its already a string
|
||||
return arg
|
||||
|
||||
origin: Any
|
||||
if origin := get_origin(annotation):
|
||||
return await cls.handle_origin(context, origin, annotation, arg)
|
||||
else:
|
||||
@@ -140,7 +141,7 @@ class Command(Generic[ClientT]):
|
||||
else:
|
||||
return arg
|
||||
|
||||
async def parse_arguments(self, context: Context[ClientT]):
|
||||
async def parse_arguments(self, context: Context[ClientCoT]):
|
||||
# please pr if you can think of a better way to do this
|
||||
|
||||
for parameter in self.parameters[2:]:
|
||||
@@ -213,7 +214,7 @@ class Command(Generic[ClientT]):
|
||||
|
||||
return f"{' '.join(parents[::-1])} {self.name} {' '.join(parameters)}"
|
||||
|
||||
def command(*, name: Optional[str] = None, aliases: Optional[list[str]] = None, cls: type[Command[ClientT]] = Command, usage: Optional[str] = None):
|
||||
def command(*, name: Optional[str] = None, aliases: Optional[list[str]] = None, cls: type[Command[ClientCoT]] = Command, usage: Optional[str] = None):
|
||||
"""A decorator that turns a function into a :class:`Command`.
|
||||
|
||||
Parameters
|
||||
|
||||
@@ -7,7 +7,7 @@ from revolt.utils import maybe_coroutine
|
||||
|
||||
from .command import Command
|
||||
from .group import Group
|
||||
from .utils import ClientT
|
||||
from .utils import ClientCoT
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .view import StringView
|
||||
@@ -16,7 +16,7 @@ __all__ = (
|
||||
"Context",
|
||||
)
|
||||
|
||||
class Context(revolt.Messageable, Generic[ClientT]):
|
||||
class Context(revolt.Messageable, Generic[ClientCoT]):
|
||||
"""Stores metadata the commands execution.
|
||||
|
||||
Attributes
|
||||
@@ -45,7 +45,7 @@ class Context(revolt.Messageable, Generic[ClientT]):
|
||||
async def _get_channel_id(self) -> str:
|
||||
return self.channel.id
|
||||
|
||||
def __init__(self, command: Optional[Command[ClientT]], invoked_with: str, view: StringView, message: revolt.Message, client: ClientT):
|
||||
def __init__(self, command: Optional[Command[ClientCoT]], invoked_with: str, view: StringView, message: revolt.Message, client: ClientCoT):
|
||||
self.command = command
|
||||
self.invoked_with = invoked_with
|
||||
self.view = view
|
||||
@@ -85,8 +85,14 @@ class Context(revolt.Messageable, Generic[ClientT]):
|
||||
await command.parse_arguments(self)
|
||||
return await command.invoke(self, *self.args, **self.kwargs)
|
||||
|
||||
async def can_run(self, command: Optional[Command[ClientT]] = None) -> bool:
|
||||
async def can_run(self, command: Optional[Command[ClientCoT]] = None) -> bool:
|
||||
"""Runs all of the commands checks, and returns true if all of them pass"""
|
||||
command = command or self.command
|
||||
|
||||
return all([await maybe_coroutine(check, self) for check in (command.checks if command else [])])
|
||||
|
||||
async def send_help(self, argument: Command[Any] | Group[Any] | ClientCoT | None = None):
|
||||
argument = argument or self.client
|
||||
|
||||
command = self.client.get_command("help")
|
||||
await command.invoke(self, argument)
|
||||
|
||||
@@ -50,7 +50,16 @@ class ServerOnly(CheckError):
|
||||
"""Raised when a check requires the command to be ran in a server"""
|
||||
|
||||
class MissingPermissionsError(CheckError):
|
||||
"""Raised when a check requires permissions the user does not have"""
|
||||
"""Raised when a check requires permissions the user does not have
|
||||
|
||||
Attributes
|
||||
-----------
|
||||
permissions: :class:`dict[str, bool]`
|
||||
The permissions which the user did not have
|
||||
"""
|
||||
|
||||
def __init__(self, permissions: dict[str, bool]):
|
||||
self.permissions = permissions
|
||||
|
||||
class ConverterError(CommandError):
|
||||
"""Base class for all converter errors"""
|
||||
|
||||
@@ -1,21 +1,17 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Callable, Coroutine, Optional, TypeVar
|
||||
from typing import Any, Callable, Coroutine, Optional
|
||||
|
||||
from .command import Command
|
||||
from .utils import ClientCoT, ClientT
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .client import CommandsClient
|
||||
|
||||
__all__ = (
|
||||
"Group",
|
||||
"group"
|
||||
)
|
||||
|
||||
ClientT = TypeVar("ClientT", bound="CommandsClient")
|
||||
|
||||
|
||||
class Group(Command[ClientT]):
|
||||
class Group(Command[ClientCoT]):
|
||||
"""Class for holding info about a group command.
|
||||
|
||||
Parameters
|
||||
@@ -33,10 +29,10 @@ class Group(Command[ClientT]):
|
||||
__slots__ = ("subcommands",)
|
||||
|
||||
def __init__(self, callback: Callable[..., Coroutine[Any, Any, Any]], name: str, aliases: list[str]):
|
||||
self.subcommands: dict[str, Command[ClientT]] = {}
|
||||
self.subcommands: dict[str, Command[ClientCoT]] = {}
|
||||
super().__init__(callback, name, aliases)
|
||||
|
||||
def command(self, *, name: Optional[str] = None, aliases: Optional[list[str]] = None, cls: type[Command[ClientT]] = Command[ClientT]):
|
||||
def command(self, *, name: Optional[str] = None, aliases: Optional[list[str]] = None, cls: type[Command[ClientCoT]] = Command[ClientCoT]):
|
||||
"""A decorator that turns a function into a :class:`Command` and registers the command as a subcommand.
|
||||
|
||||
Parameters
|
||||
@@ -61,7 +57,7 @@ class Group(Command[ClientT]):
|
||||
|
||||
return inner
|
||||
|
||||
def group(self, *, name: Optional[str] = None, aliases: Optional[list[str]] = None, cls: Optional[type[Group[ClientT]]] = None):
|
||||
def group(self, *, name: Optional[str] = None, aliases: Optional[list[str]] = None, cls: Optional[type[Group[ClientCoT]]] = None):
|
||||
"""A decorator that turns a function into a :class:`Group` and registers the command as a subcommand
|
||||
|
||||
Parameters
|
||||
@@ -92,7 +88,7 @@ class Group(Command[ClientT]):
|
||||
return f"<Group name=\"{self.name}\">"
|
||||
|
||||
@property
|
||||
def commands(self) -> list[Command[ClientT]]:
|
||||
def commands(self) -> list[Command[ClientCoT]]:
|
||||
return list(self.subcommands.values())
|
||||
|
||||
def group(*, name: Optional[str] = None, aliases: Optional[list[str]] = None, cls: type[Group[ClientT]] = Group):
|
||||
|
||||
+26
-25
@@ -9,7 +9,7 @@ from .cog import Cog
|
||||
from .command import Command
|
||||
from .context import Context
|
||||
from .group import Group
|
||||
from .utils import ClientT
|
||||
from .utils import ClientCoT, ClientT
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from revolt import File, Message, Messageable, MessageReply, SendableEmbed
|
||||
@@ -18,6 +18,7 @@ if TYPE_CHECKING:
|
||||
|
||||
__all__ = ("MessagePayload", "HelpCommand", "DefaultHelpCommand", "help_command_impl")
|
||||
|
||||
|
||||
class MessagePayload(TypedDict):
|
||||
content: str
|
||||
embed: NotRequired[SendableEmbed]
|
||||
@@ -25,28 +26,28 @@ class MessagePayload(TypedDict):
|
||||
attachments: NotRequired[list[File]]
|
||||
replies: NotRequired[list[MessageReply]]
|
||||
|
||||
class HelpCommand(ABC, Generic[ClientT]):
|
||||
class HelpCommand(ABC, Generic[ClientCoT]):
|
||||
@abstractmethod
|
||||
async def create_bot_help(self, context: Context[ClientT], commands: dict[Optional[Cog[ClientT]], list[Command[ClientT]]]) -> Union[str, SendableEmbed, MessagePayload]:
|
||||
async def create_bot_help(self, context: Context[ClientCoT], commands: dict[Optional[Cog[ClientCoT]], list[Command[ClientCoT]]]) -> Union[str, SendableEmbed, MessagePayload]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def create_command_help(self, context: Context[ClientT], command: Command[ClientT]) -> Union[str, SendableEmbed, MessagePayload]:
|
||||
async def create_command_help(self, context: Context[ClientCoT], command: Command[ClientCoT]) -> Union[str, SendableEmbed, MessagePayload]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def create_group_help(self, context: Context[ClientT], group: Group[ClientT]) -> Union[str, SendableEmbed, MessagePayload]:
|
||||
async def create_group_help(self, context: Context[ClientCoT], group: Group[ClientCoT]) -> Union[str, SendableEmbed, MessagePayload]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def create_cog_help(self, context: Context[ClientT], cog: Cog[ClientT]) -> Union[str, SendableEmbed, MessagePayload]:
|
||||
async def create_cog_help(self, context: Context[ClientCoT], cog: Cog[ClientCoT]) -> Union[str, SendableEmbed, MessagePayload]:
|
||||
raise NotImplementedError
|
||||
|
||||
async def send_help_command(self, context: Context[ClientT], message_payload: MessagePayload) -> Message:
|
||||
async def send_help_command(self, context: Context[ClientCoT], message_payload: MessagePayload) -> Message:
|
||||
return await context.send(**message_payload)
|
||||
|
||||
async def filter_commands(self, context: Context[ClientT], commands: list[Command[ClientT]]) -> list[Command[ClientT]]:
|
||||
filtered: list[Command[ClientT]] = []
|
||||
async def filter_commands(self, context: Context[ClientCoT], commands: list[Command[ClientCoT]]) -> list[Command[ClientCoT]]:
|
||||
filtered: list[Command[ClientCoT]] = []
|
||||
|
||||
for command in commands:
|
||||
try:
|
||||
@@ -57,34 +58,34 @@ class HelpCommand(ABC, Generic[ClientT]):
|
||||
|
||||
return filtered
|
||||
|
||||
async def group_commands(self, context: Context[ClientT], commands: list[Command[ClientT]]) -> dict[Optional[Cog[ClientT]], list[Command[ClientT]]]:
|
||||
cogs: dict[Optional[Cog[ClientT]], list[Command[ClientT]]] = {}
|
||||
async def group_commands(self, context: Context[ClientCoT], commands: list[Command[ClientCoT]]) -> dict[Optional[Cog[ClientCoT]], list[Command[ClientCoT]]]:
|
||||
cogs: dict[Optional[Cog[ClientCoT]], list[Command[ClientCoT]]] = {}
|
||||
|
||||
for command in commands:
|
||||
cogs.setdefault(command.cog, []).append(command)
|
||||
|
||||
return cogs
|
||||
|
||||
async def handle_message(self, context: Context[ClientT], message: Message):
|
||||
async def handle_message(self, context: Context[ClientCoT], message: Message):
|
||||
pass
|
||||
|
||||
async def get_channel(self, context: Context[ClientT]) -> Messageable:
|
||||
async def get_channel(self, context: Context) -> Messageable:
|
||||
return context
|
||||
|
||||
@abstractmethod
|
||||
async def handle_no_command_found(self, context: Context[ClientT], name: str) -> Any:
|
||||
async def handle_no_command_found(self, context: Context[ClientCoT], name: str) -> Any:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def handle_no_cog_found(self, context: Context[ClientT], name: str) -> Any:
|
||||
async def handle_no_cog_found(self, context: Context[ClientCoT], name: str) -> Any:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class DefaultHelpCommand(HelpCommand[ClientT]):
|
||||
class DefaultHelpCommand(HelpCommand[ClientCoT]):
|
||||
def __init__(self, default_cog_name: str = "No Cog"):
|
||||
self.default_cog_name = default_cog_name
|
||||
|
||||
async def create_bot_help(self, context: Context[ClientT], commands: dict[Optional[Cog[ClientT]], list[Command[ClientT]]]) -> Union[str, SendableEmbed, MessagePayload]:
|
||||
async def create_bot_help(self, context: Context[ClientCoT], commands: dict[Optional[Cog[ClientCoT]], list[Command[ClientCoT]]]) -> Union[str, SendableEmbed, MessagePayload]:
|
||||
lines = ["```"]
|
||||
|
||||
for cog, cog_commands in commands.items():
|
||||
@@ -99,7 +100,7 @@ class DefaultHelpCommand(HelpCommand[ClientT]):
|
||||
lines.append("```")
|
||||
return "\n".join(lines)
|
||||
|
||||
async def create_cog_help(self, context: Context[ClientT], cog: Cog[ClientT]) -> Union[str, SendableEmbed, MessagePayload]:
|
||||
async def create_cog_help(self, context: Context[ClientCoT], cog: Cog[ClientCoT]) -> Union[str, SendableEmbed, MessagePayload]:
|
||||
lines = ["```"]
|
||||
|
||||
lines.append(f"{cog.qualified_name}:")
|
||||
@@ -110,7 +111,7 @@ class DefaultHelpCommand(HelpCommand[ClientT]):
|
||||
lines.append("```")
|
||||
return "\n".join(lines)
|
||||
|
||||
async def create_command_help(self, context: Context[ClientT], command: Command[ClientT]) -> Union[str, SendableEmbed, MessagePayload]:
|
||||
async def create_command_help(self, context: Context[ClientCoT], command: Command[ClientCoT]) -> Union[str, SendableEmbed, MessagePayload]:
|
||||
lines = ["```"]
|
||||
|
||||
lines.append(f"{command.name}:")
|
||||
@@ -126,7 +127,7 @@ class DefaultHelpCommand(HelpCommand[ClientT]):
|
||||
lines.append("```")
|
||||
return "\n".join(lines)
|
||||
|
||||
async def create_group_help(self, context: Context[ClientT], group: Group[ClientT]) -> Union[str, SendableEmbed, MessagePayload]:
|
||||
async def create_group_help(self, context: Context[ClientCoT], group: Group[ClientCoT]) -> Union[str, SendableEmbed, MessagePayload]:
|
||||
lines = ["```"]
|
||||
|
||||
lines.append(f"{group.name}:")
|
||||
@@ -144,20 +145,20 @@ class DefaultHelpCommand(HelpCommand[ClientT]):
|
||||
lines.append("```")
|
||||
return "\n".join(lines)
|
||||
|
||||
async def handle_no_command_found(self, context: Context[ClientT], name: str):
|
||||
async def handle_no_command_found(self, context: Context[ClientCoT], name: str):
|
||||
channel = await self.get_channel(context)
|
||||
await channel.send(f"Command `{name}` not found.")
|
||||
|
||||
async def handle_no_cog_found(self, context: Context[ClientT], name: str):
|
||||
async def handle_no_cog_found(self, context: Context[ClientCoT], name: str):
|
||||
channel = await self.get_channel(context)
|
||||
await channel.send(f"Cog `{name}` not found.")
|
||||
|
||||
|
||||
class HelpCommandImpl(Command[ClientT]):
|
||||
def __init__(self, client: ClientT):
|
||||
class HelpCommandImpl(Command[ClientCoT]):
|
||||
def __init__(self, client: ClientCoT):
|
||||
self.client = client
|
||||
|
||||
async def callback(_: Union[ClientT, Cog[ClientT]], context: Context[ClientT], *args: str):
|
||||
async def callback(_: Union[ClientCoT, Cog[ClientCoT]], context: Context[ClientCoT], *args: str):
|
||||
await help_command_impl(context.client, context, *args)
|
||||
|
||||
super().__init__(callback=callback, name="help", aliases=[])
|
||||
|
||||
@@ -7,12 +7,14 @@ from typing_extensions import TypeVar
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .client import CommandsClient
|
||||
from .context import Context
|
||||
|
||||
|
||||
__all__ = ("evaluate_parameters",)
|
||||
|
||||
ClientT = TypeVar("ClientT", bound="CommandsClient", default="CommandsClient")
|
||||
|
||||
ClientCoT = TypeVar("ClientCoT", bound="CommandsClient", default="CommandsClient", covariant=True)
|
||||
ContextT = TypeVar("ContextT", bound="Context")
|
||||
|
||||
def evaluate_parameters(parameters: Iterable[Parameter], globals: dict[str, Any]) -> list[Parameter]:
|
||||
new_parameters: list[Parameter] = []
|
||||
|
||||
+1
-2
@@ -8,7 +8,6 @@ import ulid
|
||||
|
||||
from .errors import Forbidden, HTTPError, ServerError
|
||||
from .file import File
|
||||
from .utils import Missing
|
||||
|
||||
try:
|
||||
import ujson as _json
|
||||
@@ -203,7 +202,7 @@ class HttpClient:
|
||||
include_users: bool = False
|
||||
) -> Request[Union[list[MessagePayload], MessageWithUserData]]:
|
||||
|
||||
json = {"sort": sort.value, "include_users": str(include_users)}
|
||||
json: dict[str, Any] = {"sort": sort.value, "include_users": str(include_users)}
|
||||
|
||||
if limit:
|
||||
json["limit"] = limit
|
||||
|
||||
+46
-4
@@ -1,19 +1,24 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
|
||||
from .utils import _Missing, Missing
|
||||
|
||||
from .asset import Asset
|
||||
from .permissions import Permissions
|
||||
from .permissions_calculator import calculate_permissions
|
||||
from .user import User
|
||||
from .file import File
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .channel import Channel
|
||||
from .server import Server
|
||||
from .state import State
|
||||
from .types import File
|
||||
from .types import File as FilePayload
|
||||
from .types import Member as MemberPayload
|
||||
from .role import Role
|
||||
|
||||
__all__ = ("Member",)
|
||||
|
||||
@@ -42,7 +47,7 @@ class Member(User):
|
||||
|
||||
# due to not having a user payload and only a user object we have to manually add all the attributes instead of calling User.__init__
|
||||
flattern_user(self, user)
|
||||
user._members.add(self)
|
||||
user._members[server.id] = self
|
||||
|
||||
self.state = state
|
||||
|
||||
@@ -77,7 +82,7 @@ class Member(User):
|
||||
""":class:`str`: Returns a string that allows you to mention the given member."""
|
||||
return f"<@{self.id}>"
|
||||
|
||||
def _update(self, *, nickname: Optional[str] = None, avatar: Optional[File] = None, roles: Optional[list[str]] = None):
|
||||
def _update(self, *, nickname: Optional[str] = None, avatar: Optional[FilePayload] = None, roles: Optional[list[str]] = None):
|
||||
if nickname is not None:
|
||||
self.nickname = nickname
|
||||
|
||||
@@ -106,6 +111,43 @@ class Member(User):
|
||||
"""Unbans the member from the server"""
|
||||
await self.state.http.unban_member(self.server.id, self.id)
|
||||
|
||||
async def edit(
|
||||
self,
|
||||
*,
|
||||
nickname: str | None | _Missing = Missing,
|
||||
roles: list[Role] | None | _Missing = Missing,
|
||||
avatar: File | None | _Missing = Missing,
|
||||
timeout: datetime.timedelta | None | _Missing = Missing
|
||||
):
|
||||
remove: list[str] = []
|
||||
data: dict[str, Any] = {}
|
||||
|
||||
if nickname is None:
|
||||
remove.append("Nickname")
|
||||
elif nickname is not Missing:
|
||||
data["nickname"] = nickname
|
||||
|
||||
if roles is None:
|
||||
remove.append("Roles")
|
||||
elif roles is not Missing:
|
||||
data["roles"] = roles
|
||||
|
||||
if avatar is None:
|
||||
remove.append("Avatar")
|
||||
elif avatar is not Missing:
|
||||
# pyright cant understand custom singletons - it doesnt know this will never be an instance of _Missing here because Missing is the only instance
|
||||
assert not isinstance(avatar, _Missing)
|
||||
|
||||
data["avatar"] = (await self.state.http.upload_file(avatar, "avatars"))["id"]
|
||||
|
||||
if timeout is None:
|
||||
remove.append("Timeout")
|
||||
elif timeout is not Missing:
|
||||
assert not isinstance(timeout, _Missing)
|
||||
data["timeout"] = (datetime.datetime.now(datetime.timezone.utc) + timeout).isoformat()
|
||||
|
||||
await self.state.http.edit_member(self.server.id, self.id, remove, data)
|
||||
|
||||
async def timeout(self, length: datetime.timedelta):
|
||||
"""Timeouts the member
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Optional, cast
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from typing_extensions import Self
|
||||
|
||||
|
||||
+34
-3
@@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, NamedTuple, Optional, Union
|
||||
from weakref import WeakSet
|
||||
from weakref import WeakValueDictionary
|
||||
|
||||
from .asset import Asset, PartialAsset
|
||||
from .channel import DMChannel, GroupDMChannel
|
||||
@@ -18,6 +18,7 @@ if TYPE_CHECKING:
|
||||
from .types import Status as StatusPayload
|
||||
from .types import User as UserPayload
|
||||
from .types import UserProfile as UserProfileData
|
||||
from .server import Server
|
||||
|
||||
__all__ = ("User", "Status", "Relation", "UserProfile")
|
||||
|
||||
@@ -69,7 +70,7 @@ class User(Messageable, Ulid):
|
||||
|
||||
def __init__(self, data: UserPayload, state: State):
|
||||
self.state = state
|
||||
self._members: WeakSet[Member] = WeakSet() # we store all member versions of this user to avoid having to check every guild when needing to update.
|
||||
self._members: WeakValueDictionary[str, Member] = WeakValueDictionary() # we store all member versions of this user to avoid having to check every guild when needing to update.
|
||||
self.id = data["_id"]
|
||||
self.original_name = data["username"]
|
||||
self.dm_channel = None
|
||||
@@ -212,7 +213,7 @@ class User(Messageable, Ulid):
|
||||
# update user infomation for all members
|
||||
|
||||
if self.__class__ is User:
|
||||
for member in self._members:
|
||||
for member in self._members.values():
|
||||
User._update(member, status=status, profile=profile, avatar=avatar, online=online)
|
||||
|
||||
async def default_avatar(self) -> bytes:
|
||||
@@ -245,3 +246,33 @@ class User(Messageable, Ulid):
|
||||
|
||||
self.profile = UserProfile(payload.get("content"), background)
|
||||
return self.profile
|
||||
|
||||
def to_member(self, server: Server) -> Member:
|
||||
"""Gets the member instance for this user for a specific server.
|
||||
|
||||
Roughly equivelent to:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
member = server.get_member(user.id)
|
||||
|
||||
|
||||
Parameters
|
||||
-----------
|
||||
server: :class:`Server`
|
||||
The server to get the member for
|
||||
|
||||
Returns
|
||||
--------
|
||||
:class:`Member`
|
||||
The member
|
||||
|
||||
Raises
|
||||
-------
|
||||
:class:`LookupError`
|
||||
|
||||
"""
|
||||
try:
|
||||
return self._members[server.id]
|
||||
except IndexError:
|
||||
raise LookupError from None
|
||||
|
||||
+11
-3
@@ -4,7 +4,7 @@ import asyncio
|
||||
import logging
|
||||
import time
|
||||
from copy import copy
|
||||
from typing import TYPE_CHECKING, Callable, cast
|
||||
from typing import TYPE_CHECKING, Callable, NamedTuple, cast
|
||||
|
||||
from . import utils
|
||||
from .channel import GroupDMChannel, TextChannel, VoiceChannel
|
||||
@@ -29,6 +29,8 @@ from .types import (ServerCreateEventPayload, ServerDeleteEventPayload,
|
||||
UserUpdateEventPayload)
|
||||
from .user import Status, UserProfile
|
||||
|
||||
import aiohttp
|
||||
|
||||
try:
|
||||
import ujson as json
|
||||
except ImportError:
|
||||
@@ -46,7 +48,11 @@ if TYPE_CHECKING:
|
||||
from .state import State
|
||||
from .types import (AuthenticatePayload, BasePayload, MessageEventPayload,
|
||||
ReadyEventPayload)
|
||||
from .message import Message
|
||||
|
||||
class WSMessage(NamedTuple):
|
||||
type: aiohttp.WSMsgType
|
||||
data: str | bytes | aiohttp.WSCloseCode
|
||||
|
||||
__all__ = ("WebsocketHandler",)
|
||||
|
||||
@@ -313,7 +319,7 @@ class WebsocketHandler:
|
||||
# remove the member from the user
|
||||
|
||||
user = self.state.get_user(payload["user"])
|
||||
user._members.remove(member)
|
||||
user._members.pop(server.id)
|
||||
|
||||
self.dispatch("member_leave", member)
|
||||
|
||||
@@ -431,7 +437,7 @@ class WebsocketHandler:
|
||||
|
||||
self.dispatch("raw_bulk_message_delete", payload)
|
||||
|
||||
messages = []
|
||||
messages: list[Message] = []
|
||||
|
||||
for message_id in payload["ids"]:
|
||||
if server_id := channel.server_id:
|
||||
@@ -462,6 +468,8 @@ class WebsocketHandler:
|
||||
asyncio.create_task(self.heartbeat())
|
||||
|
||||
async for msg in self.websocket:
|
||||
msg = cast(WSMessage, msg) # aiohttp doesnt use NamedTuple so the type info is missing
|
||||
|
||||
if use_msgpack:
|
||||
data = cast(bytes, msg.data)
|
||||
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
def get_html_theme_path() -> str: ...
|
||||
Reference in New Issue
Block a user