Add GeeksbotAPI class to manage queries to the API

This commit is contained in:
Dustin Pianalto 2019-12-30 22:40:31 -09:00
parent 41cc1a5367
commit 0abe44cb05
2 changed files with 80 additions and 1 deletions

View File

@ -12,6 +12,7 @@ import discord
from discord.ext import commands from discord.ext import commands
from discord.ext.commands.context import Context from discord.ext.commands.context import Context
from geeksbot.imports.strings import MyStringView from geeksbot.imports.strings import MyStringView
from geeksbot.imports.geeksbot_api import GeeksbotAPI
geeksbot_logger = logging.getLogger("Geeksbot") geeksbot_logger = logging.getLogger("Geeksbot")
@ -45,8 +46,8 @@ class Geeksbot(commands.Bot):
self.token = self.settings_cache.get("DISCORD_TOKEN") self.token = self.settings_cache.get("DISCORD_TOKEN")
self.api_token = self.settings_cache.get("API_TOKEN") self.api_token = self.settings_cache.get("API_TOKEN")
self.aio_session = aiohttp.ClientSession(loop=self.loop) self.aio_session = aiohttp.ClientSession(loop=self.loop)
self.auth_header = {"Authorization": f"Token {self.api_token}"}
self.api_base = "https://geeksbot.app/api" self.api_base = "https://geeksbot.app/api"
self.api = GeeksbotAPI(self.api_token, self.api_base, self.loop)
with open(f"{self.config_dir}/{self.config_file}") as f: with open(f"{self.config_dir}/{self.config_file}") as f:
self.bot_config = json.load(f) self.bot_config = json.load(f)
self.embed_color = discord.Colour.from_rgb(49, 107, 111) self.embed_color = discord.Colour.from_rgb(49, 107, 111)

View File

@ -0,0 +1,78 @@
import aiohttp
import asyncio
import logging
logger = logging.getLogger("Geeksbot API")
class APIError500(Exception):
pass
class APIError(Exception):
pass
class GeeksbotAPI:
def __init__(
self,
token: str,
base_url: str = "https://geeksbot.app/api",
loop: asyncio.AbstractEventLoop = None,
):
self.loop = loop or asyncio.get_event_loop()
self.session = aiohttp.ClientSession(loop=self.loop)
self.base_url = base_url
self.token = token
self.auth_header = {"Authorization": f"Token {self.token}"}
def clean_endpoint(endpoint: str) -> str:
endpoint = endpoint[1:] if endpoint.startswith("/") else endpoint
endpoint += "/" if not endpoint.endswith("/") else ""
return endpoint
async def query(
self, method: str, endpoint: str, data: dict = None, query: dict = None
):
endpoint = self.clean_endpoint(endpoint)
if method.lower() == "get":
resp = await self.session.get(
f"{self.base_url}/{endpoint}{query_str}",
headers=self.auth_header,
params=query,
json=data,
)
elif method.lower() == "post":
resp = await self.session.post(
f"{self.base_url}/{endpoint}{query_str}",
headers=self.auth_header,
params=query,
json=data,
)
elif method.lower() == "put":
resp = await self.session.put(
f"{self.base_url}/{endpoint}{query_str}",
headers=self.auth_header,
params=query,
json=data,
)
elif method.lower() == "delete":
resp = await self.session.delete(
f"{self.base_url}/{endpoint}{query_str}",
headers=self.auth_header,
params=query,
json=data,
)
else:
raise APIError(f"That is not a valid method. {method}")
if resp.status == 200:
return await resp.json()
elif resp.status >= 500:
raise APIError500(
"The server returned a 500 error. "
"The developers have been notified of the issue."
)
else:
details = (await resp.json())["details"]
raise APIError(details)