From e144014afb81039ae74bedc02d8deaf53860a555 Mon Sep 17 00:00:00 2001 From: 0xry4n Date: Fri, 20 Mar 2026 15:27:10 +0000 Subject: [PATCH 1/2] =?UTF-8?q?=F0=9F=90=9B=20Fix=20interaction=20failures?= =?UTF-8?q?=20stemming=20from=20bot=20reset=20/=20state=20loss?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/bot.py | 89 ++++++++++--- src/helpers/ban.py | 72 ++++++---- src/views/bandecisionview.py | 249 +++++++++++++++++++++++------------ 3 files changed, 287 insertions(+), 123 deletions(-) diff --git a/src/bot.py b/src/bot.py index f598844..c54d329 100644 --- a/src/bot.py +++ b/src/bot.py @@ -4,12 +4,27 @@ import discord from aiohttp import AsyncResolver, ClientSession, TCPConnector from discord import ( - ApplicationContext, Cog, DiscordException, Embed, Forbidden, Guild, HTTPException, Member, NotFound, User, + ApplicationContext, + Cog, + DiscordException, + Embed, + Forbidden, + Guild, + HTTPException, + Member, + NotFound, + User, ) from discord.ext.commands import Bot as DiscordBot from discord.ext.commands import ( - CommandNotFound, CommandOnCooldown, DefaultHelpCommand, MissingAnyRole, MissingPermissions, - MissingRequiredArgument, NoPrivateMessage, UserInputError, + CommandNotFound, + CommandOnCooldown, + DefaultHelpCommand, + MissingAnyRole, + MissingPermissions, + MissingRequiredArgument, + NoPrivateMessage, + UserInputError, ) from sqlalchemy.exc import NoResultFound from typing import TypeVar @@ -50,13 +65,15 @@ async def on_ready(self) -> None: if self.http_session is None: logger.debug("Starting the HTTP session") self.http_session = ClientSession( - connector=TCPConnector(resolver=AsyncResolver(), family=socket.AF_INET), - trace_configs=[trace_config] + connector=TCPConnector(resolver=AsyncResolver(), family=socket.AF_INET), + trace_configs=[trace_config], ) name = f"{self.user} (ID: {self.user.id})" devlog_msg = f"Connected {constants.emojis.partying_face}" - self.loop.create_task(self.send_log(devlog_msg, colour=constants.colours.bright_green)) + self.loop.create_task( + self.send_log(devlog_msg, colour=constants.colours.bright_green) + ) logger.info(f"Started bot as {name}") print("Loading ScheduledTasks cog...") @@ -66,6 +83,17 @@ async def on_ready(self) -> None: except Exception as e: print(f"Failed to load ScheduledTasks cog: {e}") + await self._register_persistent_views() + + async def _register_persistent_views(self) -> None: + """Re-register persistent UI views so buttons survive bot restarts.""" + from src.views.bandecisionview import register_ban_views + + try: + await register_ban_views(self) + except Exception: + logger.exception("Failed to register persistent ban decision views") + async def on_application_command(self, ctx: ApplicationContext) -> None: """A global handler cog.""" logger.debug(f"Command '{ctx.command}' received.") @@ -77,14 +105,18 @@ async def on_application_command(self, ctx: ApplicationContext) -> None: embed.add_field(name="Channel", value=ctx.channel.name, inline=True) await ctx.guild.get_channel(settings.channels.BOT_LOGS).send(embed=embed) - async def on_application_command_error(self, ctx: ApplicationContext, error: DiscordException) -> None: + async def on_application_command_error( + self, ctx: ApplicationContext, error: DiscordException + ) -> None: """A global error handler cog.""" message = None if isinstance(error, CommandNotFound): return if isinstance(error, MissingRequiredArgument): - message = f"Parameter '{error.param.name}' is required, but missing. Type `{ctx.clean_prefix}help " \ - f"{ctx.invoked_with}` for help." + message = ( + f"Parameter '{error.param.name}' is required, but missing. Type `{ctx.clean_prefix}help " + f"{ctx.invoked_with}` for help." + ) elif isinstance(error, MissingPermissions): message = "You are missing the required permissions to run this command." elif isinstance(error, MissingAnyRole): @@ -120,7 +152,9 @@ def add_cog(self, cog: Cog, *, override: bool = False) -> None: super().add_cog(cog, override=override) logger.debug(f"Cog loaded: {cog.qualified_name}") - async def send_log(self, description: str = None, colour: int = None, embed: Embed = None) -> None: + async def send_log( + self, description: str = None, colour: int = None, embed: Embed = None + ) -> None: """Send an embed message to the devlog channel.""" devlog = self.get_channel(settings.channels.DEVLOG) @@ -159,27 +193,46 @@ async def get_member_or_user(self, guild: Guild, id_: int) -> Member | User | No try: return await guild.fetch_member(id_) except Forbidden as exc: - logger.warning(f"Unauthorized attempt to fetch member with id: {id_}", exc_info=exc) + logger.warning( + f"Unauthorized attempt to fetch member with id: {id_}", exc_info=exc + ) except NotFound as exc: logger.warning(f"Could not find guild member with id: {id_}", exc_info=exc) try: return await self.get_or_fetch_user(id_) except Forbidden as exc: - logger.warning(f"Unauthorized attempt to fetch member with id: {id_}", exc_info=exc) + logger.warning( + f"Unauthorized attempt to fetch member with id: {id_}", exc_info=exc + ) except NotFound as exc: - logger.warning(f"Could not find guild member with id: {id_}", exc_info=exc) + logger.warning( + f"Could not find guild member with id: {id_}", exc_info=exc + ) except HTTPException as exc: - logger.error(f"Discord error while fetching guild member with id: {id_}", exc_info=exc) + logger.error( + f"Discord error while fetching guild member with id: {id_}", + exc_info=exc, + ) except HTTPException as exc: - logger.error(f"Discord error while fetching guild member with id: {id_}", exc_info=exc) + logger.error( + f"Discord error while fetching guild member with id: {id_}", + exc_info=exc, + ) try: return await self.get_or_fetch_user(id_) except Forbidden as exc: - logger.warning(f"Unauthorized attempt to fetch member with id: {id_}", exc_info=exc) + logger.warning( + f"Unauthorized attempt to fetch member with id: {id_}", exc_info=exc + ) except NotFound as exc: - logger.warning(f"Could not find guild member with id: {id_}", exc_info=exc) + logger.warning( + f"Could not find guild member with id: {id_}", exc_info=exc + ) except HTTPException as exc: - logger.error(f"Discord error while fetching guild member with id: {id_}", exc_info=exc) + logger.error( + f"Discord error while fetching guild member with id: {id_}", + exc_info=exc, + ) return None diff --git a/src/helpers/ban.py b/src/helpers/ban.py index 92fe22e..763a0dc 100644 --- a/src/helpers/ban.py +++ b/src/helpers/ban.py @@ -39,7 +39,10 @@ class BanCodes(Enum): async def _check_member( - bot: Bot, guild: Guild, member: Member | User, author: Member | ClientUser | None = None + bot: Bot, + guild: Guild, + member: Member | User, + author: Member | ClientUser | None = None, ) -> SimpleResponse | None: if isinstance(member, Member): if member_is_staff(member): @@ -70,7 +73,9 @@ async def get_ban(member: Member | User) -> Ban | None: async def update_ban(ban: Ban) -> None: - logger.info(f"Updating ban {ban.id} for user {ban.user_id} with expiration {ban.unban_time}") + logger.info( + f"Updating ban {ban.id} for user {ban.user_id} with expiration {ban.unban_time}" + ) async with AsyncSessionLocal() as session: session.add(ban) await session.commit() @@ -156,7 +161,7 @@ async def handle_platform_ban_or_update( extra_log_data: dict | None = None, ) -> dict: """Handle platform ban by either creating new ban, updating existing ban, or taking no action. - + Args: bot: The Discord bot instance guild: The guild to ban the member from @@ -169,30 +174,44 @@ async def handle_platform_ban_or_update( log_channel_id: Channel ID for logging ban actions logger: Logger instance for recording events extra_log_data: Additional data to include in log entries - + Returns: dict with 'action' key indicating what was done: 'unbanned', 'extended', 'no_action', 'updated', 'created' """ if extra_log_data is None: extra_log_data = {} - + expires_dt = datetime.fromtimestamp(expires_timestamp) - + existing_ban = await get_ban(member) if not existing_ban: # No existing ban, create new one await ban_member_with_epoch( - bot, guild, member, expires_timestamp, reason, evidence, needs_approval=False + bot, + guild, + member, + expires_timestamp, + reason, + evidence, + needs_approval=False, ) await _send_ban_notice( - guild, member, reason, author_name, expires_at_str, guild.get_channel(log_channel_id) # type: ignore + guild, + member, + reason, + author_name, + expires_at_str, + guild.get_channel(log_channel_id), # type: ignore + ) + logger.info( + f"Created new platform ban for user {member.id} until {expires_at_str}", + extra=extra_log_data, ) - logger.info(f"Created new platform ban for user {member.id} until {expires_at_str}", extra=extra_log_data) return {"action": "created"} - + # Existing ban found - determine what to do based on ban type and timing is_platform_ban = existing_ban.reason.startswith("Platform Ban") - + if is_platform_ban: # Platform bans have authority over other platform bans if expires_dt < datetime.now(): @@ -202,7 +221,7 @@ async def handle_platform_ban_or_update( await guild.get_channel(log_channel_id).send(msg) # type: ignore logger.info(msg, extra=extra_log_data) return {"action": "unbanned"} - + if existing_ban.unban_time < expires_timestamp: # Extend the existing platform ban existing_ban.unban_time = expires_timestamp @@ -226,11 +245,17 @@ async def handle_platform_ban_or_update( existing_ban.unban_time = expires_timestamp existing_ban.reason = f"Platform Ban: {reason}" # Update reason to indicate platform authority await update_ban(existing_ban) - logger.info(f"Updated existing ban for user {member.id} until {expires_at_str}.", extra=extra_log_data) + logger.info( + f"Updated existing ban for user {member.id} until {expires_at_str}.", + extra=extra_log_data, + ) return {"action": "updated"} - + # Default case (shouldn't reach here, but for safety) - logger.warning(f"Unexpected case in platform ban handling for user {member.id}", extra=extra_log_data) + logger.warning( + f"Unexpected case in platform ban handling for user {member.id}", + extra=extra_log_data, + ) return {"action": "no_action"} @@ -245,7 +270,7 @@ async def ban_member_with_epoch( needs_approval: bool = True, ) -> SimpleResponse: """Ban a member from the guild until a specific epoch time. - + Args: bot: The Discord bot instance guild: The guild to ban the member from @@ -255,7 +280,7 @@ async def ban_member_with_epoch( evidence: Evidence supporting the ban author: The member issuing the ban (defaults to bot user) needs_approval: Whether the ban requires approval - + Returns: SimpleResponse with the result of the ban operation, or None if no response needed """ @@ -273,8 +298,7 @@ async def ban_member_with_epoch( current_time = datetime.now(tz=timezone.utc).timestamp() if unban_epoch_time <= current_time: return SimpleResponse( - message="Unban time must be in the future", - delete_after=15 + message="Unban time must be in the future", delete_after=15 ) end_date: str = datetime.fromtimestamp(unban_epoch_time, tz=timezone.utc).strftime( @@ -283,7 +307,7 @@ async def ban_member_with_epoch( if author is None: author = bot.user - + # Author should never be None at this point if author is None: raise ValueError("Author cannot be None") @@ -374,7 +398,7 @@ async def ban_member_with_epoch( f"Evidence: {evidence}", ) embed.set_thumbnail(url=f"{settings.HTB_URL}/images/logo600.png") - view = BanDecisionView(ban_id, bot, guild, member, end_date, reason) + view = BanDecisionView(ban_id, bot) await guild.get_channel(settings.channels.SR_MOD).send(embed=embed, view=view) # type: ignore return await _create_ban_response( @@ -393,7 +417,7 @@ async def ban_member( needs_approval: bool = True, ) -> SimpleResponse: """Ban a member from the guild using a duration. - + Args: bot: The Discord bot instance guild: The guild to ban the member from @@ -403,7 +427,7 @@ async def ban_member( evidence: Evidence supporting the ban author: The member issuing the ban (defaults to bot user) needs_approval: Whether the ban requires approval - + Returns: SimpleResponse with the result of the ban operation, or None if no response needed """ @@ -516,7 +540,7 @@ async def mute_member( if author is None: author = bot.user - + # Author should never be None at this point if author is None: raise ValueError("Author cannot be None") diff --git a/src/views/bandecisionview.py b/src/views/bandecisionview.py index d42f51f..d1c0c02 100644 --- a/src/views/bandecisionview.py +++ b/src/views/bandecisionview.py @@ -1,7 +1,8 @@ +import logging from datetime import datetime import discord -from discord import Guild, Interaction, Member, User +from discord import Interaction from discord.ui import Button, InputText, Modal, View from sqlalchemy import select @@ -12,138 +13,224 @@ from src.helpers.duration import validate_duration from src.helpers.schedule import schedule +logger = logging.getLogger(__name__) + class BanDecisionView(View): - """View for making decisions on a ban duration.""" + """Persistent view for making decisions on a ban duration. + + Encodes the ban_id into each button's custom_id so the view can be + reconstructed and re-registered after a bot restart. + """ - def __init__(self, ban_id: int, bot: Bot, guild: Guild, member: Member | User, end_date: str, reason: str): + def __init__(self, ban_id: int, bot: Bot): super().__init__(timeout=None) self.ban_id = ban_id self.bot = bot - self.guild = guild - self.member = member - self.end_date = end_date - self.reason = reason - - async def update_message(self, interaction: Interaction, decision: str) -> None: - """Update the message to reflect the decision.""" - admin_name = interaction.user.display_name - decision_message = f"{admin_name} has made a decision: **{decision}** for {self.member.display_name}." - await interaction.message.edit(content=decision_message, view=self) - - async def disable_all_buttons(self) -> None: - """Disable all buttons in the view.""" + + approve_btn = Button( + label="Approve duration", + style=discord.ButtonStyle.success, + custom_id=f"ban_approve:{ban_id}", + ) + approve_btn.callback = self._approve + self.add_item(approve_btn) + + deny_btn = Button( + label="Deny and unban", + style=discord.ButtonStyle.danger, + custom_id=f"ban_deny:{ban_id}", + ) + deny_btn.callback = self._deny + self.add_item(deny_btn) + + dispute_btn = Button( + label="Dispute", + style=discord.ButtonStyle.primary, + custom_id=f"ban_dispute:{ban_id}", + ) + dispute_btn.callback = self._dispute + self.add_item(dispute_btn) + + async def _resolve_member_name(self, guild: discord.Guild, user_id: int) -> str: + """Resolve a display name for the banned user, falling back to the raw ID.""" + member = await self.bot.get_member_or_user(guild, user_id) + return member.display_name if member else str(user_id) + + def _disable_all(self) -> None: + """Disable every button in the view.""" for item in self.children: if isinstance(item, Button): item.disabled = True - async def update_buttons(self, clicked_button_id: str) -> None: - """Disable the clicked button and enable all others.""" + def _disable_one(self, custom_id: str) -> None: + """Disable only the button matching *custom_id*.""" for item in self.children: if isinstance(item, Button): - item.disabled = item.custom_id == clicked_button_id + item.disabled = item.custom_id == custom_id + + # ------------------------------------------------------------------ + # Button callbacks + # ------------------------------------------------------------------ - @discord.ui.button(label="Approve duration", style=discord.ButtonStyle.success, custom_id="approve_button") - async def approve_button(self, button: Button, interaction: Interaction) -> None: + async def _approve(self, interaction: Interaction) -> None: """Approve the ban duration.""" - await interaction.response.send_message( - f"Ban duration for {self.member.display_name} has been approved.", ephemeral=True - ) + await interaction.response.defer(ephemeral=True) + async with AsyncSessionLocal() as session: - stmt = select(Ban).filter(Ban.id == self.ban_id) - result = await session.scalars(stmt) - ban = result.first() - if ban: - ban.approved = True - await session.commit() - await self.guild.get_channel(settings.channels.SR_MOD).send( - f"Ban duration for {self.member.display_name} has been approved by {interaction.user.display_name}." + ban = await session.get(Ban, self.ban_id) + if not ban: + await interaction.followup.send("Ban record not found.", ephemeral=True) + return + ban.approved = True + user_id = ban.user_id + await session.commit() + + member_name = await self._resolve_member_name(interaction.guild, user_id) + + await interaction.followup.send( + f"Ban duration for {member_name} has been approved.", ephemeral=True ) - # Disable the clicked button and enable others - await self.update_buttons("approve_button") - await interaction.message.edit(view=self) - await self.update_message(interaction, "Approved Duration") - - @discord.ui.button(label="Deny and unban", style=discord.ButtonStyle.danger, custom_id="deny_button") - async def deny_button(self, button: Button, interaction: Interaction) -> None: - """Deny the ban duration and unban the member.""" + + channel = interaction.guild.get_channel(settings.channels.SR_MOD) + if channel: + await channel.send( + f"Ban duration for {member_name} has been approved by {interaction.user.display_name}." + ) + + self._disable_one(f"ban_approve:{self.ban_id}") + await interaction.message.edit( + content=f"{interaction.user.display_name} has made a decision: **Approved Duration** for {member_name}.", + view=self, + ) + + async def _deny(self, interaction: Interaction) -> None: + """Deny the ban and unban the member.""" from src.helpers.ban import unban_member - await interaction.response.send_message( - f"Ban for {self.member.display_name} has been denied and the member will be unbanned.", ephemeral=True + + await interaction.response.defer(ephemeral=True) + + async with AsyncSessionLocal() as session: + ban = await session.get(Ban, self.ban_id) + if not ban: + await interaction.followup.send("Ban record not found.", ephemeral=True) + return + user_id = ban.user_id + + member = await self.bot.get_member_or_user(interaction.guild, user_id) + member_name = member.display_name if member else str(user_id) + + if member: + await unban_member(interaction.guild, member) + + await interaction.followup.send( + f"Ban for {member_name} has been denied and the member will be unbanned.", + ephemeral=True, ) - await unban_member(self.guild, self.member) - await self.guild.get_channel(settings.channels.SR_MOD).send( - f"Ban for {self.member.display_name} has been denied by {interaction.user.display_name} and the member has been unbanned." + + channel = interaction.guild.get_channel(settings.channels.SR_MOD) + if channel: + await channel.send( + f"Ban for {member_name} has been denied by {interaction.user.display_name} " + f"and the member has been unbanned." + ) + + self._disable_all() + await interaction.message.edit( + content=f"{interaction.user.display_name} has made a decision: **Denied and Unbanned** for {member_name}.", + view=self, ) - # Disable all buttons after denial - await self.disable_all_buttons() - await interaction.message.edit(view=self) - await self.update_message(interaction, "Denied and Unbanned") - - @discord.ui.button(label="Dispute", style=discord.ButtonStyle.primary, custom_id="dispute_button") - async def dispute_button(self, button: Button, interaction: Interaction) -> None: - """Dispute the ban duration.""" - modal = DisputeModal(self.ban_id, self.bot, self.guild, self.member, self.end_date, self.reason, self) + + async def _dispute(self, interaction: Interaction) -> None: + """Open the dispute modal.""" + modal = DisputeModal(self.ban_id, self.bot, self) await interaction.response.send_modal(modal) class DisputeModal(Modal): """Modal for disputing a ban duration.""" - def __init__(self, ban_id: int, bot: Bot, guild: Guild, member: Member | User, end_date: str, reason: str, parent_view: BanDecisionView): + def __init__(self, ban_id: int, bot: Bot, parent_view: BanDecisionView): super().__init__(title="Dispute Ban Duration") self.ban_id = ban_id self.bot = bot - self.guild = guild - self.member = member - self.end_date = end_date - self.reason = reason - self.parent_view = parent_view # Store the parent view + self.parent_view = parent_view - # Add InputText for duration self.add_item( - InputText(label="New Duration", placeholder="Enter new duration (e.g., 10s, 5m, 2h, 1d)", required=True) + InputText( + label="New Duration", + placeholder="Enter new duration (e.g., 10s, 5m, 2h, 1d)", + required=True, + ) ) async def callback(self, interaction: Interaction) -> None: - """Handle the dispute duration callback.""" + """Handle the dispute duration submission.""" from src.helpers.ban import unban_member - new_duration_str = self.children[0].value - # Validate duration using `validate_duration` + new_duration_str = self.children[0].value dur, dur_exc = validate_duration(new_duration_str) if dur_exc: - # Send an ephemeral message if the duration is invalid await interaction.response.send_message(dur_exc, ephemeral=True) return - # Proceed with updating the ban record if the duration is valid async with AsyncSessionLocal() as session: ban = await session.get(Ban, self.ban_id) - if not ban or not ban.timestamp: - await interaction.response.send_message(f"Cannot dispute ban {self.ban_id}: record not found.", ephemeral=True) + await interaction.response.send_message( + f"Cannot dispute ban {self.ban_id}: record not found.", + ephemeral=True, + ) return - # Update the ban's unban time and approve the dispute ban.unban_time = dur ban.approved = True + user_id = ban.user_id await session.commit() - # Schedule the unban based on the new duration new_unban_at = datetime.fromtimestamp(dur) - member = await self.bot.get_member_or_user(self.guild, ban.user_id) + member = await self.bot.get_member_or_user(interaction.guild, user_id) + member_name = member.display_name if member else str(user_id) + if member: - self.bot.loop.create_task(schedule(unban_member(self.guild, member), run_at=new_unban_at)) + self.bot.loop.create_task( + schedule(unban_member(interaction.guild, member), run_at=new_unban_at) + ) - # Notify the user and moderators of the updated ban duration await interaction.response.send_message( - f"Ban duration updated to {new_duration_str}. The member will be unbanned on {new_unban_at.strftime('%B %d, %Y')} UTC.", - ephemeral=True - ) - await self.guild.get_channel(settings.channels.SR_MOD).send( - f"Ban duration for {self.member.display_name} updated to {new_duration_str}. Unban scheduled for {new_unban_at.strftime('%B %d, %Y')} UTC." + f"Ban duration updated to {new_duration_str}. " + f"The member will be unbanned on {new_unban_at.strftime('%B %d, %Y')} UTC.", + ephemeral=True, ) - # Disable buttons and update message on the parent view after dispute - await self.parent_view.update_message(interaction, "Disputed Duration") + channel = interaction.guild.get_channel(settings.channels.SR_MOD) + if channel: + await channel.send( + f"Ban duration for {member_name} updated to {new_duration_str}. " + f"Unban scheduled for {new_unban_at.strftime('%B %d, %Y')} UTC." + ) + + self.parent_view._disable_all() + if interaction.message: + await interaction.message.edit( + content=f"{interaction.user.display_name} has made a decision: **Disputed Duration** for {member_name}.", + view=self.parent_view, + ) + + +async def register_ban_views(bot: Bot) -> None: + """Re-register persistent BanDecisionView instances for all unapproved active bans. + + Call this once during bot startup (e.g. in on_ready) so that buttons on + existing ban-decision messages continue to work after a restart. + """ + async with AsyncSessionLocal() as session: + stmt = select(Ban).filter(Ban.approved.is_(False), Ban.unbanned.is_(False)) + result = await session.scalars(stmt) + bans = result.all() + + for ban in bans: + bot.add_view(BanDecisionView(ban.id, bot)) + + if bans: + logger.info("Registered %d persistent ban decision view(s).", len(bans)) From 1e896758555653d0c989c83e90be29615b905019 Mon Sep 17 00:00:00 2001 From: 0xry4n Date: Fri, 20 Mar 2026 17:47:59 +0000 Subject: [PATCH 2/2] =?UTF-8?q?=E2=9C=85=20Add=20coverage=20for=20ban=20de?= =?UTF-8?q?cision=20interaction=20changes?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/src/views/test_bandecisionview.py | 557 ++++++++++++++++++++++++ 1 file changed, 557 insertions(+) create mode 100644 tests/src/views/test_bandecisionview.py diff --git a/tests/src/views/test_bandecisionview.py b/tests/src/views/test_bandecisionview.py new file mode 100644 index 0000000..4643623 --- /dev/null +++ b/tests/src/views/test_bandecisionview.py @@ -0,0 +1,557 @@ +from datetime import datetime +from unittest import mock +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from discord.ui import Button + +from src.views.bandecisionview import BanDecisionView, DisputeModal, register_ban_views +from tests import helpers + + +def _make_interaction( + guild: helpers.MockGuild | None = None, + user: helpers.MockMember | None = None, +) -> MagicMock: + """Build a lightweight mock Interaction with the attributes our callbacks use.""" + interaction = MagicMock() + interaction.guild = guild or helpers.MockGuild() + interaction.user = user or helpers.MockMember(name="Admin") + interaction.user.display_name = interaction.user.name + + interaction.response = MagicMock() + interaction.response.defer = AsyncMock() + interaction.response.send_message = AsyncMock() + interaction.response.send_modal = AsyncMock() + + interaction.followup = MagicMock() + interaction.followup.send = AsyncMock() + + interaction.message = MagicMock() + interaction.message.edit = AsyncMock() + + return interaction + + +def _make_ban( + ban_id: int = 1, + user_id: int = 42, + approved: bool = False, + unbanned: bool = False, + timestamp: object = True, + unban_time: int = 9999999999, +) -> MagicMock: + """Build a mock Ban model instance.""" + ban = MagicMock() + ban.id = ban_id + ban.user_id = user_id + ban.approved = approved + ban.unbanned = unbanned + ban.timestamp = timestamp + ban.unban_time = unban_time + return ban + + +def _session_ctx(session_mock: AsyncMock) -> MagicMock: + """Wrap an AsyncMock session so it works as ``async with AsyncSessionLocal() as s``.""" + ctx = MagicMock() + ctx.__aenter__ = AsyncMock(return_value=session_mock) + ctx.__aexit__ = AsyncMock(return_value=False) + return ctx + + +class TestBanDecisionViewInit: + @pytest.mark.asyncio + async def test_buttons_have_unique_custom_ids(self, bot): + view = BanDecisionView(ban_id=7, bot=bot) + buttons = [c for c in view.children if isinstance(c, Button)] + + assert len(buttons) == 3 + ids = {b.custom_id for b in buttons} + assert ids == {"ban_approve:7", "ban_deny:7", "ban_dispute:7"} + + @pytest.mark.asyncio + async def test_timeout_is_none(self, bot): + view = BanDecisionView(ban_id=1, bot=bot) + assert view.timeout is None + + @pytest.mark.asyncio + async def test_buttons_are_not_disabled_by_default(self, bot): + view = BanDecisionView(ban_id=1, bot=bot) + for child in view.children: + if isinstance(child, Button): + assert not child.disabled + + +class TestDisableHelpers: + @pytest.mark.asyncio + async def test_disable_all(self, bot): + view = BanDecisionView(ban_id=1, bot=bot) + view._disable_all() + for child in view.children: + if isinstance(child, Button): + assert child.disabled + + @pytest.mark.asyncio + async def test_disable_one_only_targets_matching_button(self, bot): + view = BanDecisionView(ban_id=5, bot=bot) + view._disable_one("ban_approve:5") + + for child in view.children: + if isinstance(child, Button): + if child.custom_id == "ban_approve:5": + assert child.disabled + else: + assert not child.disabled + + +class TestResolveMemberName: + @pytest.mark.asyncio + async def test_returns_display_name_when_member_exists(self, bot, guild): + member = helpers.MockMember(name="BannedUser") + member.display_name = "BannedUser" + bot.get_member_or_user = AsyncMock(return_value=member) + + view = BanDecisionView(ban_id=1, bot=bot) + result = await view._resolve_member_name(guild, 42) + + assert result == "BannedUser" + bot.get_member_or_user.assert_awaited_once_with(guild, 42) + + @pytest.mark.asyncio + async def test_falls_back_to_str_id_when_member_missing(self, bot, guild): + bot.get_member_or_user = AsyncMock(return_value=None) + + view = BanDecisionView(ban_id=1, bot=bot) + result = await view._resolve_member_name(guild, 42) + + assert result == "42" + + +class TestApproveCallback: + @pytest.mark.asyncio + async def test_approve_happy_path(self, bot, guild): + ban = _make_ban(ban_id=1, user_id=42) + session = AsyncMock() + session.get = AsyncMock(return_value=ban) + session.commit = AsyncMock() + + channel = helpers.MockTextChannel() + guild.get_channel = MagicMock(return_value=channel) + + member = helpers.MockMember(name="BannedUser") + member.display_name = "BannedUser" + bot.get_member_or_user = AsyncMock(return_value=member) + + interaction = _make_interaction( + guild=guild, user=helpers.MockMember(name="Admin") + ) + + view = BanDecisionView(ban_id=1, bot=bot) + + with patch( + "src.views.bandecisionview.AsyncSessionLocal", + return_value=_session_ctx(session), + ): + await view._approve(interaction) + + interaction.response.defer.assert_awaited_once_with(ephemeral=True) + assert ban.approved is True + session.commit.assert_awaited_once() + + interaction.followup.send.assert_awaited_once() + assert "approved" in interaction.followup.send.call_args[0][0].lower() + + channel.send.assert_awaited_once() + interaction.message.edit.assert_awaited_once() + assert "Approved Duration" in interaction.message.edit.call_args[1]["content"] + + @pytest.mark.asyncio + async def test_approve_ban_not_found(self, bot, guild): + session = AsyncMock() + session.get = AsyncMock(return_value=None) + + interaction = _make_interaction(guild=guild) + view = BanDecisionView(ban_id=999, bot=bot) + + with patch( + "src.views.bandecisionview.AsyncSessionLocal", + return_value=_session_ctx(session), + ): + await view._approve(interaction) + + interaction.response.defer.assert_awaited_once_with(ephemeral=True) + interaction.followup.send.assert_awaited_once() + assert "not found" in interaction.followup.send.call_args[0][0].lower() + interaction.message.edit.assert_not_awaited() + + @pytest.mark.asyncio + async def test_approve_disables_approve_button(self, bot, guild): + ban = _make_ban(ban_id=3, user_id=42) + session = AsyncMock() + session.get = AsyncMock(return_value=ban) + session.commit = AsyncMock() + + channel = helpers.MockTextChannel() + guild.get_channel = MagicMock(return_value=channel) + bot.get_member_or_user = AsyncMock(return_value=helpers.MockMember(name="User")) + + interaction = _make_interaction(guild=guild) + view = BanDecisionView(ban_id=3, bot=bot) + + with patch( + "src.views.bandecisionview.AsyncSessionLocal", + return_value=_session_ctx(session), + ): + await view._approve(interaction) + + for child in view.children: + if isinstance(child, Button): + if child.custom_id == "ban_approve:3": + assert child.disabled + else: + assert not child.disabled + + @pytest.mark.asyncio + async def test_approve_skips_channel_send_when_channel_is_none(self, bot, guild): + ban = _make_ban(ban_id=1, user_id=42) + session = AsyncMock() + session.get = AsyncMock(return_value=ban) + session.commit = AsyncMock() + + guild.get_channel = MagicMock(return_value=None) + bot.get_member_or_user = AsyncMock(return_value=helpers.MockMember(name="User")) + + interaction = _make_interaction(guild=guild) + view = BanDecisionView(ban_id=1, bot=bot) + + with patch( + "src.views.bandecisionview.AsyncSessionLocal", + return_value=_session_ctx(session), + ): + await view._approve(interaction) + + interaction.followup.send.assert_awaited_once() + interaction.message.edit.assert_awaited_once() + + +class TestDenyCallback: + @pytest.mark.asyncio + async def test_deny_happy_path(self, bot, guild): + ban = _make_ban(ban_id=2, user_id=42) + session = AsyncMock() + session.get = AsyncMock(return_value=ban) + + channel = helpers.MockTextChannel() + guild.get_channel = MagicMock(return_value=channel) + + member = helpers.MockMember(name="BannedUser") + member.display_name = "BannedUser" + bot.get_member_or_user = AsyncMock(return_value=member) + + interaction = _make_interaction( + guild=guild, user=helpers.MockMember(name="Admin") + ) + view = BanDecisionView(ban_id=2, bot=bot) + + with ( + patch( + "src.views.bandecisionview.AsyncSessionLocal", + return_value=_session_ctx(session), + ), + patch("src.helpers.ban.unban_member", new_callable=AsyncMock) as mock_unban, + ): + await view._deny(interaction) + + interaction.response.defer.assert_awaited_once_with(ephemeral=True) + mock_unban.assert_awaited_once_with(guild, member) + + interaction.followup.send.assert_awaited_once() + assert "denied" in interaction.followup.send.call_args[0][0].lower() + + channel.send.assert_awaited_once() + interaction.message.edit.assert_awaited_once() + assert "Denied and Unbanned" in interaction.message.edit.call_args[1]["content"] + + @pytest.mark.asyncio + async def test_deny_ban_not_found(self, bot, guild): + session = AsyncMock() + session.get = AsyncMock(return_value=None) + + interaction = _make_interaction(guild=guild) + view = BanDecisionView(ban_id=999, bot=bot) + + with patch( + "src.views.bandecisionview.AsyncSessionLocal", + return_value=_session_ctx(session), + ): + await view._deny(interaction) + + interaction.followup.send.assert_awaited_once() + assert "not found" in interaction.followup.send.call_args[0][0].lower() + interaction.message.edit.assert_not_awaited() + + @pytest.mark.asyncio + async def test_deny_disables_all_buttons(self, bot, guild): + ban = _make_ban(ban_id=4, user_id=42) + session = AsyncMock() + session.get = AsyncMock(return_value=ban) + + guild.get_channel = MagicMock(return_value=helpers.MockTextChannel()) + bot.get_member_or_user = AsyncMock(return_value=helpers.MockMember(name="User")) + + interaction = _make_interaction(guild=guild) + view = BanDecisionView(ban_id=4, bot=bot) + + with ( + patch( + "src.views.bandecisionview.AsyncSessionLocal", + return_value=_session_ctx(session), + ), + patch("src.helpers.ban.unban_member", new_callable=AsyncMock), + ): + await view._deny(interaction) + + for child in view.children: + if isinstance(child, Button): + assert child.disabled + + @pytest.mark.asyncio + async def test_deny_member_not_found_skips_unban(self, bot, guild): + ban = _make_ban(ban_id=5, user_id=42) + session = AsyncMock() + session.get = AsyncMock(return_value=ban) + + guild.get_channel = MagicMock(return_value=helpers.MockTextChannel()) + bot.get_member_or_user = AsyncMock(return_value=None) + + interaction = _make_interaction(guild=guild) + view = BanDecisionView(ban_id=5, bot=bot) + + with ( + patch( + "src.views.bandecisionview.AsyncSessionLocal", + return_value=_session_ctx(session), + ), + patch("src.helpers.ban.unban_member", new_callable=AsyncMock) as mock_unban, + ): + await view._deny(interaction) + + mock_unban.assert_not_awaited() + assert "42" in interaction.followup.send.call_args[0][0] + + +class TestDisputeCallback: + @pytest.mark.asyncio + async def test_dispute_sends_modal(self, bot, guild): + interaction = _make_interaction(guild=guild) + view = BanDecisionView(ban_id=10, bot=bot) + + await view._dispute(interaction) + + interaction.response.send_modal.assert_awaited_once() + modal = interaction.response.send_modal.call_args[0][0] + assert isinstance(modal, DisputeModal) + assert modal.ban_id == 10 + assert modal.parent_view is view + + +class TestDisputeModalCallback: + @pytest.mark.asyncio + async def test_dispute_modal_invalid_duration(self, bot, guild): + view = BanDecisionView(ban_id=1, bot=bot) + modal = DisputeModal(ban_id=1, bot=bot, parent_view=view) + + interaction = _make_interaction(guild=guild) + modal.children[0].value = "garbage" + + with patch( + "src.views.bandecisionview.validate_duration", + return_value=(0, "Invalid duration"), + ): + await modal.callback(interaction) + + interaction.response.send_message.assert_awaited_once_with( + "Invalid duration", ephemeral=True + ) + + @pytest.mark.asyncio + async def test_dispute_modal_ban_not_found(self, bot, guild): + view = BanDecisionView(ban_id=999, bot=bot) + modal = DisputeModal(ban_id=999, bot=bot, parent_view=view) + + session = AsyncMock() + session.get = AsyncMock(return_value=None) + + interaction = _make_interaction(guild=guild) + modal.children[0].value = "1d" + + future_ts = int(datetime.now().timestamp()) + 86400 + with ( + patch( + "src.views.bandecisionview.validate_duration", + return_value=(future_ts, ""), + ), + patch( + "src.views.bandecisionview.AsyncSessionLocal", + return_value=_session_ctx(session), + ), + ): + await modal.callback(interaction) + + interaction.response.send_message.assert_awaited_once() + assert "not found" in interaction.response.send_message.call_args[0][0].lower() + + @pytest.mark.asyncio + async def test_dispute_modal_happy_path(self, bot, guild): + ban = _make_ban(ban_id=6, user_id=42) + session = AsyncMock() + session.get = AsyncMock(return_value=ban) + session.commit = AsyncMock() + + channel = helpers.MockTextChannel() + guild.get_channel = MagicMock(return_value=channel) + + member = helpers.MockMember(name="BannedUser") + member.display_name = "BannedUser" + bot.get_member_or_user = AsyncMock(return_value=member) + + view = BanDecisionView(ban_id=6, bot=bot) + modal = DisputeModal(ban_id=6, bot=bot, parent_view=view) + + interaction = _make_interaction(guild=guild) + modal.children[0].value = "2d" + + future_ts = int(datetime.now().timestamp()) + 172800 + with ( + patch( + "src.views.bandecisionview.validate_duration", + return_value=(future_ts, ""), + ), + patch( + "src.views.bandecisionview.AsyncSessionLocal", + return_value=_session_ctx(session), + ), + patch("src.views.bandecisionview.schedule", new_callable=AsyncMock), + ): + await modal.callback(interaction) + + assert ban.unban_time == future_ts + assert ban.approved is True + session.commit.assert_awaited_once() + + interaction.response.send_message.assert_awaited_once() + assert "updated" in interaction.response.send_message.call_args[0][0].lower() + channel.send.assert_awaited_once() + + @pytest.mark.asyncio + async def test_dispute_modal_disables_all_buttons(self, bot, guild): + ban = _make_ban(ban_id=7, user_id=42) + session = AsyncMock() + session.get = AsyncMock(return_value=ban) + session.commit = AsyncMock() + + guild.get_channel = MagicMock(return_value=helpers.MockTextChannel()) + bot.get_member_or_user = AsyncMock(return_value=helpers.MockMember(name="User")) + + view = BanDecisionView(ban_id=7, bot=bot) + modal = DisputeModal(ban_id=7, bot=bot, parent_view=view) + + interaction = _make_interaction(guild=guild) + modal.children[0].value = "1h" + + future_ts = int(datetime.now().timestamp()) + 3600 + with ( + patch( + "src.views.bandecisionview.validate_duration", + return_value=(future_ts, ""), + ), + patch( + "src.views.bandecisionview.AsyncSessionLocal", + return_value=_session_ctx(session), + ), + patch("src.views.bandecisionview.schedule", new_callable=AsyncMock), + ): + await modal.callback(interaction) + + for child in view.children: + if isinstance(child, Button): + assert child.disabled + + @pytest.mark.asyncio + async def test_dispute_modal_skips_message_edit_when_message_is_none( + self, bot, guild + ): + ban = _make_ban(ban_id=8, user_id=42) + session = AsyncMock() + session.get = AsyncMock(return_value=ban) + session.commit = AsyncMock() + + guild.get_channel = MagicMock(return_value=helpers.MockTextChannel()) + bot.get_member_or_user = AsyncMock(return_value=helpers.MockMember(name="User")) + + view = BanDecisionView(ban_id=8, bot=bot) + modal = DisputeModal(ban_id=8, bot=bot, parent_view=view) + + interaction = _make_interaction(guild=guild) + interaction.message = None + modal.children[0].value = "1d" + + future_ts = int(datetime.now().timestamp()) + 86400 + with ( + patch( + "src.views.bandecisionview.validate_duration", + return_value=(future_ts, ""), + ), + patch( + "src.views.bandecisionview.AsyncSessionLocal", + return_value=_session_ctx(session), + ), + patch("src.views.bandecisionview.schedule", new_callable=AsyncMock), + ): + await modal.callback(interaction) + + interaction.response.send_message.assert_awaited_once() + + +class TestRegisterBanViews: + @pytest.mark.asyncio + async def test_registers_views_for_unapproved_bans(self, bot): + ban_a = _make_ban(ban_id=10) + ban_b = _make_ban(ban_id=20) + + scalars_result = MagicMock() + scalars_result.all.return_value = [ban_a, ban_b] + + session = AsyncMock() + session.scalars = AsyncMock(return_value=scalars_result) + + bot.add_view = MagicMock() + + with patch( + "src.views.bandecisionview.AsyncSessionLocal", + return_value=_session_ctx(session), + ): + await register_ban_views(bot) + + assert bot.add_view.call_count == 2 + registered_ids = {call.args[0].ban_id for call in bot.add_view.call_args_list} + assert registered_ids == {10, 20} + + @pytest.mark.asyncio + async def test_no_views_registered_when_no_pending_bans(self, bot): + scalars_result = MagicMock() + scalars_result.all.return_value = [] + + session = AsyncMock() + session.scalars = AsyncMock(return_value=scalars_result) + + bot.add_view = MagicMock() + + with patch( + "src.views.bandecisionview.AsyncSessionLocal", + return_value=_session_ctx(session), + ): + await register_ban_views(bot) + + bot.add_view.assert_not_called()