From 47ca65da91deddb7c19cd75a21ffb3cbcb216488 Mon Sep 17 00:00:00 2001 From: Dustin Pianalto Date: Mon, 21 May 2018 19:13:40 -0800 Subject: [PATCH] Switch to Asyncpg for db con --- geeksbot.py | 21 ++++++++++++--------- shared_libs/__init__.py | 0 shared_libs/database.py | 22 ++++++++++++++++++++++ 3 files changed, 34 insertions(+), 9 deletions(-) create mode 100644 shared_libs/__init__.py create mode 100644 shared_libs/database.py diff --git a/geeksbot.py b/geeksbot.py index c0af47b..57b2d4c 100644 --- a/geeksbot.py +++ b/geeksbot.py @@ -6,9 +6,8 @@ from datetime import datetime import json import aiohttp from googleapiclient.discovery import build -import asyncpg from concurrent import futures -import asyncio +from shared_libs import database log_format = '{asctime}.{msecs:03.0f}|{levelname:<8}|{name}::{message}' @@ -58,14 +57,14 @@ class Geeksbot(commands.Bot): self.infected = {} self.TOKEN = self.bot_secrets['token'] - async def connect_db(): - return await asyncpg.create_pool(host=self.bot_secrets['db_con']['host'], - database=self.bot_secrets['db_con']['db_name'], - user=self.bot_secrets['db_con']['user'], - password=self.bot_secrets['db_con']['password'], - loop=asyncio.get_event_loop()) + # async def connect_db(): + # return await asyncpg.create_pool(host=self.bot_secrets['db_con']['host'], + # database=self.bot_secrets['db_con']['db_name'], + # user=self.bot_secrets['db_con']['user'], + # password=self.bot_secrets['db_con']['password'], + # loop=asyncio.get_event_loop()) del self.bot_secrets['token'] - self.db_con = asyncio.get_event_loop().create_task(connect_db()) + self.db_con = database.DatabaseConnection(**self.bot_secrets['db_con']) self.default_prefix = 'g~' self.voice_chans = {} self.spam_list = {} @@ -80,6 +79,10 @@ class Geeksbot(commands.Bot): 'left_fist': '🤛', } + async def logout(self): + await self.db_con.close() + super().logout() + @staticmethod async def get_custom_prefix(bot_inst, message): return await bot_inst.db_con.fetchval('select prefix from guild_config where guild_id = $1', diff --git a/shared_libs/__init__.py b/shared_libs/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/shared_libs/database.py b/shared_libs/database.py new file mode 100644 index 0000000..52c6395 --- /dev/null +++ b/shared_libs/database.py @@ -0,0 +1,22 @@ +import asyncpg +import asyncio + + +class DatabaseConnection: + def __init__(self, host: str='localhost', port: int=5432, database: str='', username: str='', password: str=''): + if username == '' or password == '' or database == '': + raise RuntimeError('Username or Password are blank') + self.kwargs = {'host': host, 'port': port, 'database': database, 'username': username, 'password': password} + self._conn = asyncio.get_event_loop().run_until_complete(self.acquire()) + self.fetchval = self._conn.fetchval + self.execute = self._conn.execute + self.fetchall = self._conn.fetchall + self.fetchone = self._conn.fetchone + + async def acquire(self): + if not self._conn: + self._conn = await asyncpg.create_pool(**self.kwargs) + + async def close(self): + await self._conn.close() + self._conn = None