From 8d2e1ad72a9357b61bfd4e6357b5782b8c54ec55 Mon Sep 17 00:00:00 2001 From: retke Date: Mon, 26 Sep 2022 00:20:04 +0200 Subject: [PATCH] bot: add blacklist logic and global check --- ballsdex/core/bot.py | 31 ++++++++++++++++++++++++++++++- ballsdex/core/commands.py | 6 ++++++ 2 files changed, 36 insertions(+), 1 deletion(-) diff --git a/ballsdex/core/bot.py b/ballsdex/core/bot.py index 629771c16..8e4621404 100755 --- a/ballsdex/core/bot.py +++ b/ballsdex/core/bot.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import discord import logging @@ -8,6 +10,7 @@ from discord.ext import commands from ballsdex.core.dev import Dev +from ballsdex.core.models import BlacklistedID from ballsdex.core.commands import Core log = logging.getLogger("ballsdex.core.bot") @@ -15,6 +18,12 @@ PACKAGES = ["config", "players", "countryballs", "info"] +class CommandTree(app_commands.CommandTree): + async def interaction_check(self, interaction: discord.Interaction, /) -> bool: + bot = cast(BallsDexBot, interaction.client) + return await bot.blacklist_check(interaction) + + class BallsDexBot(commands.Bot): """ BallsDex Discord bot @@ -22,10 +31,11 @@ class BallsDexBot(commands.Bot): def __init__(self, command_prefix: str, dev: bool = False, **options): intents = discord.Intents(guilds=True, guild_messages=True, message_content=True) - super().__init__(command_prefix, intents=intents, **options) + super().__init__(command_prefix, intents=intents, tree_cls=CommandTree, **options) self._shutdown = 0 self.dev = dev self.tree.error(self.on_application_command_error) + self.blacklist: list[int] = [] async def on_shard_ready(self, shard_id: int): log.debug(f"Connected to shard #{shard_id}") @@ -54,9 +64,19 @@ def assign_ids_to_app_commands(self, synced_commands: list[app_commands.AppComma bot_command, cast(list[app_commands.AppCommandGroup], synced_command.options) ) + async def load_blacklist(self): + self.blacklist = ( + await BlacklistedID.all().only("discord_id").values_list("discord_id", flat=True) + ) # type: ignore + async def on_ready(self): assert self.user log.info(f"Successfully logged in as {self.user} ({self.user.id})!") + + await self.load_blacklist() + if self.blacklist: + log.info(f"{len(self.blacklist)} blacklisted users.") + log.info("Loading packages...") await self.add_cog(Core(self)) if self.dev: @@ -84,6 +104,14 @@ async def on_ready(self): log.info("No command to sync.") print("\n [bold][red]BallsDex bot[/red] [green]is now operational![/green][/bold]\n") + async def blacklist_check(self, interaction: discord.Interaction) -> bool: + if interaction.user.id in self.blacklist: + await interaction.response.send_message( + "You are blacklisted from the bot.", ephemeral=True + ) + return False + return True + async def on_command_error( self, context: commands.Context, exception: commands.errors.CommandError ): @@ -167,3 +195,4 @@ async def on_error(self, event_method: str, /, *args, **kwargs): f"Error in event {event_method}. Args: {formatted_args}. Kwargs: {formatted_kwargs}", exc_info=True, ) + self.tree.interaction_check diff --git a/ballsdex/core/commands.py b/ballsdex/core/commands.py index 56220952e..e97c7c439 100755 --- a/ballsdex/core/commands.py +++ b/ballsdex/core/commands.py @@ -52,3 +52,9 @@ async def reload(self, ctx: commands.Context, package: str): log.error(f"Failed to reload extension {package}", exc_info=True) else: await ctx.send("Extension reloaded.") + + @commands.command() + @commands.is_owner() + async def reloadblacklist(self, ctx: commands.Context): + await self.bot.load_blacklist() + await ctx.message.add_reaction("✅")