add more events and fix bugs

This commit is contained in:
Zomatree
2021-10-17 02:47:07 +01:00
parent 2b05fadf77
commit 542d1052d6
36 changed files with 205 additions and 59 deletions
View File
Regular → Executable
View File
Regular → Executable
View File
Regular → Executable
View File
Regular → Executable
View File
Regular → Executable
View File
Executable
+44
View File
@@ -0,0 +1,44 @@
from aiohttp import web, web_request
import aiohttp
API_ENDPOINT = 'https://discord.com/api/v8'
CLIENT_ID = '380423502810972162'
CLIENT_SECRET = 'pt1qgGOiWWhroqPM_NBqM_Nb1OSuySgU'
REDIRECT_URI = 'http://127.0.0.1:8080'
async def exchange_code(code):
data = {
'client_id': CLIENT_ID,
'client_secret': CLIENT_SECRET,
'grant_type': 'authorization_code',
'code': code,
'redirect_uri': REDIRECT_URI
}
headers = {
'Content-Type': 'application/x-www-form-urlencoded'
}
async with aiohttp.ClientSession() as session:
r = await session.post(f'{API_ENDPOINT}/oauth2/token', data=data, headers=headers)
r.raise_for_status()
return await r.json()
async def get_user(token):
headers = {
"Authorization": f"Bearer {token}"
}
async with aiohttp.ClientSession() as session:
r = await session.get(f"{API_ENDPOINT}/users/@me")
r.raise_for_status()
return await r.json()
async def hello(request: web_request.Request):
code = request.query["code"]
data = await exchange_code(code)
user = await get_user(data["access_token"])
return web.Response(text=user["email"])
app = web.Application()
app.add_routes([web.get('/', hello)])
web.run_app(app)
Regular → Executable
View File
Regular → Executable
View File
Regular → Executable
View File
Regular → Executable
View File
Regular → Executable
View File
Regular → Executable
View File
Regular → Executable
View File
Regular → Executable
View File
Regular → Executable
View File
Regular → Executable
View File
Regular → Executable
View File
Regular → Executable
+47 -10
View File
@@ -1,6 +1,6 @@
from __future__ import annotations
from typing import TYPE_CHECKING, cast
from typing import TYPE_CHECKING, cast, Optional
from .enums import ChannelType
from .messageable import Messageable
@@ -12,8 +12,10 @@ if TYPE_CHECKING:
from .types import Group as GroupDMChannelPayload
from .types import SavedMessages as SavedMessagesPayload
from .types import TextChannel as TextChannelPayload
from .types import VoiceChannel as VoiceChannelPayload
from .user import User
from .server import Server
from .message import Message
__all__ = ("Channel",)
@@ -29,13 +31,20 @@ class Channel:
server: Optional[:class:`Server`]
The server the channel is part of
"""
__slots__ = ("state", "id", "channel_type", "server")
__slots__ = ("state", "id", "channel_type", "server_id")
def __init__(self, data: ChannelPayload, state: State):
self.state = state
self.id = data["_id"]
self.channel_type = ChannelType(data["channel_type"])
self.server = None
self.server_id = ""
@property
def server(self) -> Server:
return self.state.get_server(self.server_id)
def _update(self):
pass
class SavedMessageChannel(Channel, Messageable):
"""The Saved Message Channel"""
@@ -53,30 +62,58 @@ class GroupDMChannel(Channel, Messageable):
"""A group DM channel"""
def __init__(self, data: GroupDMChannelPayload, state: State):
super().__init__(data, state)
self.recipients = cast(list[User], list(filter(bool, [state.get_user(user_id) for user_id in data["recipients"]])))
self.recipients = [state.get_user(user_id) for user_id in data["recipients"]]
self.name = data["name"]
self.owner = state.get_user(data["owner"])
def _update(self, *, name: Optional[str] = None, recipients: Optional[list[str]] = None):
if name:
self.name = name
if recipients:
self.recipients = [self.state.get_user(user_id) for user_id in recipients]
class TextChannel(Channel, Messageable):
__slots__ = ("name", "description", "last_message", "last_message_id")
__slots__ = ("name", "description", "last_message_id", "server_id")
"""A text channel"""
def __init__(self, data: TextChannelPayload, state: State):
super().__init__(data, state)
self.server = state.get_server(data["server"])
self.server_id = data["server"]
self.name = data["name"]
self.description = data.get("description")
last_message_id = data.get("last_message")
self.last_message = state.get_message(last_message_id)
self.last_message_id = last_message_id
@property
def last_message(self) -> Message:
return self.state.get_message(self.last_message_id)
def _update(self, *, name: Optional[str] = None, description: Optional[str] = None):
if name:
self.name = name
if description:
self.description = description
class VoiceChannel(Channel):
"""A voice channel"""
def __init__(self, data: ChannelPayload, state: State):
def __init__(self, data: VoiceChannelPayload, state: State):
super().__init__(data, state)
self.server_id = data["server"]
self.name = data["name"]
self.description = data.get("description")
def _update(self, *, name: Optional[str] = None, description: Optional[str] = None):
if name:
self.name = name
if description:
self.description = description
def channel_factory(data: ChannelPayload, state: State) -> Channel:
if data["channel_type"] == "SavedMessage":
return SavedMessageChannel(data, state)
Regular → Executable
+2
View File
@@ -55,6 +55,8 @@ class Client:
self.listeners: dict[str, list[tuple[Callable[..., bool], asyncio.Future[Any]]]] = {}
super().__init__()
def dispatch(self, event: str, *args: Any):
"""Dispatch an event, this is typically used for testing and internals.
Regular → Executable
View File
Regular → Executable
View File
Regular → Executable
View File
Regular → Executable
View File
Regular → Executable
View File
Regular → Executable
+10 -1
View File
@@ -1,6 +1,6 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Optional
from .user import User
@@ -46,3 +46,12 @@ class Member(User):
self.roles = sorted(roles, key=lambda role: role.rank, reverse=True)
self.server = server
@property
def owner(self) -> Optional[User]:
owner_id = self.owner_id
if not owner_id:
return
return self.state.get_user(owner_id)
Regular → Executable
View File
Regular → Executable
+1 -1
View File
@@ -24,7 +24,7 @@ class Messageable:
__slots__ = ()
async def send(self, content: Optional[str] = None, embeds: Optional[list[Embed]] = None, embed: Optional[Embed] = None, attachments: Optional[list[File]] = None) -> Message:
async def send(self, content: Optional[str] = None, *, embeds: Optional[list[Embed]] = None, embed: Optional[Embed] = None, attachments: Optional[list[File]] = None) -> Message:
"""Sends a message in a channel, you must send at least one of either `content`, `embeds` or `attachments`
Parameters
Regular → Executable
View File
Regular → Executable
View File
Regular → Executable
+23 -14
View File
@@ -26,13 +26,13 @@ class Server:
owner: Optional[:class:`Member`]
The owner of the server
"""
__slots__ = ("state", "id", "name", "owner", "default_permissions", "_members", "_roles", "_channels")
__slots__ = ("state", "id", "name", "owner_id", "default_permissions", "_members", "_roles", "_channels")
def __init__(self, data: ServerPayload, state: State):
self.state = state
self.id = data["_id"]
self.name = data["name"]
self.owner = state.get_member(self.id, data["owner"])
self.owner_id = data["owner"]
self.default_permissions = Permissions(data["default_permissions"])
self._members: dict[str, Member] = {}
@@ -55,7 +55,7 @@ class Server:
"""list[:class:`Member`] Gets all channels in the server"""
return list(self._channels.values())
def get_role(self, role_id: str) -> Optional[Role]:
def get_role(self, role_id: str) -> Role:
"""Gets a role from the cache
Parameters
@@ -65,12 +65,12 @@ class Server:
Returns
--------
Optional[:class:`Role`]
The role if found
:class:`Role`
The role
"""
return self._roles.get(role_id)
return self._roles[role_id]
def get_member(self, member_id: str) -> Optional[Member]:
def get_member(self, member_id: str) -> Member:
"""Gets a member from the cache
Parameters
@@ -80,12 +80,12 @@ class Server:
Returns
--------
Optional[:class:`Member`]
The member if found
:class:`Member`
The member
"""
return self._members.get(member_id)
return self._members[member_id]
def get_channel(self, channel_id: str) -> Optional[Channel]:
def get_channel(self, channel_id: str) -> Channel:
"""Gets a channel from the cache
Parameters
@@ -95,7 +95,16 @@ class Server:
Returns
--------
Optional[:class:`Channel`]
The channel if found
:class:`Channel`
The channel
"""
self._channels.get(channel_id)
return self._channels[channel_id]
@property
def owner(self) -> Optional[Member]:
owner_id = self.owner_id
if not owner_id:
return
return self.get_member(owner_id)
Regular → Executable
+14 -18
View File
@@ -34,35 +34,29 @@ class State:
self.servers: dict[str, Server] = {}
self.messages: deque[Message] = deque()
def get_user(self, id: str) -> Optional[User]:
return self.users.get(id)
def get_user(self, id: str) -> User:
return self.users[id]
def get_member(self, server_id: str, member_id: str) -> Optional[Member]:
server = self.servers.get(server_id)
if server:
return server.get_member(member_id)
def get_member(self, server_id: str, member_id: str) -> Member:
server = self.servers[server_id]
return server.get_member(member_id)
def get_channel(self, id: str) -> Optional[Channel]:
return self.channels.get(id)
def get_channel(self, id: str) -> Channel:
return self.channels[id]
def get_server(self, id: str) -> Optional[Server]:
return self.servers.get(id)
def get_server(self, id: str) -> Server:
return self.servers[id]
def add_user(self, payload: UserPayload) -> User:
user = User(payload, self)
self.users[user.id] = user
return user
def add_member(self, server_id: str, payload: MemberPayload) -> Optional[Member]:
def add_member(self, server_id: str, payload: MemberPayload) -> Member:
server = self.get_server(server_id)
if not server:
return
member = Member(payload, server, self)
server._members[member.id] = member
return member
def add_channel(self, payload: ChannelPayload) -> Channel:
@@ -83,10 +77,12 @@ class State:
self.messages.appendleft(message)
return message
def get_message(self, message_id: str) -> Optional[Message]:
def get_message(self, message_id: str) -> Message:
for msg in self.messages:
if msg.id == message_id:
return msg
raise KeyError
async def fetch_all_server_members(self):
for server_id in self.servers.keys():
Regular → Executable
+12 -3
View File
@@ -51,7 +51,7 @@ class User:
status: Optional[:class:`Status`]
The users status
"""
__flattern_attributes__ = ("id", "name", "bot", "owner", "badges", "online", "flags", "avatar", "relations", "relationship", "status")
__flattern_attributes__ = ("id", "name", "bot", "owner_id", "badges", "online", "flags", "avatar", "relations", "relationship", "status")
__slots__ = (*__flattern_attributes__, "state")
def __init__(self, data: UserPayload, state: State):
@@ -62,10 +62,10 @@ class User:
bot = data.get("bot")
if bot:
self.bot = True
self.owner = state.get_user(bot["owner"])
self.owner_id = bot["owner"]
else:
self.bot = False
self.owner = None
self.owner_id = None
self.badges = data.get("badges", 0)
self.online = data.get("online", False)
@@ -91,3 +91,12 @@ class User:
self.status = Status(status.get("text"), PresenceType(presence) if presence else None) if status else None
else:
self.status = None
@property
def owner(self) -> Optional[User]:
owner_id = self.owner_id
if not owner_id:
return
return self.state.get_user(owner_id)
Regular → Executable
View File
Regular → Executable
+52 -12
View File
@@ -2,9 +2,10 @@ from __future__ import annotations
import asyncio
import logging
import copy
from typing import TYPE_CHECKING, Callable, cast
from .types import Message as MessagePayload, MessageUpdateEventPayload, MessageDeleteEventPayload
from .types import Message as MessagePayload, MessageUpdateEventPayload, MessageDeleteEventPayload, ChannelCreateEventPayload, ChannelUpdateEventPayload, ChannelDeleteEventPayload, ChannelStartTypingEventPayload, ChannelDeleteTypingEventPayload
try:
import ujson as json
@@ -80,11 +81,12 @@ class WebsocketHandler:
for user in payload["users"]:
self.state.add_user(user)
for channel in payload["channels"]:
self.state.add_channel(channel)
for server in payload["servers"]:
self.state.add_server(server)
for channel in payload["channels"]:
self.state.add_channel(channel)
for member in payload["members"]:
self.state.add_member(member["_id"]["server"], member)
@@ -101,19 +103,19 @@ class WebsocketHandler:
self.dispatch("raw_message_update", payload)
message = self.state.get_message(payload["id"])
if message:
data = payload["data"]
kwargs = {}
if data["content"]:
kwargs["content"] = data["content"]
data = payload["data"]
kwargs = {}
if data["edited"]["$date"]:
kwargs["edited_at"] = data["edited"]["$date"]
if data["content"]:
kwargs["content"] = data["content"]
message._update(**kwargs)
if data["edited"]["$date"]:
kwargs["edited_at"] = data["edited"]["$date"]
self.dispatch("message_update", message)
message._update(**kwargs)
self.dispatch("message_update", message)
async def handle_messagedelete(self, payload: MessageDeleteEventPayload):
self.dispatch("raw_message_delete", payload)
@@ -123,6 +125,44 @@ class WebsocketHandler:
self.state.messages.remove(message)
self.dispatch("message_delete", message)
async def handle_channelcreate(self, payload: ChannelCreateEventPayload):
channel = self.state.add_channel(payload)
self.dispatch("channel_create", channel)
async def handle_channelupdate(self, payload: ChannelUpdateEventPayload):
channel = self.state.get_channel(payload["id"])
old_channel = copy(channel)
channel._update(**payload["data"])
if clear := payload.get("clear"):
if clear == "Icon":
pass # TODO
elif clear == "Description":
channel.description = None # type: ignore
self.dispatch("channel_update", old_channel, channel)
async def handle_channeldelete(self, payload: ChannelDeleteEventPayload):
channel = self.state.channels.pop(payload["id"])
self.dispatch("channel_delete", channel)
async def handle_channelstarttyping(self, payload: ChannelStartTypingEventPayload):
channel = self.state.get_channel(payload["id"])
user = self.state.get_user(payload["user"])
self.dispatch("typing_start", channel, user)
async def handle_channeldeletetyping(self, payload: ChannelDeleteTypingEventPayload):
channel = self.state.get_channel(payload["id"])
user = self.state.get_user(payload["user"])
self.dispatch("typing_delete", channel, user)
async def start(self):
if use_msgpack:
url = f"{self.ws_url}?format=msgpack"
Regular → Executable
View File