Add git commands and reformatted files

This commit is contained in:
Dustin Pianalto 2019-12-26 20:56:00 -09:00
parent 2033e0767f
commit d1da4fc999
4 changed files with 388 additions and 200 deletions

142
geeksbot/exts/git.py Normal file
View File

@ -0,0 +1,142 @@
import asyncio
import logging
import discord
from discord.ext import commands
from geeksbot.imports.checks import is_me
from geeksbot.imports.utils import Book, Paginator, run_command
git_log = logging.getLogger("git")
class Git:
def __init__(self, bot):
self.bot = bot
@commands.group(case_insensitive=True, invoke_without_command=True)
async def git(self, ctx):
"""Shows my Git link"""
branch = (
await asyncio.wait_for(
self.bot.loop.create_task(
run_command(
"git rev-parse --symbolic-full-name " "--abbrev-ref HEAD"
)
),
120,
)
).split("\n")[0]
url = f"{self.bot.git_url}/tree/{branch}"
em = discord.Embed(
title=f"Here is where you can find my code",
url=url,
color=self.bot.embed_color,
)
if branch == "master":
em.description = (
"I am Geeksbot. You can find my code here:\n" f"{self.bot.git_url}"
)
else:
em.description = (
f"I am the {branch} branch of Geeksbot. "
f"You can find the master branch here:\n"
f"{self.bot.git_url}/tree/master"
)
em.set_thumbnail(url=f"{ctx.guild.me.avatar_url}")
await ctx.send(embed=em)
@git.command()
@is_me()
async def pull(self, ctx):
"""Pulls updates from GitHub rebasing branch."""
pag = Paginator(self.bot, max_line_length=100, embed=True)
pag.set_embed_meta(
title="Git Pull",
color=self.bot.embed_color,
thumbnail=f"{ctx.guild.me.avatar_url}",
)
pag.add(
"\uFFF6"
+ await asyncio.wait_for(
self.bot.loop.create_task(run_command("git fetch --all")), 120
)
)
pag.add("\uFFF7\n\uFFF8")
pag.add(
await asyncio.wait_for(
self.bot.loop.create_task(
run_command(
"git reset --hard "
"origin/$(git "
"rev-parse --symbolic-full-name"
" --abbrev-ref HEAD)"
)
),
120,
)
)
pag.add("\uFFF7\n\uFFF8")
pag.add(
await asyncio.wait_for(
self.bot.loop.create_task(
run_command("git show --stat | " 'sed "s/.*@.*[.].*/ /g"')
),
10,
)
)
book = Book(pag, (None, ctx.channel, self.bot, ctx.message))
await book.create_book()
@git.command()
@is_me()
async def status(self, ctx):
"""Gets status of current branch."""
pag = Paginator(self.bot, max_line_length=44, max_lines=30, embed=True)
pag.set_embed_meta(
title="Git Status",
color=self.bot.embed_color,
thumbnail=f"{ctx.guild.me.avatar_url}",
)
result = await asyncio.wait_for(
self.bot.loop.create_task(run_command("git status")), 120
)
pag.add(result)
book = Book(pag, (None, ctx.channel, self.bot, ctx.message))
await book.create_book()
@git.command()
@is_me()
async def checkout(self, ctx, branch: str = "master"):
"""Checks out the requested branch.
If no branch name is provided will checkout the master branch"""
pag = Paginator(self.bot, max_line_length=44, max_lines=30, embed=True)
branches_str = await asyncio.wait_for(
self.bot.loop.create_task(run_command(f"git branch -a")), 120
)
existing_branches = [
b.strip().split("/")[-1]
for b in branches_str.replace("*", "").split("\n")[:-1]
]
if branch not in existing_branches:
pag.add(f"There is no existing branch named {branch}.")
pag.set_embed_meta(
title="Git Checkout",
color=self.bot.error_color,
thumbnail=f"{ctx.guild.me.avatar_url}",
)
else:
pag.set_embed_meta(
title="Git Checkout",
color=self.bot.embed_color,
thumbnail=f"{ctx.guild.me.avatar_url}",
)
result = await asyncio.wait_for(
self.bot.loop.create_task(run_command(f"git checkout -f {branch}")), 120
)
pag.add(result)
book = Book(pag, (None, ctx.channel, self.bot, ctx.message))
await book.create_book()
def setup(bot):
bot.add_cog(Git(bot))

View File

@ -34,3 +34,10 @@ def is_admin():
return True return True
return False return False
return discord.ext.commands.check(predicate) return discord.ext.commands.check(predicate)
def is_production():
def predicate(ctx):
return not os.environ['DEBUG']
return discord.ext.commands.check(predicate)

View File

@ -1,74 +1,89 @@
import json
import logging
import os import os
import time import time
import json
from concurrent import futures from concurrent import futures
from multiprocessing import Pool from multiprocessing import Pool
import logging
import redis
import aiohttp
import discord import discord
from discord.ext import commands from discord.ext import commands
from discord.ext.commands.context import Context from discord.ext.commands.context import Context
from geeksbot.imports.strings import MyStringView from geeksbot.imports.strings import MyStringView
import redis
import aiohttp
geeksbot_logger = logging.getLogger('Geeksbot') geeksbot_logger = logging.getLogger("Geeksbot")
class Geeksbot(commands.Bot): class Geeksbot(commands.Bot):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
self.default_prefix = os.environ.get('DISCORD_DEFAULT_PREFIX', 'g$') self.default_prefix = os.environ.get("DISCORD_DEFAULT_PREFIX", "g$")
kwargs['command_prefix'] = self.get_prefixes kwargs["command_prefix"] = self.get_prefixes
self.description = "Geeksbot v2" self.description = "Geeksbot v2"
kwargs['description'] = self.description kwargs["description"] = self.description
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.config_dir = 'geeksbot/config' self.config_dir = "geeksbot/config"
self.config_file = 'bot_config.json' self.config_file = "bot_config.json"
self.extension_dir = 'exts' self.extension_dir = "exts"
self.cache = redis.Redis(host=os.environ['REDIS_HOST'], port=os.environ['REDIS_PORT'], db=os.environ['REDIS_DB'], charset="utf-8", decode_responses=True, health_check_interval=10) self.cache = redis.Redis(
self.settings_cache = redis.Redis(host=os.environ['REDIS_HOST'], port=os.environ['REDIS_PORT'], db=1, charset="utf-8", decode_responses=True, health_check_interval=10) host=os.environ["REDIS_HOST"],
self.token = self.settings_cache.get('DISCORD_TOKEN') port=os.environ["REDIS_PORT"],
self.api_token = self.settings_cache.get('API_TOKEN') db=os.environ["REDIS_DB"],
charset="utf-8",
decode_responses=True,
health_check_interval=10,
)
self.settings_cache = redis.Redis(
host=os.environ["REDIS_HOST"],
port=os.environ["REDIS_PORT"],
db=1,
charset="utf-8",
decode_responses=True,
health_check_interval=10,
)
self.token = self.settings_cache.get("DISCORD_TOKEN")
self.api_token = self.settings_cache.get("API_TOKEN")
self.aio_session = aiohttp.ClientSession(loop=self.loop) self.aio_session = aiohttp.ClientSession(loop=self.loop)
self.auth_header = {'Authorization': f'Token {self.api_token}'} self.auth_header = {"Authorization": f"Token {self.api_token}"}
self.api_base = 'https://geeksbot.app/api' self.api_base = "https://geeksbot.app/api"
with open(f'{self.config_dir}/{self.config_file}') as f: with open(f"{self.config_dir}/{self.config_file}") as f:
self.bot_config = json.load(f) self.bot_config = json.load(f)
self.embed_color = discord.Colour.from_rgb(49, 107, 111) self.embed_color = discord.Colour.from_rgb(49, 107, 111)
self.error_color = discord.Colour.from_rgb(142, 29, 31) self.error_color = discord.Colour.from_rgb(142, 29, 31)
self.tpe = futures.ThreadPoolExecutor(max_workers=20) self.tpe = futures.ThreadPoolExecutor(max_workers=20)
self.process_pool = Pool(processes=4) self.process_pool = Pool(processes=4)
self.geo_api = '2d4e419c2be04c8abe91cb5dd1548c72' self.geo_api = "2d4e419c2be04c8abe91cb5dd1548c72"
self.git_url = 'https://github.com/dustinpianalto/geeksbot_v2' self.git_url = "https://github.com/dustinpianalto/geeksbot_v2"
self.load_default_extensions() self.load_default_extensions()
self.owner_id = 351794468870946827 self.owner_id = 351794468870946827
self.success_emoji = '\N{WHITE HEAVY CHECK MARK}' self.success_emoji = "\N{WHITE HEAVY CHECK MARK}"
self.book_emojis = { self.book_emojis = {
'unlock': '🔓', "unlock": "🔓",
'start': '', "start": "",
'back': '', "back": "",
'hash': '#\N{COMBINING ENCLOSING KEYCAP}', "hash": "#\N{COMBINING ENCLOSING KEYCAP}",
'forward': '', "forward": "",
'end': '', "end": "",
'close': '🇽', "close": "🇽",
} }
async def load_ext(self, mod): async def load_ext(self, mod):
try: try:
self.load_extension(f'geeksbot.{self.extension_dir}.{mod}') self.load_extension(f"geeksbot.{self.extension_dir}.{mod}")
geeksbot_logger.info(f'Extension Loaded: {mod}') geeksbot_logger.info(f"Extension Loaded: {mod}")
except Exception: except Exception:
geeksbot_logger.exception(f"Error loading {mod}") geeksbot_logger.exception(f"Error loading {mod}")
async def unload_ext(self, mod): async def unload_ext(self, mod):
try: try:
self.unload_extension(f'geeksbot.{self.extension_dir}.{mod}') self.unload_extension(f"geeksbot.{self.extension_dir}.{mod}")
geeksbot_logger.info(f'Extension Unloaded: {mod}') geeksbot_logger.info(f"Extension Unloaded: {mod}")
except Exception: except Exception:
geeksbot_logger.exception(f"Error loading {mod}") geeksbot_logger.exception(f"Error loading {mod}")
def load_default_extensions(self): def load_default_extensions(self):
for load_item in self.bot_config['load_list']: for load_item in self.bot_config["load_list"]:
self.loop.create_task(self.load_ext(load_item)) self.loop.create_task(self.load_ext(load_item))
async def get_prefixes(self, bot, message): async def get_prefixes(self, bot, message):
@ -94,7 +109,7 @@ class Geeksbot(commands.Bot):
cls cls
The factory class that will be used to create the context. The factory class that will be used to create the context.
By default, this is :class:`.Context`. Should a custom By default, this is :class:`.Context`. Should a custom
class be provided, it must be similar enough to :class:`.Context`\'s class be provided, it must be similar enough to :class:`.Context`'s
interface. interface.
Returns Returns
@ -118,8 +133,8 @@ class Geeksbot(commands.Bot):
return ctx return ctx
else: else:
try: try:
# if the context class' __init__ consumes something from the view this # if the context class' __init__ consumes something from the
# will be wrong. That seems unreasonable though. # view this will be wrong. That seems unreasonable though.
if message.content.casefold().startswith(tuple(prefix)): if message.content.casefold().startswith(tuple(prefix)):
invoked_prefix = discord.utils.find(view.skip_string, prefix) invoked_prefix = discord.utils.find(view.skip_string, prefix)
else: else:
@ -127,14 +142,22 @@ class Geeksbot(commands.Bot):
except TypeError: except TypeError:
if not isinstance(prefix, list): if not isinstance(prefix, list):
raise TypeError("get_prefix must return either a string or a list of string, " raise TypeError(
"not {}".format(prefix.__class__.__name__)) "get_prefix must return either a string "
"or a list of string, "
"not {}".format(prefix.__class__.__name__)
)
# It's possible a bad command_prefix got us here. # It's possible a bad command_prefix got us here.
for value in prefix: for value in prefix:
if not isinstance(value, str): if not isinstance(value, str):
raise TypeError("Iterable command_prefix or list returned from get_prefix must " raise TypeError(
"contain only strings, not {}".format(value.__class__.__name__)) "Iterable command_prefix or list "
"returned from get_prefix must "
"contain only strings, not {}".format(
value.__class__.__name__
)
)
# Getting here shouldn't happen # Getting here shouldn't happen
raise raise
@ -149,6 +172,6 @@ class Geeksbot(commands.Bot):
await self.aio_session.close() await self.aio_session.close()
await super().close() await super().close()
time.sleep(5) time.sleep(5)
geeksbot_logger.info('Exiting...') geeksbot_logger.info("Exiting...")
# noinspection PyProtectedMember # noinspection PyProtectedMember
os._exit(1) os._exit(1)

View File

@ -4,17 +4,15 @@ import typing
from datetime import datetime from datetime import datetime
async def get_guild_config(bot, guild_id):
guild_config = bot.cache.get()
def create_date_string(time, time_now): def create_date_string(time, time_now):
diff = (time_now - time) diff = time_now - time
date_str = time.strftime('%Y-%m-%d %H:%M:%S') date_str = time.strftime("%Y-%m-%d %H:%M:%S")
return f"{diff.days} {'day' if diff.days == 1 else 'days'} " \ return (
f"{diff.seconds // 3600} {'hour' if diff.seconds // 3600 == 1 else 'hours'} " \ f"{diff.days} {'day' if diff.days == 1 else 'days'} "
f"{diff.seconds % 3600 // 60} {'minute' if diff.seconds % 3600 // 60 == 1 else 'minutes'} " \ f"{diff.seconds // 3600} {'hour' if diff.seconds // 3600 == 1 else 'hours'} "
f"{diff.seconds % 3600 // 60} {'minute' if diff.seconds % 3600 // 60 == 1 else 'minutes'} "
f"{diff.seconds % 3600 % 60} {'second' if diff.seconds % 3600 % 60 == 1 else 'seconds'} ago.\n{date_str}" f"{diff.seconds % 3600 % 60} {'second' if diff.seconds % 3600 % 60 == 1 else 'seconds'} ago.\n{date_str}"
)
def process_snowflake(snowflake: int) -> typing.Tuple[datetime, int, int, int]: def process_snowflake(snowflake: int) -> typing.Tuple[datetime, int, int, int]:
@ -25,71 +23,28 @@ def process_snowflake(snowflake: int) -> typing.Tuple[datetime, int, int, int]:
PROCESS_ID_LOC = 12 PROCESS_ID_LOC = 12
PROCESS_ID_MASK = 0x1F000 PROCESS_ID_MASK = 0x1F000
INCREMENT_MASK = 0xFFF INCREMENT_MASK = 0xFFF
creation_time = datetime.fromtimestamp(((snowflake >> TIME_BITS_LOC) + DISCORD_EPOCH) / 1000.0) creation_time = datetime.fromtimestamp(
((snowflake >> TIME_BITS_LOC) + DISCORD_EPOCH) / 1000.0
)
worker_id = (snowflake & WORKER_ID_MASK) >> WORKER_ID_LOC worker_id = (snowflake & WORKER_ID_MASK) >> WORKER_ID_LOC
process_id = (snowflake & PROCESS_ID_MASK) >> PROCESS_ID_LOC process_id = (snowflake & PROCESS_ID_MASK) >> PROCESS_ID_LOC
counter = snowflake & INCREMENT_MASK counter = snowflake & INCREMENT_MASK
return creation_time, worker_id, process_id, counter return creation_time, worker_id, process_id, counter
# noinspection PyDefaultArgument
def to_list_of_str(items, out: list = list(), level=1, recurse=0):
# noinspection PyShadowingNames
def rec_loop(item, key, out, level):
quote = '"'
if type(item) == list:
out.append(f'{" "*level}{quote+key+quote+": " if key else ""}[')
new_level = level + 1
out = to_list_of_str(item, out, new_level, 1)
out.append(f'{" "*level}]')
elif type(item) == dict:
out.append(f'{" "*level}{quote+key+quote+": " if key else ""}{{')
new_level = level + 1
out = to_list_of_str(item, out, new_level, 1)
out.append(f'{" "*level}}}')
else:
out.append(f'{" "*level}{quote+key+quote+": " if key else ""}{repr(item)},')
if type(items) == list:
if not recurse:
out = list()
out.append('[')
for item in items:
rec_loop(item, None, out, level)
if not recurse:
out.append(']')
elif type(items) == dict:
if not recurse:
out = list()
out.append('{')
for key in items:
rec_loop(items[key], key, out, level)
if not recurse:
out.append('}')
return out
def format_output(text):
if type(text) == list:
text = to_list_of_str(text)
elif type(text) == dict:
text = to_list_of_str(text)
return text
async def run_command(args): async def run_command(args):
# Create subprocess # Create subprocess
process = await asyncio.create_subprocess_shell( process = await asyncio.create_subprocess_shell(
f'time -f "Process took %e seconds (%U user | %S system) and used %P of the CPU" {args}', f'time -f "Process took %e seconds (%U user | %S system) and used %P of the CPU" {args}',
# stdout must a pipe to be accessible as process.stdout # stdout must a pipe to be accessible as process.stdout
stdout=asyncio.subprocess.PIPE, stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE) stderr=asyncio.subprocess.PIPE,
)
# Wait for the subprocess to finish # Wait for the subprocess to finish
stdout, stderr = await process.communicate() stdout, stderr = await process.communicate()
# Return stdout # Return stdout
if stderr and stderr.strip() != '': if stderr and stderr.strip() != "":
output = f'{stdout.decode().strip()}\n{stderr.decode().strip()}' output = f"{stdout.decode().strip()}\n{stderr.decode().strip()}"
else: else:
output = stdout.decode().strip() output = stdout.decode().strip()
return output return output
@ -97,29 +52,34 @@ async def run_command(args):
# noinspection PyShadowingNames # noinspection PyShadowingNames
class Paginator: class Paginator:
def __init__(self, def __init__(
self,
bot: discord.ext.commands.Bot, bot: discord.ext.commands.Bot,
*, *,
max_chars: int = 1970, max_chars: int = 1970,
max_lines: int = 20, max_lines: int = 20,
prefix: str = '```md', prefix: str = "```md",
suffix: str = '```', suffix: str = "```",
page_break: str = '\uFFF8', page_break: str = "\uFFF8",
field_break: str = '\uFFF7', field_break: str = "\uFFF7",
field_name_char: str = '\uFFF6', field_name_char: str = "\uFFF6",
inline_char: str = '\uFFF5', inline_char: str = "\uFFF5",
max_line_length: int = 100, max_line_length: int = 100,
embed=False, embed=False,
header: str = ''): header: str = "",
):
_max_len = 6000 if embed else 1980 _max_len = 6000 if embed else 1980
assert 0 < max_lines <= max_chars assert 0 < max_lines <= max_chars
self._parts = list() self._parts = list()
self._prefix = prefix self._prefix = prefix
self._suffix = suffix self._suffix = suffix
self._max_chars = max_chars if max_chars + len(prefix) + len(suffix) + 2 <= _max_len \ self._max_chars = (
max_chars
if max_chars + len(prefix) + len(suffix) + 2 <= _max_len
else _max_len - len(prefix) - len(suffix) - 2 else _max_len - len(prefix) - len(suffix) - 2
self._max_lines = max_lines - (prefix + suffix).count('\n') + 1 )
self._max_lines = max_lines - (prefix + suffix).count("\n") + 1
self._page_break = page_break self._page_break = page_break
self._max_line_length = max_line_length self._max_line_length = max_line_length
self._pages = list() self._pages = list()
@ -130,26 +90,29 @@ class Paginator:
self._field_break = field_break self._field_break = field_break
self._field_name_char = field_name_char self._field_name_char = field_name_char
self._inline_char = inline_char self._inline_char = inline_char
self._embed_title = '' self._embed_title = ""
self._embed_description = '' self._embed_description = ""
self._embed_color = None self._embed_color = None
self._embed_thumbnail = None self._embed_thumbnail = None
self._embed_url = None self._embed_url = None
self._bot = bot self._bot = bot
self._header = header self._header = header
def set_embed_meta(self, title: str = None, def set_embed_meta(
self,
title: str = None,
description: str = None, description: str = None,
color: discord.Colour = None, color: discord.Colour = None,
thumbnail: str = None, thumbnail: str = None,
footer: str = '', footer: str = "",
url: str = None): url: str = None,
):
if title and len(title) > self._max_field_name: if title and len(title) > self._max_field_name:
raise RuntimeError('Provided Title is too long') raise RuntimeError("Provided Title is too long")
else: else:
self._embed_title = title self._embed_title = title
if description and len(description) > self._max_description: if description and len(description) > self._max_description:
raise RuntimeError('Provided Description is too long') raise RuntimeError("Provided Description is too long")
else: else:
self._embed_description = description self._embed_description = description
self._embed_color = color self._embed_color = color
@ -159,10 +122,10 @@ class Paginator:
def pages(self, page_headers: bool = True) -> typing.List[str]: def pages(self, page_headers: bool = True) -> typing.List[str]:
_pages = list() _pages = list()
_fields = list() _fields = list()
_page = '' _page = ""
_lines = 0 _lines = 0
_field_name = '' _field_name = ""
_field_value = '' _field_value = ""
_inline = False _inline = False
def open_page(initial: bool = False): def open_page(initial: bool = False):
@ -173,7 +136,7 @@ class Paginator:
elif page_headers: elif page_headers:
_page = self._header _page = self._header
else: else:
_page = '' _page = ""
_page += self._prefix _page += self._prefix
_lines = 0 _lines = 0
else: else:
@ -200,12 +163,13 @@ class Paginator:
if new_chars > self._max_chars: if new_chars > self._max_chars:
close_page() close_page()
elif (_lines + (part.count('\n') + 1 or 1)) > self._max_lines: elif (_lines + (part.count("\n") + 1 or 1)) > self._max_lines:
close_page() close_page()
_lines += (part.count('\n') + 1 or 1) _lines += part.count("\n") + 1 or 1
_page += '\n' + part _page += "\n" + part
else: else:
def open_field(name: str): def open_field(name: str):
nonlocal _field_value, _field_name nonlocal _field_value, _field_name
_field_name = name _field_name = name
@ -215,28 +179,30 @@ class Paginator:
nonlocal _field_name, _field_value, _fields nonlocal _field_name, _field_value, _fields
_field_value += self._suffix _field_value += self._suffix
if _field_value != self._prefix + self._suffix: if _field_value != self._prefix + self._suffix:
_fields.append({'name': _field_name, 'value': _field_value, 'inline': _inline}) _fields.append(
{"name": _field_name, "value": _field_value, "inline": _inline}
)
if next_name: if next_name:
open_field(next_name) open_field(next_name)
open_field('\uFFF0') open_field("\uFFF0")
for part in [str(p) for p in self._parts]: for part in [str(p) for p in self._parts]:
if part.strip().startswith(self._page_break): if part.strip().startswith(self._page_break):
close_page() close_page()
elif part == self._field_break: elif part == self._field_break:
if len(_fields) + 1 < 25: if len(_fields) + 1 < 25:
close_field(next_name='\uFFF0') close_field(next_name="\uFFF0")
else: else:
close_field() close_field()
close_page() close_page()
continue continue
if part.startswith(self._field_name_char): if part.startswith(self._field_name_char):
part = part.replace(self._field_name_char, '') part = part.replace(self._field_name_char, "")
if part.startswith(self._inline_char): if part.startswith(self._inline_char):
_inline = True _inline = True
part = part.replace(self._inline_char, '') part = part.replace(self._inline_char, "")
else: else:
_inline = False _inline = False
if _field_value and _field_value != self._prefix: if _field_value and _field_value != self._prefix:
@ -245,7 +211,7 @@ class Paginator:
_field_name = part _field_name = part
continue continue
_field_value += '\n' + part _field_value += "\n" + part
close_field() close_field()
@ -257,14 +223,15 @@ class Paginator:
def process_pages(self) -> typing.List[str]: def process_pages(self) -> typing.List[str]:
_pages = self._pages or self.pages() _pages = self._pages or self.pages()
_len_pages = len(_pages) _len_pages = len(_pages)
_len_page_str = len(f'{_len_pages}/{_len_pages}') _len_page_str = len(f"{_len_pages}/{_len_pages}")
if not self._embed: if not self._embed:
for i, page in enumerate(_pages): for i, page in enumerate(_pages):
if len(page) + _len_page_str <= 2000: if len(page) + _len_page_str <= 2000:
_pages[i] = f'{i + 1}/{_len_pages}\n{page}' _pages[i] = f"{i + 1}/{_len_pages}\n{page}"
else: else:
for i, page in enumerate(_pages): for i, page in enumerate(_pages):
em = discord.Embed(title=self._embed_title, em = discord.Embed(
title=self._embed_title,
description=self._embed_description, description=self._embed_description,
color=self._bot.embed_color, color=self._bot.embed_color,
) )
@ -274,9 +241,11 @@ class Paginator:
em.url = self._embed_url em.url = self._embed_url
if self._embed_color: if self._embed_color:
em.color = self._embed_color em.color = self._embed_color
em.set_footer(text=f'{i + 1}/{_len_pages}') em.set_footer(text=f"{i + 1}/{_len_pages}")
for field in page: for field in page:
em.add_field(name=field['name'], value=field['value'], inline=field['inline']) em.add_field(
name=field["name"], value=field["value"], inline=field["inline"]
)
_pages[i] = em _pages[i] = em
return _pages return _pages
@ -287,54 +256,71 @@ class Paginator:
# noinspection PyProtectedMember # noinspection PyProtectedMember
return self.__class__ == other.__class__ and self._parts == other._parts return self.__class__ == other.__class__ and self._parts == other._parts
def set_header(self, header: str = ''): def set_header(self, header: str = ""):
self._header = header self._header = header
def add_page_break(self, *, to_beginning: bool = False) -> None: def add_page_break(self, *, to_beginning: bool = False) -> None:
self.add(self._page_break, to_beginning=to_beginning) self.add(self._page_break, to_beginning=to_beginning)
def add(self, item: typing.Any, *, to_beginning: bool = False, keep_intact: bool = False, truncate=False) -> None: def add(
self,
item: typing.Any,
*,
to_beginning: bool = False,
keep_intact: bool = False,
truncate=False,
) -> None:
item = str(item) item = str(item)
i = 0 i = 0
if not keep_intact and not item == self._page_break: if not keep_intact and not item == self._page_break:
item_parts = item.strip('\n').split('\n') item_parts = item.strip("\n").split("\n")
for part in item_parts: for part in item_parts:
if len(part) > self._max_line_length: if len(part) > self._max_line_length:
if not truncate: if not truncate:
length = 0 length = 0
out_str = '' out_str = ""
def close_line(line): def close_line(line):
nonlocal i, out_str, length nonlocal i, out_str, length
self._parts.insert(i, out_str) if to_beginning else self._parts.append(out_str) self._parts.insert(
i, out_str
) if to_beginning else self._parts.append(out_str)
i += 1 i += 1
out_str = line + ' ' out_str = line + " "
length = len(out_str) length = len(out_str)
bits = part.split(' ') bits = part.split(" ")
for bit in bits: for bit in bits:
next_len = length + len(bit) + 1 next_len = length + len(bit) + 1
if next_len <= self._max_line_length: if next_len <= self._max_line_length:
out_str += bit + ' ' out_str += bit + " "
length = next_len length = next_len
elif len(bit) > self._max_line_length: elif len(bit) > self._max_line_length:
if out_str: if out_str:
close_line(line='') close_line(line="")
for out_str in [bit[i:i + self._max_line_length] for out_str in [
for i in range(0, len(bit), self._max_line_length)]: bit[i : i + self._max_line_length]
close_line('') for i in range(0, len(bit), self._max_line_length)
]:
close_line("")
else: else:
close_line(bit) close_line(bit)
close_line('') close_line("")
else: else:
line = f'{part:.{self._max_line_length-3}}...' line = f"{part:.{self._max_line_length-3}}..."
self._parts.insert(i, line) if to_beginning else self._parts.append(line) self._parts.insert(
i, line
) if to_beginning else self._parts.append(line)
else: else:
self._parts.insert(i, part) if to_beginning else self._parts.append(part) self._parts.insert(i, part) if to_beginning else self._parts.append(
part
)
i += 1 i += 1
elif keep_intact and not item == self._page_break: elif keep_intact and not item == self._page_break:
if len(item) >= self._max_chars or item.count('\n') > self._max_lines: if len(item) >= self._max_chars or item.count("\n") > self._max_lines:
raise RuntimeError('{item} is too long to keep on a single page and is marked to keep intact.') raise RuntimeError(
"{item} is too long to keep on a single page and is marked to keep intact."
)
if to_beginning: if to_beginning:
self._parts.insert(0, item) self._parts.insert(0, item)
else: else:
@ -347,17 +333,23 @@ class Paginator:
class Book: class Book:
def __init__(self, pag: Paginator, ctx: typing.Tuple[typing.Optional[discord.Message], def __init__(
self,
pag: Paginator,
ctx: typing.Tuple[
typing.Optional[discord.Message],
discord.TextChannel, discord.TextChannel,
discord.ext.commands.Bot, discord.ext.commands.Bot,
discord.Message]) -> None: discord.Message,
],
) -> None:
self._pages = pag.process_pages() self._pages = pag.process_pages()
self._len_pages = len(self._pages) self._len_pages = len(self._pages)
self._current_page = 0 self._current_page = 0
self._message, self._channel, self._bot, self._calling_message = ctx self._message, self._channel, self._bot, self._calling_message = ctx
self._locked = True self._locked = True
if pag == Paginator(self._bot): if pag == Paginator(self._bot):
raise RuntimeError('Cannot create a book out of an empty Paginator.') raise RuntimeError("Cannot create a book out of an empty Paginator.")
def advance_page(self) -> None: def advance_page(self) -> None:
self._current_page += 1 self._current_page += 1
@ -372,14 +364,22 @@ class Book:
async def display_page(self) -> None: async def display_page(self) -> None:
if isinstance(self._pages[self._current_page], discord.Embed): if isinstance(self._pages[self._current_page], discord.Embed):
if self._message: if self._message:
await self._message.edit(content=None, embed=self._pages[self._current_page]) await self._message.edit(
content=None, embed=self._pages[self._current_page]
)
else: else:
self._message = await self._channel.send(embed=self._pages[self._current_page]) self._message = await self._channel.send(
embed=self._pages[self._current_page]
)
else: else:
if self._message: if self._message:
await self._message.edit(content=self._pages[self._current_page], embed=None) await self._message.edit(
content=self._pages[self._current_page], embed=None
)
else: else:
self._message = await self._channel.send(self._pages[self._current_page]) self._message = await self._channel.send(
self._pages[self._current_page]
)
async def create_book(self) -> None: async def create_book(self) -> None:
# noinspection PyUnresolvedReferences # noinspection PyUnresolvedReferences
@ -387,12 +387,16 @@ class Book:
# noinspection PyShadowingNames # noinspection PyShadowingNames
def check(reaction, user): def check(reaction, user):
if self._locked: if self._locked:
return str(reaction.emoji) in self._bot.book_emojis.values() \ return (
and user == self._calling_message.author \ str(reaction.emoji) in self._bot.book_emojis.values()
and user == self._calling_message.author
and reaction.message.id == self._message.id and reaction.message.id == self._message.id
)
else: else:
return str(reaction.emoji) in self._bot.book_emojis.values() \ return (
str(reaction.emoji) in self._bot.book_emojis.values()
and reaction.message.id == self._message.id and reaction.message.id == self._message.id
)
await self.display_page() await self.display_page()
@ -404,14 +408,16 @@ class Book:
pass pass
else: else:
try: try:
await self._message.add_reaction(self._bot.book_emojis['unlock']) await self._message.add_reaction(self._bot.book_emojis["unlock"])
await self._message.add_reaction(self._bot.book_emojis['close']) await self._message.add_reaction(self._bot.book_emojis["close"])
except (discord.Forbidden, KeyError): except (discord.Forbidden, KeyError):
pass pass
while True: while True:
try: try:
reaction, user = await self._bot.wait_for('reaction_add', timeout=60, check=check) reaction, user = await self._bot.wait_for(
"reaction_add", timeout=60, check=check
)
except asyncio.TimeoutError: except asyncio.TimeoutError:
try: try:
await self._message.clear_reactions() await self._message.clear_reactions()
@ -420,34 +426,42 @@ class Book:
raise asyncio.CancelledError raise asyncio.CancelledError
else: else:
await self._message.remove_reaction(reaction, user) await self._message.remove_reaction(reaction, user)
if str(reaction.emoji) == self._bot.book_emojis['close']: if str(reaction.emoji) == self._bot.book_emojis["close"]:
await self._calling_message.delete() await self._calling_message.delete()
await self._message.delete() await self._message.delete()
raise asyncio.CancelledError raise asyncio.CancelledError
elif str(reaction.emoji) == self._bot.book_emojis['forward']: elif str(reaction.emoji) == self._bot.book_emojis["forward"]:
self.advance_page() self.advance_page()
elif str(reaction.emoji) == self._bot.book_emojis['back']: elif str(reaction.emoji) == self._bot.book_emojis["back"]:
self.reverse_page() self.reverse_page()
elif str(reaction.emoji) == self._bot.book_emojis['end']: elif str(reaction.emoji) == self._bot.book_emojis["end"]:
self._current_page = self._len_pages - 1 self._current_page = self._len_pages - 1
elif str(reaction.emoji) == self._bot.book_emojis['start']: elif str(reaction.emoji) == self._bot.book_emojis["start"]:
self._current_page = 0 self._current_page = 0
elif str(reaction.emoji) == self._bot.book_emojis['hash']: elif str(reaction.emoji) == self._bot.book_emojis["hash"]:
m = await self._channel.send(f'Please enter a number in range 1 to {self._len_pages}') m = await self._channel.send(
f"Please enter a number in range 1 to {self._len_pages}"
)
def num_check(message): def num_check(message):
if self._locked: if self._locked:
return message.content.isdigit() \ return (
and 0 < int(message.content) <= self._len_pages \ message.content.isdigit()
and message.author == self._calling_message.author
else:
return message.content.isdigit() \
and 0 < int(message.content) <= self._len_pages and 0 < int(message.content) <= self._len_pages
and message.author == self._calling_message.author
)
else:
return (
message.content.isdigit()
and 0 < int(message.content) <= self._len_pages
)
try: try:
msg = await self._bot.wait_for('message', timeout=30, check=num_check) msg = await self._bot.wait_for(
"message", timeout=30, check=num_check
)
except asyncio.TimeoutError: except asyncio.TimeoutError:
await m.edit(content='Message Timed out.') await m.edit(content="Message Timed out.")
else: else:
self._current_page = int(msg.content) - 1 self._current_page = int(msg.content) - 1
try: try:
@ -455,9 +469,11 @@ class Book:
await msg.delete() await msg.delete()
except (discord.Forbidden, discord.NotFound): except (discord.Forbidden, discord.NotFound):
pass pass
elif str(reaction.emoji) == self._bot.book_emojis['unlock']: elif str(reaction.emoji) == self._bot.book_emojis["unlock"]:
self._locked = False self._locked = False
await self._message.remove_reaction(reaction, self._channel.guild.me) await self._message.remove_reaction(
reaction, self._channel.guild.me
)
continue continue
await self.display_page() await self.display_page()