diff --git a/sebimachine/__main__.py b/sebimachine/__main__.py index 4300b3b..c866a27 100644 --- a/sebimachine/__main__.py +++ b/sebimachine/__main__.py @@ -26,6 +26,8 @@ from .shared_libs.loggable import Loggable # Init logging to output on INFO level to stderr. logging.basicConfig(level="INFO") +REBOOT_FILE = "sebimachine/config/reboot" + # If uvloop is installed, change to that eventloop policy as it # is more efficient @@ -61,6 +63,7 @@ class SebiMachine(commands.Bot, LoadConfig, Loggable): with open(in_here("config", "PrivateConfig.json")) as fp: self.bot_secrets = json.load(fp) self.db_con = database.DatabaseConnection(**self.bot_secrets["db-con"]) + self.failed_cogs_on_startup = {} self.book_emojis: Dict[str, str] = { "unlock": "🔓", "start": "⏮", @@ -74,22 +77,31 @@ class SebiMachine(commands.Bot, LoadConfig, Loggable): # Load plugins # Add your cog file name in this list with open(in_here("extensions.txt")) as cog_file: - cogs = cog_file.readlines() + cogs = {f'sebimachine.cogs.{c.strip()}' for c in cog_file.readlines()} for cog in cogs: - # Could this just be replaced with `strip()`? - cog = cog.replace("\n", "") - self.load_extension(f"src.cogs.{cog}") - self.logger.info(f"Loaded: {cog}") - + try: + self.load_extension(cog) + self.logger.info(f"Loaded: {cog}") + except (ModuleNotFoundError, ImportError) as ex: + logging.exception(f'Could not load {cog}', exc_info=(type(ex), ex, ex.__traceback__)) + self.failed_cogs_on_startup[cog] = ex + async def on_ready(self): """On ready function""" self.maintenance and self.logger.warning("MAINTENANCE ACTIVE") - with open(f"src/config/reboot", "r") as f: - reboot = f.readlines() - if int(reboot[0]) == 1: - await self.get_channel(int(reboot[1])).send("Restart Finished.") - with open(f"src/config/reboot", "w") as f: + if os.path.exists(REBOOT_FILE): + with open(REBOOT_FILE, "r") as f: + reboot = f.readlines() + if int(reboot[0]) == 1: + await self.get_channel(int(reboot[1])).send("Restart Finished.") + for cog, ex in self.failed_cogs_on_startup.items(): + tb = ''.join(traceback.format_exception(type(ex), ex, ex.__traceback__))[-1500:] + await ctx.send( + f'FAILED TO LOAD {cog} BECAUSE OF {type(ex).__name__}: {ex}\n' + f'{tb}' + ) + with open(REBOOT_FILE, "w") as f: f.write(f"0") async def on_command_error(self, ctx, error): diff --git a/sebimachine/cogs/bot_management.py b/sebimachine/cogs/bot_management.py index 7f9de50..09bf211 100644 --- a/sebimachine/cogs/bot_management.py +++ b/sebimachine/cogs/bot_management.py @@ -10,12 +10,8 @@ class BotManager: if member.bot is False: return else: - # Checks if the bot is in the database - if await self.bot.db_con.fetch('select count(*) from bots where id = $1', member.id) != 1: - return await member.kick() - - bot_owner = member.guild.get_member((await self.bot.db_con.fetchval('select owner from bots where id = $1', member.id)) - await bot_owner.send("Your bot has been approved and invited") + # The member is a bot + bot_owner = member.guild.get_member(await self.bot.db_con.fetchval('select owner from bots where id = $1', member.id)) await bot_owner.add_roles(discord.utils.get(member.guild.roles, name='Bot Developers')) await member.add_roles(discord.utils.get(member.guild.roles, name='Bots')) @@ -65,7 +61,7 @@ class BotManager: await ctx.send(embed=em) em = discord.Embed(title="Bot invite", colour=discord.Color(0x363941)) - em.description = discord.utils.oauth_url(client_id, permissions=None, guild=ctx.guild)) + em.description = discord.utils.oauth_url(client_id, permissions=None, guild=ctx.guild) em.set_thumbnail(url=bot.avatar_url) em.add_field(name="Bot name", value=bot.name) em.add_field(name="Bot id", value="`" + str(bot.id) + "`") diff --git a/sebimachine/cogs/tag.py b/sebimachine/cogs/tag.py index f94ae06..08ee6fd 100644 --- a/sebimachine/cogs/tag.py +++ b/sebimachine/cogs/tag.py @@ -8,7 +8,7 @@ import asyncio class Tag: def __init__(self, bot): self.bot = bot - with open("src/shared_libs/tags.json", "r") as fp: + with open("sebimachine/shared_libs/tags.json", "r") as fp: json_data = fp.read() global tags tags = json.loads(json_data) diff --git a/sebimachine/config/config.py b/sebimachine/config/config.py index df4be54..7370fbc 100644 --- a/sebimachine/config/config.py +++ b/sebimachine/config/config.py @@ -13,7 +13,7 @@ class LoadConfig: def __init__(self): # Read our config file - with open("src/config/Config.json") as fp: + with open("sebimachine/config/Config.json") as fp: self.config = json.load(fp) # Initialize config diff --git a/sebimachine/shared_libs/utils.py b/sebimachine/shared_libs/utils.py index 6281f43..487153b 100644 --- a/sebimachine/shared_libs/utils.py +++ b/sebimachine/shared_libs/utils.py @@ -156,11 +156,11 @@ class Capturing(list): def __exit__(self, *args): self.extend(self._stringio.getvalue().splitlines()) - del self._stringio # free up some memory + del self._stringio # free up some memory sys.stdout = self._stdout -def to_list_of_str(items, out: list = list(), level=1, recurse=0): +def to_list_of_str(items, out: list=list(), level=1, recurse=0): def rec_loop(item, key, out, level): quote = '"' if type(item) == list: @@ -174,42 +174,40 @@ def to_list_of_str(items, out: list = list(), level=1, recurse=0): out = to_list_of_str(item, out, new_level, 1) out.append(f'{" "*level}}}') else: - out.append( - f'{" "*level}{quote+key+quote+": " if key else ""}{repr(item)},' - ) + out.append(f'{" "*level}{quote+key+quote+": " if key else ""}{repr(item)},') if type(items) == list: if not recurse: out = list() - out.append("[") + out.append('[') for item in items: rec_loop(item, None, out, level) if not recurse: - out.append("]") + out.append(']') elif type(items) == dict: if not recurse: out = list() - out.append("{") + out.append('{') for key in items: rec_loop(items[key], key, out, level) if not recurse: - out.append("}") + out.append('}') return out def paginate(text, maxlen=1990): - paginator = Paginator(prefix="```py", max_size=maxlen + 10) + paginator = Paginator(prefix='```py', max_size=maxlen+10) if type(text) == list: data = to_list_of_str(text) elif type(text) == dict: data = to_list_of_str(text) else: - data = str(text).split("\n") + data = str(text).split('\n') for line in data: if len(line) > maxlen: n = maxlen - for l in [line[i : i + n] for i in range(0, len(line), n)]: + for l in [line[i:i+n] for i in range(0, len(line), n)]: paginator.add_line(l) else: paginator.add_line(line) @@ -221,8 +219,7 @@ async def run_command(args): process = await asyncio.create_subprocess_shell( args, # stdout must a pipe to be accessible as process.stdout - stdout=asyncio.subprocess.PIPE, - ) + stdout=asyncio.subprocess.PIPE) # Wait for the subprocess to finish stdout, stderr = await process.communicate() # Return stdout