Merge pull request #2 from dustinpianalto/development

Merge in changes
This commit is contained in:
Dusty.P 2020-04-07 21:38:34 -08:00 committed by GitHub
commit 4e029e99ea
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 470 additions and 202 deletions

View File

@ -42,7 +42,7 @@ ENV REDIS_DB 0
ENV REDIS_HOST redis.geeksbot.com
ENV REDIS_PORT 6379
ENV USE_DOCKER yes
ENV DISCORD_DEFAULT_PREFIX g$
ENV DISCORD_DEFAULT_PREFIX g#
ENV PYTHONPATH /code
COPY entrypoint .

View File

@ -8,6 +8,7 @@
"command_events",
"tickets",
"inspect",
"rcon"
"rcon",
"git"
]
}

142
geeksbot/exts/git.py Normal file
View File

@ -0,0 +1,142 @@
import asyncio
import logging
import discord
from discord.ext import commands
from geeksbot.imports.checks import is_me
from geeksbot.imports.utils import Book, Paginator, run_command
git_log = logging.getLogger("git")
class Git:
def __init__(self, bot):
self.bot = bot
@commands.group(case_insensitive=True, invoke_without_command=True)
async def git(self, ctx):
"""Shows my Git link"""
branch = (
await asyncio.wait_for(
self.bot.loop.create_task(
run_command(
"git rev-parse --symbolic-full-name " "--abbrev-ref HEAD"
)
),
120,
)
).split("\n")[0]
url = f"{self.bot.git_url}/tree/{branch}"
em = discord.Embed(
title=f"Here is where you can find my code",
url=url,
color=self.bot.embed_color,
)
if branch == "master":
em.description = (
"I am Geeksbot. You can find my code here:\n" f"{self.bot.git_url}"
)
else:
em.description = (
f"I am the {branch} branch of Geeksbot. "
f"You can find the master branch here:\n"
f"{self.bot.git_url}/tree/master"
)
em.set_thumbnail(url=f"{ctx.guild.me.avatar_url}")
await ctx.send(embed=em)
@git.command()
@is_me()
async def pull(self, ctx):
"""Pulls updates from GitHub rebasing branch."""
pag = Paginator(self.bot, max_line_length=100, embed=True)
pag.set_embed_meta(
title="Git Pull",
color=self.bot.embed_color,
thumbnail=f"{ctx.guild.me.avatar_url}",
)
pag.add(
"\uFFF6"
+ await asyncio.wait_for(
self.bot.loop.create_task(run_command("git fetch --all")), 120
)
)
pag.add("\uFFF7\n\uFFF8")
pag.add(
await asyncio.wait_for(
self.bot.loop.create_task(
run_command(
"git reset --hard "
"origin/$(git "
"rev-parse --symbolic-full-name"
" --abbrev-ref HEAD)"
)
),
120,
)
)
pag.add("\uFFF7\n\uFFF8")
pag.add(
await asyncio.wait_for(
self.bot.loop.create_task(
run_command("git show --stat | " 'sed "s/.*@.*[.].*/ /g"')
),
10,
)
)
book = Book(pag, (None, ctx.channel, self.bot, ctx.message))
await book.create_book()
@git.command()
@is_me()
async def status(self, ctx):
"""Gets status of current branch."""
pag = Paginator(self.bot, max_line_length=44, max_lines=30, embed=True)
pag.set_embed_meta(
title="Git Status",
color=self.bot.embed_color,
thumbnail=f"{ctx.guild.me.avatar_url}",
)
result = await asyncio.wait_for(
self.bot.loop.create_task(run_command("git status")), 120
)
pag.add(result)
book = Book(pag, (None, ctx.channel, self.bot, ctx.message))
await book.create_book()
@git.command()
@is_me()
async def checkout(self, ctx, branch: str = "master"):
"""Checks out the requested branch.
If no branch name is provided will checkout the master branch"""
pag = Paginator(self.bot, max_line_length=44, max_lines=30, embed=True)
branches_str = await asyncio.wait_for(
self.bot.loop.create_task(run_command(f"git branch -a")), 120
)
existing_branches = [
b.strip().split("/")[-1]
for b in branches_str.replace("*", "").split("\n")[:-1]
]
if branch not in existing_branches:
pag.add(f"There is no existing branch named {branch}.")
pag.set_embed_meta(
title="Git Checkout",
color=self.bot.error_color,
thumbnail=f"{ctx.guild.me.avatar_url}",
)
else:
pag.set_embed_meta(
title="Git Checkout",
color=self.bot.embed_color,
thumbnail=f"{ctx.guild.me.avatar_url}",
)
result = await asyncio.wait_for(
self.bot.loop.create_task(run_command(f"git checkout -f {branch}")), 120
)
pag.add(result)
book = Book(pag, (None, ctx.channel, self.bot, ctx.message))
await book.create_book()
def setup(bot):
bot.add_cog(Git(bot))

View File

@ -34,3 +34,10 @@ def is_admin():
return True
return False
return discord.ext.commands.check(predicate)
def is_production():
def predicate(ctx):
return not os.environ['DEBUG']
return discord.ext.commands.check(predicate)

View File

@ -1,74 +1,90 @@
import json
import logging
import os
import time
import json
from concurrent import futures
from multiprocessing import Pool
import logging
import redis
import aiohttp
import discord
from discord.ext import commands
from discord.ext.commands.context import Context
from geeksbot.imports.strings import MyStringView
import redis
import aiohttp
from geeksbot.imports.geeksbot_api import GeeksbotAPI
geeksbot_logger = logging.getLogger('Geeksbot')
geeksbot_logger = logging.getLogger("Geeksbot")
class Geeksbot(commands.Bot):
def __init__(self, *args, **kwargs):
self.default_prefix = os.environ.get('DISCORD_DEFAULT_PREFIX', 'g$')
kwargs['command_prefix'] = self.get_prefixes
self.default_prefix = os.environ.get("DISCORD_DEFAULT_PREFIX", "g$")
kwargs["command_prefix"] = self.get_prefixes
self.description = "Geeksbot v2"
kwargs['description'] = self.description
kwargs["description"] = self.description
super().__init__(*args, **kwargs)
self.config_dir = 'geeksbot/config'
self.config_file = 'bot_config.json'
self.extension_dir = 'exts'
self.cache = redis.Redis(host=os.environ['REDIS_HOST'], port=os.environ['REDIS_PORT'], db=os.environ['REDIS_DB'], charset="utf-8", decode_responses=True, health_check_interval=10)
self.settings_cache = redis.Redis(host=os.environ['REDIS_HOST'], port=os.environ['REDIS_PORT'], db=1, charset="utf-8", decode_responses=True, health_check_interval=10)
self.token = self.settings_cache.get('DISCORD_TOKEN')
self.api_token = self.settings_cache.get('API_TOKEN')
self.config_dir = "geeksbot/config"
self.config_file = "bot_config.json"
self.extension_dir = "exts"
self.cache = redis.Redis(
host=os.environ["REDIS_HOST"],
port=os.environ["REDIS_PORT"],
db=os.environ["REDIS_DB"],
charset="utf-8",
decode_responses=True,
health_check_interval=10,
)
self.settings_cache = redis.Redis(
host=os.environ["REDIS_HOST"],
port=os.environ["REDIS_PORT"],
db=1,
charset="utf-8",
decode_responses=True,
health_check_interval=10,
)
self.token = self.settings_cache.get("DISCORD_TOKEN")
self.api_token = self.settings_cache.get("API_TOKEN")
self.aio_session = aiohttp.ClientSession(loop=self.loop)
self.auth_header = {'Authorization': f'Token {self.api_token}'}
self.api_base = 'https://geeksbot.app/api'
with open(f'{self.config_dir}/{self.config_file}') as f:
self.api_base = "https://geeksbot.app/api"
self.api = GeeksbotAPI(self.api_token, self.api_base, self.loop)
with open(f"{self.config_dir}/{self.config_file}") as f:
self.bot_config = json.load(f)
self.embed_color = discord.Colour.from_rgb(49, 107, 111)
self.error_color = discord.Colour.from_rgb(142, 29, 31)
self.tpe = futures.ThreadPoolExecutor(max_workers=20)
self.process_pool = Pool(processes=4)
self.geo_api = '2d4e419c2be04c8abe91cb5dd1548c72'
self.git_url = 'https://github.com/dustinpianalto/geeksbot_v2'
self.geo_api = "2d4e419c2be04c8abe91cb5dd1548c72"
self.git_url = "https://github.com/dustinpianalto/geeksbot_v2"
self.load_default_extensions()
self.owner_id = 351794468870946827
self.success_emoji = '\N{WHITE HEAVY CHECK MARK}'
self.success_emoji = "\N{WHITE HEAVY CHECK MARK}"
self.book_emojis = {
'unlock': '🔓',
'start': '',
'back': '',
'hash': '#\N{COMBINING ENCLOSING KEYCAP}',
'forward': '',
'end': '',
'close': '🇽',
"unlock": "🔓",
"start": "",
"back": "",
"hash": "#\N{COMBINING ENCLOSING KEYCAP}",
"forward": "",
"end": "",
"close": "🇽",
}
async def load_ext(self, mod):
try:
self.load_extension(f'geeksbot.{self.extension_dir}.{mod}')
geeksbot_logger.info(f'Extension Loaded: {mod}')
self.load_extension(f"geeksbot.{self.extension_dir}.{mod}")
geeksbot_logger.info(f"Extension Loaded: {mod}")
except Exception:
geeksbot_logger.exception(f"Error loading {mod}")
async def unload_ext(self, mod):
try:
self.unload_extension(f'geeksbot.{self.extension_dir}.{mod}')
geeksbot_logger.info(f'Extension Unloaded: {mod}')
self.unload_extension(f"geeksbot.{self.extension_dir}.{mod}")
geeksbot_logger.info(f"Extension Unloaded: {mod}")
except Exception:
geeksbot_logger.exception(f"Error loading {mod}")
def load_default_extensions(self):
for load_item in self.bot_config['load_list']:
for load_item in self.bot_config["load_list"]:
self.loop.create_task(self.load_ext(load_item))
async def get_prefixes(self, bot, message):
@ -94,7 +110,7 @@ class Geeksbot(commands.Bot):
cls
The factory class that will be used to create the context.
By default, this is :class:`.Context`. Should a custom
class be provided, it must be similar enough to :class:`.Context`\'s
class be provided, it must be similar enough to :class:`.Context`'s
interface.
Returns
@ -118,8 +134,8 @@ class Geeksbot(commands.Bot):
return ctx
else:
try:
# if the context class' __init__ consumes something from the view this
# will be wrong. That seems unreasonable though.
# if the context class' __init__ consumes something from the
# view this will be wrong. That seems unreasonable though.
if message.content.casefold().startswith(tuple(prefix)):
invoked_prefix = discord.utils.find(view.skip_string, prefix)
else:
@ -127,14 +143,22 @@ class Geeksbot(commands.Bot):
except TypeError:
if not isinstance(prefix, list):
raise TypeError("get_prefix must return either a string or a list of string, "
"not {}".format(prefix.__class__.__name__))
raise TypeError(
"get_prefix must return either a string "
"or a list of string, "
"not {}".format(prefix.__class__.__name__)
)
# It's possible a bad command_prefix got us here.
for value in prefix:
if not isinstance(value, str):
raise TypeError("Iterable command_prefix or list returned from get_prefix must "
"contain only strings, not {}".format(value.__class__.__name__))
raise TypeError(
"Iterable command_prefix or list "
"returned from get_prefix must "
"contain only strings, not {}".format(
value.__class__.__name__
)
)
# Getting here shouldn't happen
raise
@ -149,6 +173,6 @@ class Geeksbot(commands.Bot):
await self.aio_session.close()
await super().close()
time.sleep(5)
geeksbot_logger.info('Exiting...')
geeksbot_logger.info("Exiting...")
# noinspection PyProtectedMember
os._exit(1)

View File

@ -0,0 +1,78 @@
import aiohttp
import asyncio
import logging
logger = logging.getLogger("Geeksbot API")
class APIError500(Exception):
pass
class APIError(Exception):
pass
class GeeksbotAPI:
def __init__(
self,
token: str,
base_url: str = "https://geeksbot.app/api",
loop: asyncio.AbstractEventLoop = None,
):
self.loop = loop or asyncio.get_event_loop()
self.session = aiohttp.ClientSession(loop=self.loop)
self.base_url = base_url
self.token = token
self.auth_header = {"Authorization": f"Token {self.token}"}
def clean_endpoint(endpoint: str) -> str:
endpoint = endpoint[1:] if endpoint.startswith("/") else endpoint
endpoint += "/" if not endpoint.endswith("/") else ""
return endpoint
async def query(
self, method: str, endpoint: str, data: dict = None, query: dict = None
):
endpoint = self.clean_endpoint(endpoint)
if method.lower() == "get":
resp = await self.session.get(
f"{self.base_url}/{endpoint}{query_str}",
headers=self.auth_header,
params=query,
json=data,
)
elif method.lower() == "post":
resp = await self.session.post(
f"{self.base_url}/{endpoint}{query_str}",
headers=self.auth_header,
params=query,
json=data,
)
elif method.lower() == "put":
resp = await self.session.put(
f"{self.base_url}/{endpoint}{query_str}",
headers=self.auth_header,
params=query,
json=data,
)
elif method.lower() == "delete":
resp = await self.session.delete(
f"{self.base_url}/{endpoint}{query_str}",
headers=self.auth_header,
params=query,
json=data,
)
else:
raise APIError(f"That is not a valid method. {method}")
if resp.status == 200:
return await resp.json()
elif resp.status >= 500:
raise APIError500(
"The server returned a 500 error. "
"The developers have been notified of the issue."
)
else:
details = (await resp.json())["details"]
raise APIError(details)

View File

@ -4,17 +4,15 @@ import typing
from datetime import datetime
async def get_guild_config(bot, guild_id):
guild_config = bot.cache.get()
def create_date_string(time, time_now):
diff = (time_now - time)
date_str = time.strftime('%Y-%m-%d %H:%M:%S')
return f"{diff.days} {'day' if diff.days == 1 else 'days'} " \
f"{diff.seconds // 3600} {'hour' if diff.seconds // 3600 == 1 else 'hours'} " \
f"{diff.seconds % 3600 // 60} {'minute' if diff.seconds % 3600 // 60 == 1 else 'minutes'} " \
diff = time_now - time
date_str = time.strftime("%Y-%m-%d %H:%M:%S")
return (
f"{diff.days} {'day' if diff.days == 1 else 'days'} "
f"{diff.seconds // 3600} {'hour' if diff.seconds // 3600 == 1 else 'hours'} "
f"{diff.seconds % 3600 // 60} {'minute' if diff.seconds % 3600 // 60 == 1 else 'minutes'} "
f"{diff.seconds % 3600 % 60} {'second' if diff.seconds % 3600 % 60 == 1 else 'seconds'} ago.\n{date_str}"
)
def process_snowflake(snowflake: int) -> typing.Tuple[datetime, int, int, int]:
@ -25,71 +23,28 @@ def process_snowflake(snowflake: int) -> typing.Tuple[datetime, int, int, int]:
PROCESS_ID_LOC = 12
PROCESS_ID_MASK = 0x1F000
INCREMENT_MASK = 0xFFF
creation_time = datetime.fromtimestamp(((snowflake >> TIME_BITS_LOC) + DISCORD_EPOCH) / 1000.0)
creation_time = datetime.fromtimestamp(
((snowflake >> TIME_BITS_LOC) + DISCORD_EPOCH) / 1000.0
)
worker_id = (snowflake & WORKER_ID_MASK) >> WORKER_ID_LOC
process_id = (snowflake & PROCESS_ID_MASK) >> PROCESS_ID_LOC
counter = snowflake & INCREMENT_MASK
return creation_time, worker_id, process_id, counter
# noinspection PyDefaultArgument
def to_list_of_str(items, out: list = list(), level=1, recurse=0):
# noinspection PyShadowingNames
def rec_loop(item, key, out, level):
quote = '"'
if type(item) == list:
out.append(f'{" "*level}{quote+key+quote+": " if key else ""}[')
new_level = level + 1
out = to_list_of_str(item, out, new_level, 1)
out.append(f'{" "*level}]')
elif type(item) == dict:
out.append(f'{" "*level}{quote+key+quote+": " if key else ""}{{')
new_level = level + 1
out = to_list_of_str(item, out, new_level, 1)
out.append(f'{" "*level}}}')
else:
out.append(f'{" "*level}{quote+key+quote+": " if key else ""}{repr(item)},')
if type(items) == list:
if not recurse:
out = list()
out.append('[')
for item in items:
rec_loop(item, None, out, level)
if not recurse:
out.append(']')
elif type(items) == dict:
if not recurse:
out = list()
out.append('{')
for key in items:
rec_loop(items[key], key, out, level)
if not recurse:
out.append('}')
return out
def format_output(text):
if type(text) == list:
text = to_list_of_str(text)
elif type(text) == dict:
text = to_list_of_str(text)
return text
async def run_command(args):
# Create subprocess
process = await asyncio.create_subprocess_shell(
f'time -f "Process took %e seconds (%U user | %S system) and used %P of the CPU" {args}',
# stdout must a pipe to be accessible as process.stdout
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE)
stderr=asyncio.subprocess.PIPE,
)
# Wait for the subprocess to finish
stdout, stderr = await process.communicate()
# Return stdout
if stderr and stderr.strip() != '':
output = f'{stdout.decode().strip()}\n{stderr.decode().strip()}'
if stderr and stderr.strip() != "":
output = f"{stdout.decode().strip()}\n{stderr.decode().strip()}"
else:
output = stdout.decode().strip()
return output
@ -97,29 +52,34 @@ async def run_command(args):
# noinspection PyShadowingNames
class Paginator:
def __init__(self,
def __init__(
self,
bot: discord.ext.commands.Bot,
*,
max_chars: int = 1970,
max_lines: int = 20,
prefix: str = '```md',
suffix: str = '```',
page_break: str = '\uFFF8',
field_break: str = '\uFFF7',
field_name_char: str = '\uFFF6',
inline_char: str = '\uFFF5',
prefix: str = "```md",
suffix: str = "```",
page_break: str = "\uFFF8",
field_break: str = "\uFFF7",
field_name_char: str = "\uFFF6",
inline_char: str = "\uFFF5",
max_line_length: int = 100,
embed=False,
header: str = ''):
header: str = "",
):
_max_len = 6000 if embed else 1980
assert 0 < max_lines <= max_chars
self._parts = list()
self._prefix = prefix
self._suffix = suffix
self._max_chars = max_chars if max_chars + len(prefix) + len(suffix) + 2 <= _max_len \
self._max_chars = (
max_chars
if max_chars + len(prefix) + len(suffix) + 2 <= _max_len
else _max_len - len(prefix) - len(suffix) - 2
self._max_lines = max_lines - (prefix + suffix).count('\n') + 1
)
self._max_lines = max_lines - (prefix + suffix).count("\n") + 1
self._page_break = page_break
self._max_line_length = max_line_length
self._pages = list()
@ -130,26 +90,29 @@ class Paginator:
self._field_break = field_break
self._field_name_char = field_name_char
self._inline_char = inline_char
self._embed_title = ''
self._embed_description = ''
self._embed_title = ""
self._embed_description = ""
self._embed_color = None
self._embed_thumbnail = None
self._embed_url = None
self._bot = bot
self._header = header
def set_embed_meta(self, title: str = None,
def set_embed_meta(
self,
title: str = None,
description: str = None,
color: discord.Colour = None,
thumbnail: str = None,
footer: str = '',
url: str = None):
footer: str = "",
url: str = None,
):
if title and len(title) > self._max_field_name:
raise RuntimeError('Provided Title is too long')
raise RuntimeError("Provided Title is too long")
else:
self._embed_title = title
if description and len(description) > self._max_description:
raise RuntimeError('Provided Description is too long')
raise RuntimeError("Provided Description is too long")
else:
self._embed_description = description
self._embed_color = color
@ -159,10 +122,10 @@ class Paginator:
def pages(self, page_headers: bool = True) -> typing.List[str]:
_pages = list()
_fields = list()
_page = ''
_page = ""
_lines = 0
_field_name = ''
_field_value = ''
_field_name = ""
_field_value = ""
_inline = False
def open_page(initial: bool = False):
@ -173,7 +136,7 @@ class Paginator:
elif page_headers:
_page = self._header
else:
_page = ''
_page = ""
_page += self._prefix
_lines = 0
else:
@ -200,12 +163,13 @@ class Paginator:
if new_chars > self._max_chars:
close_page()
elif (_lines + (part.count('\n') + 1 or 1)) > self._max_lines:
elif (_lines + (part.count("\n") + 1 or 1)) > self._max_lines:
close_page()
_lines += (part.count('\n') + 1 or 1)
_page += '\n' + part
_lines += part.count("\n") + 1 or 1
_page += "\n" + part
else:
def open_field(name: str):
nonlocal _field_value, _field_name
_field_name = name
@ -215,28 +179,30 @@ class Paginator:
nonlocal _field_name, _field_value, _fields
_field_value += self._suffix
if _field_value != self._prefix + self._suffix:
_fields.append({'name': _field_name, 'value': _field_value, 'inline': _inline})
_fields.append(
{"name": _field_name, "value": _field_value, "inline": _inline}
)
if next_name:
open_field(next_name)
open_field('\uFFF0')
open_field("\uFFF0")
for part in [str(p) for p in self._parts]:
if part.strip().startswith(self._page_break):
close_page()
elif part == self._field_break:
if len(_fields) + 1 < 25:
close_field(next_name='\uFFF0')
close_field(next_name="\uFFF0")
else:
close_field()
close_page()
continue
if part.startswith(self._field_name_char):
part = part.replace(self._field_name_char, '')
part = part.replace(self._field_name_char, "")
if part.startswith(self._inline_char):
_inline = True
part = part.replace(self._inline_char, '')
part = part.replace(self._inline_char, "")
else:
_inline = False
if _field_value and _field_value != self._prefix:
@ -245,7 +211,7 @@ class Paginator:
_field_name = part
continue
_field_value += '\n' + part
_field_value += "\n" + part
close_field()
@ -257,14 +223,15 @@ class Paginator:
def process_pages(self) -> typing.List[str]:
_pages = self._pages or self.pages()
_len_pages = len(_pages)
_len_page_str = len(f'{_len_pages}/{_len_pages}')
_len_page_str = len(f"{_len_pages}/{_len_pages}")
if not self._embed:
for i, page in enumerate(_pages):
if len(page) + _len_page_str <= 2000:
_pages[i] = f'{i + 1}/{_len_pages}\n{page}'
_pages[i] = f"{i + 1}/{_len_pages}\n{page}"
else:
for i, page in enumerate(_pages):
em = discord.Embed(title=self._embed_title,
em = discord.Embed(
title=self._embed_title,
description=self._embed_description,
color=self._bot.embed_color,
)
@ -274,9 +241,11 @@ class Paginator:
em.url = self._embed_url
if self._embed_color:
em.color = self._embed_color
em.set_footer(text=f'{i + 1}/{_len_pages}')
em.set_footer(text=f"{i + 1}/{_len_pages}")
for field in page:
em.add_field(name=field['name'], value=field['value'], inline=field['inline'])
em.add_field(
name=field["name"], value=field["value"], inline=field["inline"]
)
_pages[i] = em
return _pages
@ -287,54 +256,71 @@ class Paginator:
# noinspection PyProtectedMember
return self.__class__ == other.__class__ and self._parts == other._parts
def set_header(self, header: str = ''):
def set_header(self, header: str = ""):
self._header = header
def add_page_break(self, *, to_beginning: bool = False) -> None:
self.add(self._page_break, to_beginning=to_beginning)
def add(self, item: typing.Any, *, to_beginning: bool = False, keep_intact: bool = False, truncate=False) -> None:
def add(
self,
item: typing.Any,
*,
to_beginning: bool = False,
keep_intact: bool = False,
truncate=False,
) -> None:
item = str(item)
i = 0
if not keep_intact and not item == self._page_break:
item_parts = item.strip('\n').split('\n')
item_parts = item.strip("\n").split("\n")
for part in item_parts:
if len(part) > self._max_line_length:
if not truncate:
length = 0
out_str = ''
out_str = ""
def close_line(line):
nonlocal i, out_str, length
self._parts.insert(i, out_str) if to_beginning else self._parts.append(out_str)
self._parts.insert(
i, out_str
) if to_beginning else self._parts.append(out_str)
i += 1
out_str = line + ' '
out_str = line + " "
length = len(out_str)
bits = part.split(' ')
bits = part.split(" ")
for bit in bits:
next_len = length + len(bit) + 1
if next_len <= self._max_line_length:
out_str += bit + ' '
out_str += bit + " "
length = next_len
elif len(bit) > self._max_line_length:
if out_str:
close_line(line='')
for out_str in [bit[i:i + self._max_line_length]
for i in range(0, len(bit), self._max_line_length)]:
close_line('')
close_line(line="")
for out_str in [
bit[i : i + self._max_line_length]
for i in range(0, len(bit), self._max_line_length)
]:
close_line("")
else:
close_line(bit)
close_line('')
close_line("")
else:
line = f'{part:.{self._max_line_length-3}}...'
self._parts.insert(i, line) if to_beginning else self._parts.append(line)
line = f"{part:.{self._max_line_length-3}}..."
self._parts.insert(
i, line
) if to_beginning else self._parts.append(line)
else:
self._parts.insert(i, part) if to_beginning else self._parts.append(part)
self._parts.insert(i, part) if to_beginning else self._parts.append(
part
)
i += 1
elif keep_intact and not item == self._page_break:
if len(item) >= self._max_chars or item.count('\n') > self._max_lines:
raise RuntimeError('{item} is too long to keep on a single page and is marked to keep intact.')
if len(item) >= self._max_chars or item.count("\n") > self._max_lines:
raise RuntimeError(
"{item} is too long to keep on a single page and is marked to keep intact."
)
if to_beginning:
self._parts.insert(0, item)
else:
@ -347,17 +333,23 @@ class Paginator:
class Book:
def __init__(self, pag: Paginator, ctx: typing.Tuple[typing.Optional[discord.Message],
def __init__(
self,
pag: Paginator,
ctx: typing.Tuple[
typing.Optional[discord.Message],
discord.TextChannel,
discord.ext.commands.Bot,
discord.Message]) -> None:
discord.Message,
],
) -> None:
self._pages = pag.process_pages()
self._len_pages = len(self._pages)
self._current_page = 0
self._message, self._channel, self._bot, self._calling_message = ctx
self._locked = True
if pag == Paginator(self._bot):
raise RuntimeError('Cannot create a book out of an empty Paginator.')
raise RuntimeError("Cannot create a book out of an empty Paginator.")
def advance_page(self) -> None:
self._current_page += 1
@ -372,14 +364,22 @@ class Book:
async def display_page(self) -> None:
if isinstance(self._pages[self._current_page], discord.Embed):
if self._message:
await self._message.edit(content=None, embed=self._pages[self._current_page])
await self._message.edit(
content=None, embed=self._pages[self._current_page]
)
else:
self._message = await self._channel.send(embed=self._pages[self._current_page])
self._message = await self._channel.send(
embed=self._pages[self._current_page]
)
else:
if self._message:
await self._message.edit(content=self._pages[self._current_page], embed=None)
await self._message.edit(
content=self._pages[self._current_page], embed=None
)
else:
self._message = await self._channel.send(self._pages[self._current_page])
self._message = await self._channel.send(
self._pages[self._current_page]
)
async def create_book(self) -> None:
# noinspection PyUnresolvedReferences
@ -387,12 +387,16 @@ class Book:
# noinspection PyShadowingNames
def check(reaction, user):
if self._locked:
return str(reaction.emoji) in self._bot.book_emojis.values() \
and user == self._calling_message.author \
return (
str(reaction.emoji) in self._bot.book_emojis.values()
and user == self._calling_message.author
and reaction.message.id == self._message.id
)
else:
return str(reaction.emoji) in self._bot.book_emojis.values() \
return (
str(reaction.emoji) in self._bot.book_emojis.values()
and reaction.message.id == self._message.id
)
await self.display_page()
@ -404,14 +408,16 @@ class Book:
pass
else:
try:
await self._message.add_reaction(self._bot.book_emojis['unlock'])
await self._message.add_reaction(self._bot.book_emojis['close'])
await self._message.add_reaction(self._bot.book_emojis["unlock"])
await self._message.add_reaction(self._bot.book_emojis["close"])
except (discord.Forbidden, KeyError):
pass
while True:
try:
reaction, user = await self._bot.wait_for('reaction_add', timeout=60, check=check)
reaction, user = await self._bot.wait_for(
"reaction_add", timeout=60, check=check
)
except asyncio.TimeoutError:
try:
await self._message.clear_reactions()
@ -420,34 +426,42 @@ class Book:
raise asyncio.CancelledError
else:
await self._message.remove_reaction(reaction, user)
if str(reaction.emoji) == self._bot.book_emojis['close']:
if str(reaction.emoji) == self._bot.book_emojis["close"]:
await self._calling_message.delete()
await self._message.delete()
raise asyncio.CancelledError
elif str(reaction.emoji) == self._bot.book_emojis['forward']:
elif str(reaction.emoji) == self._bot.book_emojis["forward"]:
self.advance_page()
elif str(reaction.emoji) == self._bot.book_emojis['back']:
elif str(reaction.emoji) == self._bot.book_emojis["back"]:
self.reverse_page()
elif str(reaction.emoji) == self._bot.book_emojis['end']:
elif str(reaction.emoji) == self._bot.book_emojis["end"]:
self._current_page = self._len_pages - 1
elif str(reaction.emoji) == self._bot.book_emojis['start']:
elif str(reaction.emoji) == self._bot.book_emojis["start"]:
self._current_page = 0
elif str(reaction.emoji) == self._bot.book_emojis['hash']:
m = await self._channel.send(f'Please enter a number in range 1 to {self._len_pages}')
elif str(reaction.emoji) == self._bot.book_emojis["hash"]:
m = await self._channel.send(
f"Please enter a number in range 1 to {self._len_pages}"
)
def num_check(message):
if self._locked:
return message.content.isdigit() \
and 0 < int(message.content) <= self._len_pages \
and message.author == self._calling_message.author
else:
return message.content.isdigit() \
return (
message.content.isdigit()
and 0 < int(message.content) <= self._len_pages
and message.author == self._calling_message.author
)
else:
return (
message.content.isdigit()
and 0 < int(message.content) <= self._len_pages
)
try:
msg = await self._bot.wait_for('message', timeout=30, check=num_check)
msg = await self._bot.wait_for(
"message", timeout=30, check=num_check
)
except asyncio.TimeoutError:
await m.edit(content='Message Timed out.')
await m.edit(content="Message Timed out.")
else:
self._current_page = int(msg.content) - 1
try:
@ -455,9 +469,11 @@ class Book:
await msg.delete()
except (discord.Forbidden, discord.NotFound):
pass
elif str(reaction.emoji) == self._bot.book_emojis['unlock']:
elif str(reaction.emoji) == self._bot.book_emojis["unlock"]:
self._locked = False
await self._message.remove_reaction(reaction, self._channel.guild.me)
await self._message.remove_reaction(
reaction, self._channel.guild.me
)
continue
await self.display_page()