Compare commits
No commits in common. "1af2716c544f467208fd8edf5b409686f95ff0bf" and "d8f53a35be36005f751c4fdcdfe2343b2f9c95ac" have entirely different histories.
1af2716c54
...
d8f53a35be
@ -1,5 +1,5 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
from typing import Union, Optional, Dict, List
|
from typing import Union, Optional, Dict
|
||||||
|
|
||||||
from .api import API
|
from .api import API
|
||||||
from .room import Room
|
from .room import Room
|
||||||
@ -30,16 +30,10 @@ class Client:
|
|||||||
"rooms": self.process_room_events,
|
"rooms": self.process_room_events,
|
||||||
"groups": self.process_group_events,
|
"groups": self.process_group_events,
|
||||||
}
|
}
|
||||||
self.event_dispatchers: Dict[str, List[callable]] = {}
|
self.event_dispatchers: Dict[str, callable] = {}
|
||||||
self.users = []
|
self.users = []
|
||||||
self.loop: Optional[asyncio.AbstractEventLoop] = None
|
|
||||||
|
|
||||||
async def run(self, user_id: str = None, password: str = None, token: str = None, loop: Optional[asyncio.AbstractEventLoop] = None):
|
|
||||||
if loop:
|
|
||||||
self.loop = loop
|
|
||||||
elif not self.loop:
|
|
||||||
self.loop = asyncio.get_event_loop()
|
|
||||||
|
|
||||||
|
async def run(self, user_id: str = None, password: str = None, token: str = None):
|
||||||
if not password and not token:
|
if not password and not token:
|
||||||
raise RuntimeError("Either the password or a token is required")
|
raise RuntimeError("Either the password or a token is required")
|
||||||
self.user_id = user_id
|
self.user_id = user_id
|
||||||
@ -101,10 +95,9 @@ class Client:
|
|||||||
event_dict["room"] = room
|
event_dict["room"] = room
|
||||||
event = self.process_event(event_dict)
|
event = self.process_event(event_dict)
|
||||||
await room.update_state(event)
|
await room.update_state(event)
|
||||||
handlers = self.event_dispatchers.get(event.type)
|
handler = self.event_dispatchers.get(event.type)
|
||||||
if handlers:
|
if handler:
|
||||||
for handler in handlers:
|
await self.invoke(handler, event)
|
||||||
self.loop.create_task(self.invoke(handler, event))
|
|
||||||
|
|
||||||
# Process ephemeral events
|
# Process ephemeral events
|
||||||
for event in data['ephemeral']['events']:
|
for event in data['ephemeral']['events']:
|
||||||
@ -122,13 +115,12 @@ class Client:
|
|||||||
if isinstance(event, StateEvent):
|
if isinstance(event, StateEvent):
|
||||||
await room.update_state(event)
|
await room.update_state(event)
|
||||||
elif isinstance(event, MessageEvent):
|
elif isinstance(event, MessageEvent):
|
||||||
if event.event_id not in room.message_cache:
|
if event not in room.message_cache:
|
||||||
room.message_cache[event.event_id] = event
|
room.message_cache.append(event)
|
||||||
if room.read_receipts[self.user_id][1] < event.origin_server_ts:
|
if room.read_receipts[self.user_id][1] < event.origin_server_ts:
|
||||||
handlers = self.event_dispatchers.get(event.type)
|
handler = self.event_dispatchers.get(event.type)
|
||||||
if handlers:
|
if handler:
|
||||||
for handler in handlers:
|
await self.invoke(handler, event)
|
||||||
self.loop.create_task(self.invoke(handler, event))
|
|
||||||
try:
|
try:
|
||||||
await self.mark_event_read(event)
|
await self.mark_event_read(event)
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
@ -160,8 +152,6 @@ class Client:
|
|||||||
return EventBase.from_dict(self, event)
|
return EventBase.from_dict(self, event)
|
||||||
elif event["type"] == "m.room.message":
|
elif event["type"] == "m.room.message":
|
||||||
return MessageEvent.from_dict(self, event)
|
return MessageEvent.from_dict(self, event)
|
||||||
elif event['type'] == 'm.room.redaction':
|
|
||||||
return RedactionEvent.from_dict(self, event)
|
|
||||||
else:
|
else:
|
||||||
return RoomEvent.from_dict(self, event)
|
return RoomEvent.from_dict(self, event)
|
||||||
|
|
||||||
@ -171,15 +161,9 @@ class Client:
|
|||||||
await handler(event)
|
await handler(event)
|
||||||
|
|
||||||
def register_handler(self, event_type, handler: callable):
|
def register_handler(self, event_type, handler: callable):
|
||||||
if not event_type:
|
|
||||||
event_type = handler.__name__.replace('_', '.')
|
|
||||||
|
|
||||||
if not callable(handler):
|
if not callable(handler):
|
||||||
raise TypeError(f'handler must be a callable not {type(handler)}')
|
raise TypeError(f'handler must be a callable not {type(handler)}')
|
||||||
if event_type in self.event_dispatchers:
|
self.event_dispatchers[event_type] = handler
|
||||||
self.event_dispatchers[event_type].append(handler)
|
|
||||||
else:
|
|
||||||
self.event_dispatchers[event_type] = [handler]
|
|
||||||
|
|
||||||
async def mark_event_read(self, event, receipt_type: str = 'm.read'):
|
async def mark_event_read(self, event, receipt_type: str = 'm.read'):
|
||||||
from .events import RoomEvent
|
from .events import RoomEvent
|
||||||
@ -194,7 +178,7 @@ class Client:
|
|||||||
|
|
||||||
async def send_text(self, room: Room, body: str, formatted_body: str = None, format_type: str = None):
|
async def send_text(self, room: Room, body: str, formatted_body: str = None, format_type: str = None):
|
||||||
content = {
|
content = {
|
||||||
'msgtype': 'm.notice',
|
'msgtype': 'm.text',
|
||||||
'body': body
|
'body': body
|
||||||
}
|
}
|
||||||
if formatted_body and format_type:
|
if formatted_body and format_type:
|
||||||
|
|||||||
@ -53,11 +53,6 @@ class MImageContent(MessageContentBase):
|
|||||||
file: Optional[EncryptedFile] = None
|
file: Optional[EncryptedFile] = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class MStickerContent(MImageContent):
|
|
||||||
msgtype = 'm.sticker'
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class MFileContent(MessageContentBase):
|
class MFileContent(MessageContentBase):
|
||||||
msgtype = "m.file"
|
msgtype = "m.file"
|
||||||
@ -149,7 +144,7 @@ class MRoomPowerLevelsContent(ContentBase):
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class MRoomRedactionContent(ContentBase):
|
class MRoomRedactionContent(ContentBase):
|
||||||
reason: Optional[str] = None
|
reason: str
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -198,7 +193,6 @@ content_dispatcher = {
|
|||||||
"m.emote": MEmoteContent,
|
"m.emote": MEmoteContent,
|
||||||
"m.notice": MNoticeContent,
|
"m.notice": MNoticeContent,
|
||||||
"m.image": MImageContent,
|
"m.image": MImageContent,
|
||||||
"m.sticker": MStickerContent,
|
|
||||||
"m.file": MFileContent,
|
"m.file": MFileContent,
|
||||||
"m.location": MLocationContent,
|
"m.location": MLocationContent,
|
||||||
"m.video": MVideoContent,
|
"m.video": MVideoContent,
|
||||||
|
|||||||
@ -18,8 +18,7 @@ class EventBase:
|
|||||||
def from_dict(cls, client: Client, event_dict: dict):
|
def from_dict(cls, client: Client, event_dict: dict):
|
||||||
from .content import content_dispatcher
|
from .content import content_dispatcher
|
||||||
if event_dict['type'] == 'm.room.message':
|
if event_dict['type'] == 'm.room.message':
|
||||||
content_class = content_dispatcher[event_dict['content']['msgtype']] \
|
content_class = content_dispatcher[event_dict['content']['msgtype']]
|
||||||
if event_dict['content'].get('msgtype') else ContentBase
|
|
||||||
else:
|
else:
|
||||||
content_class = content_dispatcher[event_dict['type']]
|
content_class = content_dispatcher[event_dict['type']]
|
||||||
|
|
||||||
@ -29,8 +28,6 @@ class EventBase:
|
|||||||
content_dict = {'options': event_dict['content']}
|
content_dict = {'options': event_dict['content']}
|
||||||
else:
|
else:
|
||||||
content_dict = event_dict['content']
|
content_dict = event_dict['content']
|
||||||
if event_dict['type'] == 'm.sticker':
|
|
||||||
content_dict['msgtype'] = 'm.sticker'
|
|
||||||
|
|
||||||
if content_dict.get('m.relates_to'):
|
if content_dict.get('m.relates_to'):
|
||||||
if content_dict['m.relates_to'].get('m.in_reply_to'):
|
if content_dict['m.relates_to'].get('m.in_reply_to'):
|
||||||
@ -44,16 +41,11 @@ class EventBase:
|
|||||||
|
|
||||||
del event_dict['content']
|
del event_dict['content']
|
||||||
|
|
||||||
try:
|
|
||||||
return cls(
|
return cls(
|
||||||
client=client,
|
client=client,
|
||||||
content=content_class(**content_dict),
|
content=content_class(**content_dict),
|
||||||
**event_dict
|
**event_dict
|
||||||
)
|
)
|
||||||
except Exception as e:
|
|
||||||
print(content_dict)
|
|
||||||
print(event_dict)
|
|
||||||
raise e
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|||||||
@ -1,3 +1,4 @@
|
|||||||
|
# TODO Add Room class
|
||||||
from typing import List, Optional, Dict, Tuple
|
from typing import List, Optional, Dict, Tuple
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from collections import deque
|
from collections import deque
|
||||||
@ -14,7 +15,7 @@ from .content import (
|
|||||||
MRoomRelatedGroupsContent,
|
MRoomRelatedGroupsContent,
|
||||||
MRoomTopicContent,
|
MRoomTopicContent,
|
||||||
)
|
)
|
||||||
from .utils import PreviousRoom, DequeDict
|
from .utils import PreviousRoom
|
||||||
|
|
||||||
|
|
||||||
class Room:
|
class Room:
|
||||||
@ -42,7 +43,7 @@ class Room:
|
|||||||
self.joined_member_count: Optional[int] = None
|
self.joined_member_count: Optional[int] = None
|
||||||
self.invited_member_count: Optional[int] = None
|
self.invited_member_count: Optional[int] = None
|
||||||
self.read_receipts: Dict[str, Tuple[str, int]] = {}
|
self.read_receipts: Dict[str, Tuple[str, int]] = {}
|
||||||
self.message_cache = DequeDict(max=1000)
|
self.message_cache = deque(maxlen=1000)
|
||||||
|
|
||||||
def update_read_receipts(self, receipts: Dict[str, Dict[str, Dict[str, Dict[str, int]]]]):
|
def update_read_receipts(self, receipts: Dict[str, Dict[str, Dict[str, Dict[str, int]]]]):
|
||||||
for event_id, receipt in receipts.items():
|
for event_id, receipt in receipts.items():
|
||||||
|
|||||||
@ -1,7 +1,5 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Optional, List, Dict
|
from typing import Optional, List, Dict
|
||||||
from inspect import isawaitable
|
|
||||||
from collections import OrderedDict
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -99,25 +97,5 @@ class MessageRelation:
|
|||||||
event_id: str
|
event_id: str
|
||||||
|
|
||||||
|
|
||||||
async def maybe_coroutine(func, *args, **kwargs):
|
|
||||||
f = func(*args, **kwargs)
|
|
||||||
if isawaitable(f):
|
|
||||||
return await f
|
|
||||||
else:
|
|
||||||
return f
|
|
||||||
|
|
||||||
|
|
||||||
def notification_power_levels_default_factory():
|
def notification_power_levels_default_factory():
|
||||||
return {'room': 50}
|
return {'room': 50}
|
||||||
|
|
||||||
|
|
||||||
class DequeDict(OrderedDict):
|
|
||||||
def __init__(self, *args, max: int = 0, **kwargs):
|
|
||||||
self._max = max
|
|
||||||
super(DequeDict, self).__init__(*args, **kwargs)
|
|
||||||
|
|
||||||
def __setitem__(self, key, value):
|
|
||||||
OrderedDict.__setitem__(self, key, value)
|
|
||||||
if self._max > 0:
|
|
||||||
if len(self) > self._max:
|
|
||||||
self.popitem(False)
|
|
||||||
|
|||||||
@ -1,112 +1,26 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
from typing import Union, Optional, Dict, List
|
from typing import Union, Optional, Dict
|
||||||
from inspect import isawaitable
|
|
||||||
from argparse import ArgumentParser
|
|
||||||
|
|
||||||
from morpheus.core.client import Client
|
from morpheus.core.client import Client
|
||||||
from morpheus.core.room import Room
|
from morpheus.core.room import Room
|
||||||
from morpheus.core.utils import maybe_coroutine
|
|
||||||
from morpheus.core.events import RoomEvent
|
|
||||||
from morpheus.core.content import MessageContentBase
|
|
||||||
from .context import Context
|
from .context import Context
|
||||||
from .command import Command
|
|
||||||
|
|
||||||
|
|
||||||
class Bot(Client):
|
class Bot(Client):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
prefix: Union[str, list, tuple, callable],
|
prefix: Union[str, list, tuple],
|
||||||
homeserver: str = "https://matrixcoding.chat",
|
homeserver: str = "https://matrixcoding.chat",
|
||||||
):
|
):
|
||||||
self.loop = asyncio.get_event_loop()
|
self.loop = asyncio.get_event_loop()
|
||||||
super(Bot, self).__init__(prefix=prefix, homeserver=homeserver)
|
super(Bot, self).__init__(prefix=prefix, homeserver=homeserver)
|
||||||
self.commands: Dict[str, Command] = {}
|
|
||||||
|
|
||||||
def run(self, user_id: str = None, password: str = None, token: str = None, loop: Optional[asyncio.AbstractEventLoop] = None):
|
def run(self, user_id: str = None, password: str = None, token: str = None):
|
||||||
loop = loop or self.loop or asyncio.get_event_loop()
|
loop = self.loop or asyncio.get_event_loop()
|
||||||
loop.run_until_complete(super(Bot, self).run(user_id, password, token, loop=loop))
|
loop.run_until_complete(super(Bot, self).run(user_id, password, token))
|
||||||
|
|
||||||
async def get_context(self, event: RoomEvent):
|
async def get_context(self, event):
|
||||||
if not isinstance(event.content, MessageContentBase):
|
|
||||||
return None
|
|
||||||
|
|
||||||
if callable(self.prefix):
|
|
||||||
prefix = await maybe_coroutine(self.prefix, event)
|
|
||||||
elif isinstance(self.prefix, (str, list, tuple)):
|
|
||||||
prefix = self.prefix
|
|
||||||
else:
|
|
||||||
raise RuntimeError('Prefix must be a string, list of strings or callable')
|
|
||||||
|
|
||||||
if isinstance(prefix, str):
|
async def check_event(self, event):
|
||||||
return self._get_context(event, prefix)
|
|
||||||
elif isinstance(prefix, (list, tuple)):
|
|
||||||
prefixes = tuple(prefix)
|
|
||||||
for prefix in prefixes:
|
|
||||||
try:
|
|
||||||
ctx = self._get_context(event, prefix)
|
|
||||||
if ctx:
|
|
||||||
return ctx
|
|
||||||
except TypeError:
|
|
||||||
raise RuntimeError('Prefix must be a string or list of strings')
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
else:
|
|
||||||
raise RuntimeError('Prefix must be a string or list of strings')
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _get_context(event: RoomEvent, prefix: str):
|
|
||||||
if not isinstance(event.content, MessageContentBase):
|
|
||||||
return None
|
|
||||||
|
|
||||||
raw_body = event.content.body
|
|
||||||
if not raw_body.startswith(prefix):
|
|
||||||
return None
|
|
||||||
raw_body = raw_body.lstrip(prefix)
|
|
||||||
body_list = raw_body.split(' ', 1)
|
|
||||||
called_with = body_list[0]
|
|
||||||
body = body_list[1] if len(body_list) > 1 else None
|
|
||||||
return Context.get_context(event, prefix, called_with, body)
|
|
||||||
|
|
||||||
async def process_command(self, event):
|
|
||||||
if not event.content.msgtype == 'm.text':
|
|
||||||
return
|
|
||||||
|
|
||||||
ctx = await self.get_context(event)
|
|
||||||
if not ctx:
|
|
||||||
return
|
|
||||||
|
|
||||||
command = self.commands.get(ctx.called_with)
|
|
||||||
if not command:
|
|
||||||
return
|
|
||||||
await command.invoke(ctx, ctx.body.split(' ') if ctx.body else None)
|
|
||||||
|
|
||||||
def listener(self, name=None):
|
|
||||||
def decorator(func):
|
|
||||||
self.register_handler(name, func)
|
|
||||||
return decorator
|
|
||||||
|
|
||||||
def add_command(self, name: str, aliases: list, func: callable):
|
|
||||||
if not name:
|
|
||||||
name = func.__name__
|
|
||||||
|
|
||||||
if name.startswith('_'):
|
|
||||||
raise RuntimeWarning(f'Command names cannot start with an underscore')
|
|
||||||
|
|
||||||
if aliases is None:
|
|
||||||
aliases = []
|
|
||||||
|
|
||||||
if not isinstance(aliases, list) or any([not isinstance(alias, str) for alias in aliases]):
|
|
||||||
raise RuntimeWarning(f'Aliases must be a list of strings.')
|
|
||||||
|
|
||||||
if name in self.commands or any([alias in self.commands for alias in aliases]):
|
|
||||||
raise RuntimeWarning(f'Command {name} has already been registered')
|
|
||||||
|
|
||||||
command = Command(func)
|
|
||||||
self.commands[name] = command
|
|
||||||
for alias in aliases:
|
|
||||||
self.commands[alias] = command
|
|
||||||
|
|
||||||
def command(self, name: Optional[str] = None, aliases: Optional[list] = None):
|
|
||||||
def decorator(func):
|
|
||||||
self.add_command(name=name, aliases=aliases, func=func)
|
|
||||||
return decorator
|
|
||||||
|
|||||||
@ -1,83 +0,0 @@
|
|||||||
import inspect
|
|
||||||
from argparse import ArgumentParser
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
|
|
||||||
class Command:
|
|
||||||
def __init__(self, function: callable, extension: str = None):
|
|
||||||
if not callable(function):
|
|
||||||
raise RuntimeError('The function to make a command from must be a callable')
|
|
||||||
|
|
||||||
if not inspect.iscoroutinefunction(function):
|
|
||||||
raise RuntimeError('The function to make a command from must be a coroutine')
|
|
||||||
|
|
||||||
self.extension = extension
|
|
||||||
self.signature = inspect.signature(function)
|
|
||||||
self.parser: ArgumentParser = self.process_parameters(self.signature.parameters)
|
|
||||||
self.function: callable = function
|
|
||||||
|
|
||||||
def process_parameters(self, params: dict) -> ArgumentParser:
|
|
||||||
iterator = iter(params.items())
|
|
||||||
|
|
||||||
if self.extension:
|
|
||||||
try:
|
|
||||||
next(iterator)
|
|
||||||
except StopIteration:
|
|
||||||
raise RuntimeError('self is missing from signature')
|
|
||||||
|
|
||||||
try:
|
|
||||||
next(iterator) # the next param should be ctx
|
|
||||||
except StopIteration:
|
|
||||||
raise RuntimeError('ctx is missing from signature')
|
|
||||||
|
|
||||||
parser = ArgumentParser()
|
|
||||||
for name, param in iterator:
|
|
||||||
param: inspect.Parameter
|
|
||||||
if param.kind == param.VAR_POSITIONAL:
|
|
||||||
nargs = '+'
|
|
||||||
else:
|
|
||||||
nargs = 1
|
|
||||||
|
|
||||||
if param.annotation == param.empty:
|
|
||||||
param_type = str
|
|
||||||
else:
|
|
||||||
param_type = param.annotation
|
|
||||||
|
|
||||||
if param.kind == param.KEYWORD_ONLY:
|
|
||||||
name = '--' + name
|
|
||||||
|
|
||||||
if param.default == param.empty:
|
|
||||||
parser.add_argument(name, nargs=nargs, type=param_type)
|
|
||||||
else:
|
|
||||||
parser.add_argument(name, nargs=nargs, type=param_type, default=param.default)
|
|
||||||
|
|
||||||
return parser
|
|
||||||
|
|
||||||
async def invoke(self, ctx, args_list):
|
|
||||||
iterator = iter(self.signature.parameters.items())
|
|
||||||
|
|
||||||
if self.extension:
|
|
||||||
try:
|
|
||||||
next(iterator)
|
|
||||||
except StopIteration:
|
|
||||||
raise RuntimeError('self is missing from signature')
|
|
||||||
|
|
||||||
try:
|
|
||||||
next(iterator) # the next param should be ctx
|
|
||||||
except StopIteration:
|
|
||||||
raise RuntimeError('ctx is missing from signature')
|
|
||||||
|
|
||||||
args = []
|
|
||||||
kwargs = {}
|
|
||||||
if args_list:
|
|
||||||
params, ctx.extra_params = self.parser.parse_known_args(args_list)
|
|
||||||
|
|
||||||
for key, value in iterator:
|
|
||||||
value: inspect.Parameter
|
|
||||||
if value.kind == value.VAR_POSITIONAL or value.kind == value.POSITIONAL_OR_KEYWORD:
|
|
||||||
args.extend(params.__dict__[key])
|
|
||||||
else:
|
|
||||||
kwargs[key] = params.__dict__[key]
|
|
||||||
await self.function(ctx, *args, **kwargs)
|
|
||||||
else:
|
|
||||||
await self.function(ctx)
|
|
||||||
@ -1,24 +1,6 @@
|
|||||||
from morpheus.core.client import Client
|
from morpheus.core.client import Client
|
||||||
from morpheus.core.room import Room
|
from morpheus.core.room import Room
|
||||||
from morpheus.core.events import RoomEvent
|
|
||||||
from morpheus.core.content import ContentBase
|
|
||||||
|
|
||||||
|
|
||||||
class Context:
|
class Context:
|
||||||
def __init__(self, client: Client, room: Room, calling_prefix: str, sender: str, event: RoomEvent, content: ContentBase, called_with: str, body: str):
|
def __init__(self, client: Client, room: Room, prefix: str, sender: str, ):
|
||||||
self.client: Client = client
|
self.client: Client
|
||||||
self.room: Room = room
|
|
||||||
self.calling_prefix: str = calling_prefix
|
|
||||||
self.sender: str = sender # TODO once the User class is created change this to type User
|
|
||||||
self.event: RoomEvent = event
|
|
||||||
self.content: ContentBase = content
|
|
||||||
self.called_with: str = called_with
|
|
||||||
self.body: str = body
|
|
||||||
self.extra_params: list = []
|
|
||||||
|
|
||||||
async def send_text(self, body: str, formatted_body: str = None, format_type: str = 'org.matrix.custom.html'):
|
|
||||||
await self.client.send_text(self.room, body, formatted_body, format_type)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_context(cls, event: RoomEvent, calling_prefix: str, called_with: str, body: str):
|
|
||||||
return cls(event.client, event.room, calling_prefix, event.sender, event, event.content, called_with, body)
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user