Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 71 additions & 18 deletions src/bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,27 @@
import discord
from aiohttp import AsyncResolver, ClientSession, TCPConnector
from discord import (
ApplicationContext, Cog, DiscordException, Embed, Forbidden, Guild, HTTPException, Member, NotFound, User,
ApplicationContext,
Cog,
DiscordException,
Embed,
Forbidden,
Guild,
HTTPException,
Member,
NotFound,
User,
)
from discord.ext.commands import Bot as DiscordBot
from discord.ext.commands import (
CommandNotFound, CommandOnCooldown, DefaultHelpCommand, MissingAnyRole, MissingPermissions,
MissingRequiredArgument, NoPrivateMessage, UserInputError,
CommandNotFound,
CommandOnCooldown,
DefaultHelpCommand,
MissingAnyRole,
MissingPermissions,
MissingRequiredArgument,
NoPrivateMessage,
UserInputError,
)
from sqlalchemy.exc import NoResultFound
from typing import TypeVar
Expand Down Expand Up @@ -50,13 +65,15 @@ async def on_ready(self) -> None:
if self.http_session is None:
logger.debug("Starting the HTTP session")
self.http_session = ClientSession(
connector=TCPConnector(resolver=AsyncResolver(), family=socket.AF_INET),
trace_configs=[trace_config]
connector=TCPConnector(resolver=AsyncResolver(), family=socket.AF_INET),
trace_configs=[trace_config],
)

name = f"{self.user} (ID: {self.user.id})"
devlog_msg = f"Connected {constants.emojis.partying_face}"
self.loop.create_task(self.send_log(devlog_msg, colour=constants.colours.bright_green))
self.loop.create_task(
self.send_log(devlog_msg, colour=constants.colours.bright_green)
)

logger.info(f"Started bot as {name}")
print("Loading ScheduledTasks cog...")
Expand All @@ -66,6 +83,17 @@ async def on_ready(self) -> None:
except Exception as e:
print(f"Failed to load ScheduledTasks cog: {e}")

await self._register_persistent_views()

async def _register_persistent_views(self) -> None:
"""Re-register persistent UI views so buttons survive bot restarts."""
from src.views.bandecisionview import register_ban_views

try:
await register_ban_views(self)
except Exception:
logger.exception("Failed to register persistent ban decision views")

async def on_application_command(self, ctx: ApplicationContext) -> None:
"""A global handler cog."""
logger.debug(f"Command '{ctx.command}' received.")
Expand All @@ -77,14 +105,18 @@ async def on_application_command(self, ctx: ApplicationContext) -> None:
embed.add_field(name="Channel", value=ctx.channel.name, inline=True)
await ctx.guild.get_channel(settings.channels.BOT_LOGS).send(embed=embed)

async def on_application_command_error(self, ctx: ApplicationContext, error: DiscordException) -> None:
async def on_application_command_error(
self, ctx: ApplicationContext, error: DiscordException
) -> None:
"""A global error handler cog."""
message = None
if isinstance(error, CommandNotFound):
return
if isinstance(error, MissingRequiredArgument):
message = f"Parameter '{error.param.name}' is required, but missing. Type `{ctx.clean_prefix}help " \
f"{ctx.invoked_with}` for help."
message = (
f"Parameter '{error.param.name}' is required, but missing. Type `{ctx.clean_prefix}help "
f"{ctx.invoked_with}` for help."
)
elif isinstance(error, MissingPermissions):
message = "You are missing the required permissions to run this command."
elif isinstance(error, MissingAnyRole):
Expand Down Expand Up @@ -120,7 +152,9 @@ def add_cog(self, cog: Cog, *, override: bool = False) -> None:
super().add_cog(cog, override=override)
logger.debug(f"Cog loaded: {cog.qualified_name}")

async def send_log(self, description: str = None, colour: int = None, embed: Embed = None) -> None:
async def send_log(
self, description: str = None, colour: int = None, embed: Embed = None
) -> None:
"""Send an embed message to the devlog channel."""
devlog = self.get_channel(settings.channels.DEVLOG)

Expand Down Expand Up @@ -159,27 +193,46 @@ async def get_member_or_user(self, guild: Guild, id_: int) -> Member | User | No
try:
return await guild.fetch_member(id_)
except Forbidden as exc:
logger.warning(f"Unauthorized attempt to fetch member with id: {id_}", exc_info=exc)
logger.warning(
f"Unauthorized attempt to fetch member with id: {id_}", exc_info=exc
)
except NotFound as exc:
logger.warning(f"Could not find guild member with id: {id_}", exc_info=exc)
try:
return await self.get_or_fetch_user(id_)
except Forbidden as exc:
logger.warning(f"Unauthorized attempt to fetch member with id: {id_}", exc_info=exc)
logger.warning(
f"Unauthorized attempt to fetch member with id: {id_}", exc_info=exc
)
except NotFound as exc:
logger.warning(f"Could not find guild member with id: {id_}", exc_info=exc)
logger.warning(
f"Could not find guild member with id: {id_}", exc_info=exc
)
except HTTPException as exc:
logger.error(f"Discord error while fetching guild member with id: {id_}", exc_info=exc)
logger.error(
f"Discord error while fetching guild member with id: {id_}",
exc_info=exc,
)
except HTTPException as exc:
logger.error(f"Discord error while fetching guild member with id: {id_}", exc_info=exc)
logger.error(
f"Discord error while fetching guild member with id: {id_}",
exc_info=exc,
)
try:
return await self.get_or_fetch_user(id_)
except Forbidden as exc:
logger.warning(f"Unauthorized attempt to fetch member with id: {id_}", exc_info=exc)
logger.warning(
f"Unauthorized attempt to fetch member with id: {id_}", exc_info=exc
)
except NotFound as exc:
logger.warning(f"Could not find guild member with id: {id_}", exc_info=exc)
logger.warning(
f"Could not find guild member with id: {id_}", exc_info=exc
)
except HTTPException as exc:
logger.error(f"Discord error while fetching guild member with id: {id_}", exc_info=exc)
logger.error(
f"Discord error while fetching guild member with id: {id_}",
exc_info=exc,
)

return None

Expand Down
72 changes: 48 additions & 24 deletions src/helpers/ban.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,10 @@ class BanCodes(Enum):


async def _check_member(
bot: Bot, guild: Guild, member: Member | User, author: Member | ClientUser | None = None
bot: Bot,
guild: Guild,
member: Member | User,
author: Member | ClientUser | None = None,
) -> SimpleResponse | None:
if isinstance(member, Member):
if member_is_staff(member):
Expand Down Expand Up @@ -70,7 +73,9 @@ async def get_ban(member: Member | User) -> Ban | None:


async def update_ban(ban: Ban) -> None:
logger.info(f"Updating ban {ban.id} for user {ban.user_id} with expiration {ban.unban_time}")
logger.info(
f"Updating ban {ban.id} for user {ban.user_id} with expiration {ban.unban_time}"
)
async with AsyncSessionLocal() as session:
session.add(ban)
await session.commit()
Expand Down Expand Up @@ -156,7 +161,7 @@ async def handle_platform_ban_or_update(
extra_log_data: dict | None = None,
) -> dict:
"""Handle platform ban by either creating new ban, updating existing ban, or taking no action.

Args:
bot: The Discord bot instance
guild: The guild to ban the member from
Expand All @@ -169,30 +174,44 @@ async def handle_platform_ban_or_update(
log_channel_id: Channel ID for logging ban actions
logger: Logger instance for recording events
extra_log_data: Additional data to include in log entries

Returns:
dict with 'action' key indicating what was done: 'unbanned', 'extended', 'no_action', 'updated', 'created'
"""
if extra_log_data is None:
extra_log_data = {}

expires_dt = datetime.fromtimestamp(expires_timestamp)

existing_ban = await get_ban(member)
if not existing_ban:
# No existing ban, create new one
await ban_member_with_epoch(
bot, guild, member, expires_timestamp, reason, evidence, needs_approval=False
bot,
guild,
member,
expires_timestamp,
reason,
evidence,
needs_approval=False,
)
await _send_ban_notice(
guild, member, reason, author_name, expires_at_str, guild.get_channel(log_channel_id) # type: ignore
guild,
member,
reason,
author_name,
expires_at_str,
guild.get_channel(log_channel_id), # type: ignore
)
logger.info(
f"Created new platform ban for user {member.id} until {expires_at_str}",
extra=extra_log_data,
)
logger.info(f"Created new platform ban for user {member.id} until {expires_at_str}", extra=extra_log_data)
return {"action": "created"}

# Existing ban found - determine what to do based on ban type and timing
is_platform_ban = existing_ban.reason.startswith("Platform Ban")

if is_platform_ban:
# Platform bans have authority over other platform bans
if expires_dt < datetime.now():
Expand All @@ -202,7 +221,7 @@ async def handle_platform_ban_or_update(
await guild.get_channel(log_channel_id).send(msg) # type: ignore
logger.info(msg, extra=extra_log_data)
return {"action": "unbanned"}

if existing_ban.unban_time < expires_timestamp:
# Extend the existing platform ban
existing_ban.unban_time = expires_timestamp
Expand All @@ -226,11 +245,17 @@ async def handle_platform_ban_or_update(
existing_ban.unban_time = expires_timestamp
existing_ban.reason = f"Platform Ban: {reason}" # Update reason to indicate platform authority
await update_ban(existing_ban)
logger.info(f"Updated existing ban for user {member.id} until {expires_at_str}.", extra=extra_log_data)
logger.info(
f"Updated existing ban for user {member.id} until {expires_at_str}.",
extra=extra_log_data,
)
return {"action": "updated"}

# Default case (shouldn't reach here, but for safety)
logger.warning(f"Unexpected case in platform ban handling for user {member.id}", extra=extra_log_data)
logger.warning(
f"Unexpected case in platform ban handling for user {member.id}",
extra=extra_log_data,
)
return {"action": "no_action"}


Expand All @@ -245,7 +270,7 @@ async def ban_member_with_epoch(
needs_approval: bool = True,
) -> SimpleResponse:
"""Ban a member from the guild until a specific epoch time.

Args:
bot: The Discord bot instance
guild: The guild to ban the member from
Expand All @@ -255,7 +280,7 @@ async def ban_member_with_epoch(
evidence: Evidence supporting the ban
author: The member issuing the ban (defaults to bot user)
needs_approval: Whether the ban requires approval

Returns:
SimpleResponse with the result of the ban operation, or None if no response needed
"""
Expand All @@ -273,8 +298,7 @@ async def ban_member_with_epoch(
current_time = datetime.now(tz=timezone.utc).timestamp()
if unban_epoch_time <= current_time:
return SimpleResponse(
message="Unban time must be in the future",
delete_after=15
message="Unban time must be in the future", delete_after=15
)

end_date: str = datetime.fromtimestamp(unban_epoch_time, tz=timezone.utc).strftime(
Expand All @@ -283,7 +307,7 @@ async def ban_member_with_epoch(

if author is None:
author = bot.user

# Author should never be None at this point
if author is None:
raise ValueError("Author cannot be None")
Expand Down Expand Up @@ -374,7 +398,7 @@ async def ban_member_with_epoch(
f"Evidence: {evidence}",
)
embed.set_thumbnail(url=f"{settings.HTB_URL}/images/logo600.png")
view = BanDecisionView(ban_id, bot, guild, member, end_date, reason)
view = BanDecisionView(ban_id, bot)
await guild.get_channel(settings.channels.SR_MOD).send(embed=embed, view=view) # type: ignore

return await _create_ban_response(
Expand All @@ -393,7 +417,7 @@ async def ban_member(
needs_approval: bool = True,
) -> SimpleResponse:
"""Ban a member from the guild using a duration.

Args:
bot: The Discord bot instance
guild: The guild to ban the member from
Expand All @@ -403,7 +427,7 @@ async def ban_member(
evidence: Evidence supporting the ban
author: The member issuing the ban (defaults to bot user)
needs_approval: Whether the ban requires approval

Returns:
SimpleResponse with the result of the ban operation, or None if no response needed
"""
Expand Down Expand Up @@ -516,7 +540,7 @@ async def mute_member(

if author is None:
author = bot.user

# Author should never be None at this point
if author is None:
raise ValueError("Author cannot be None")
Expand Down
Loading
Loading