diff --git a/src/tux/database/controllers/__init__.py b/src/tux/database/controllers/__init__.py index f7e512ca..f27639a4 100644 --- a/src/tux/database/controllers/__init__.py +++ b/src/tux/database/controllers/__init__.py @@ -29,6 +29,7 @@ "SnippetController", "StarboardController", "StarboardMessageController", + "VerificationController", ] from tux.database.controllers.afk import AfkController @@ -50,6 +51,7 @@ StarboardController, StarboardMessageController, ) +from tux.database.controllers.verification import VerificationController class DatabaseCoordinator: @@ -113,6 +115,7 @@ def __init__( self._starboard: StarboardController | None = None self._starboard_message: StarboardMessageController | None = None self._reminder: ReminderController | None = None + self._verification: VerificationController | None = None @property def guild(self) -> GuildController: @@ -209,3 +212,10 @@ def command_permissions(self) -> PermissionCommandController: cache_backend=getattr(self, "_cache_backend", None), ) return self._permission_commands + + @property + def verification(self) -> VerificationController: + """Get the verification controller.""" + if self._verification is None: + self._verification = VerificationController(self.db) + return self._verification diff --git a/src/tux/database/controllers/verification.py b/src/tux/database/controllers/verification.py new file mode 100644 index 00000000..5b0f8c44 --- /dev/null +++ b/src/tux/database/controllers/verification.py @@ -0,0 +1,68 @@ +"""Verification controller for Roblox OAuth2 account linking.""" + +from __future__ import annotations + +from datetime import UTC, datetime +from typing import TYPE_CHECKING + +from tux.database.controllers.base import BaseController +from tux.database.models import Verification + +if TYPE_CHECKING: + from tux.database.service import DatabaseService + + +class VerificationController(BaseController[Verification]): + """Controller for verification-related database operations.""" + + def __init__(self, db: DatabaseService | None = None) -> None: + """Initialize the verification controller. + + Parameters + ---------- + db : DatabaseService | None, optional + The database service instance. + """ + super().__init__(Verification, db) + + async def get_by_discord_id(self, discord_id: int) -> Verification | None: + """Get a verification record by Discord ID.""" + return await self.find_one(filters=Verification.discord_id == discord_id) + + async def get_by_roblox_id(self, roblox_id: int) -> Verification | None: + """Get a verification record by Roblox ID.""" + return await self.find_one(filters=Verification.roblox_id == roblox_id) + + async def upsert_verification( + self, + discord_id: int, + roblox_id: int, + roblox_username: str | None = None, + ) -> Verification: + """Create or update a verification record. + + Parameters + ---------- + discord_id : int + Discord user ID. + roblox_id : int + Roblox user ID. + roblox_username : str | None, optional + Roblox username. + + Returns + ------- + Verification + The created or updated verification record. + """ + result, _ = await self.upsert( + filters={"discord_id": discord_id}, + roblox_id=roblox_id, + roblox_username=roblox_username, + verified_at=datetime.now(UTC).replace(tzinfo=None), + ) + return result + + async def delete(self, discord_id: int) -> bool: + """Delete a verification record by Discord ID.""" + return await self.delete_by_id(discord_id) diff --git a/src/tux/database/migrations/versions/2026_02_26_0444-e938fbc52236_dev_migration.py b/src/tux/database/migrations/versions/2026_02_26_0444-e938fbc52236_dev_migration.py new file mode 100644 index 00000000..f440fe2e --- /dev/null +++ b/src/tux/database/migrations/versions/2026_02_26_0444-e938fbc52236_dev_migration.py @@ -0,0 +1,71 @@ +""" +Revision ID: e938fbc52236 +Revises: e7c5ed41fae0 +Create Date: 2026-02-26 04:44:00.430858+00:00 +""" + +from __future__ import annotations + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +import sqlmodel + + +# revision identifiers, used by Alembic. +revision: str = "e938fbc52236" +down_revision: Union[str, None] = "e7c5ed41fae0" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "verification", + sa.Column( + "created_at", + sa.DateTime(), + server_default=sa.text("CURRENT_TIMESTAMP"), + nullable=True, + ), + sa.Column( + "updated_at", + sa.DateTime(), + server_default=sa.text("CURRENT_TIMESTAMP"), + nullable=True, + ), + sa.Column("discord_id", sa.BigInteger(), nullable=False), + sa.Column("roblox_id", sa.BigInteger(), nullable=False), + sa.Column( + "roblox_username", + sqlmodel.sql.sqltypes.AutoString(length=100), + nullable=True, + ), + sa.Column("verified_at", sa.DateTime(), nullable=False), + sa.CheckConstraint( + "discord_id > 0", name="check_verification_discord_id_valid" + ), + sa.CheckConstraint("roblox_id > 0", name="check_verification_roblox_id_valid"), + sa.PrimaryKeyConstraint("discord_id"), + ) + with op.batch_alter_table("verification", schema=None) as batch_op: + batch_op.create_index("idx_verification_roblox", ["roblox_id"], unique=False) + + with op.batch_alter_table("levels", schema=None) as batch_op: + batch_op.drop_table_comment(existing_comment="f") + + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("levels", schema=None) as batch_op: + batch_op.create_table_comment("f", existing_comment=None) + + with op.batch_alter_table("verification", schema=None) as batch_op: + batch_op.drop_index("idx_verification_roblox") + + op.drop_table("verification") + # ### end Alembic commands ### diff --git a/src/tux/database/models/__init__.py b/src/tux/database/models/__init__.py index 9e902723..12fb74b6 100644 --- a/src/tux/database/models/__init__.py +++ b/src/tux/database/models/__init__.py @@ -24,6 +24,7 @@ Snippet, Starboard, StarboardMessage, + Verification, ) __all__ = [ @@ -51,4 +52,5 @@ # Starboard system "Starboard", "StarboardMessage", + "Verification", ] diff --git a/src/tux/database/models/models.py b/src/tux/database/models/models.py index b3d5f3c9..9fe00818 100644 --- a/src/tux/database/models/models.py +++ b/src/tux/database/models/models.py @@ -1272,3 +1272,55 @@ class StarboardMessage(BaseModel, table=True): def __repr__(self) -> str: """Return string representation showing guild, original message and user.""" return f"" + + +# ============================================================================= +# VERIFICATION MODELS +# ============================================================================= + + +class Verification(BaseModel, table=True): + """Linkage between Discord and Roblox accounts. + + Stores the verified Roblox account information for each Discord user. + + Attributes + ---------- + discord_id : int + Discord user ID (primary key). + roblox_id : int + Roblox user ID. + roblox_username : str, optional + Roblox username at the time of verification. + verified_at : datetime + Timestamp when the user was verified. + """ + + discord_id: int = Field( + primary_key=True, + sa_type=BigInteger, + description="Discord user ID", + ) + roblox_id: int = Field( + sa_type=BigInteger, + description="Roblox user ID", + ) + roblox_username: str | None = Field( + default=None, + max_length=100, + description="Roblox username", + ) + verified_at: datetime = Field( + default_factory=lambda: datetime.now(UTC).replace(tzinfo=None), + description="Timestamp when verification occurred", + ) + + __table_args__ = ( + CheckConstraint("discord_id > 0", name="check_verification_discord_id_valid"), + CheckConstraint("roblox_id > 0", name="check_verification_roblox_id_valid"), + Index("idx_verification_roblox", "roblox_id"), + ) + + def __repr__(self) -> str: + """Return string representation showing Discord and Roblox IDs.""" + return f"" diff --git a/src/tux/plugins/spectacle/verify.py b/src/tux/plugins/spectacle/verify.py new file mode 100644 index 00000000..be820679 --- /dev/null +++ b/src/tux/plugins/spectacle/verify.py @@ -0,0 +1,364 @@ +"""Spectacle Studios Discord Servers - Verification Plugin.""" + +import base64 +import hashlib +import secrets +from urllib.parse import quote + +import discord +import httpx +from aiohttp import web +from discord import app_commands +from discord.ext import commands +from loguru import logger + +from tux.core.base_cog import BaseCog +from tux.core.bot import Tux +from tux.services.http_client import http_client +from tux.shared.config import CONFIG + + +class Verify(BaseCog): + """Manage and expose the verification functionality. + + This cog provides commands that allow users to verify their Discord accounts + using Roblox OAuth2 and PKCE flow. It also runs a small web server to + handle the OAuth2 callback. + """ + + def __init__(self, bot: Tux) -> None: + super().__init__(bot) + + # In-memory fallback if cache_service is not available + self._pending_verifications: dict[str, dict[str, int]] = {} + self._verifiers: dict[str, str] = {} + + self.site_app = web.Application() + self.site_app.add_routes([web.get("/callback", self.handle_callback)]) + self.site_runner: web.AppRunner | None = None + + async def cog_load(self) -> None: + """Initialize the callback web server when the cog is loaded.""" + self.site_runner = web.AppRunner(self.site_app) + await self.site_runner.setup() + + # Listen on 0.0.0.0 and the configured WEB_PORT + site = web.TCPSite(self.site_runner, "0.0.0.0", CONFIG.WEB_PORT) + await site.start() + logger.info(f"Verification callback server listening on port {CONFIG.WEB_PORT}") + + async def cog_unload(self) -> None: + """Clean up the web server when the cog is unloaded.""" + if self.site_runner: + await self.site_runner.cleanup() + logger.info("Verification callback server stopped") + + def generate_challenge(self) -> dict[str, str]: + """Generate a random challenge and verifier for Roblox PKCE. + + Follows the S256 code challenge method. + """ + # Create a URL-safe verifier (between 43 and 128 chars) + verifier = secrets.token_urlsafe(64) + + # SHA256 hash the verifier + hash_digest = hashlib.sha256(verifier.encode()).digest() + + # Base64URL encode the digest and remove padding (=) + challenge = base64.urlsafe_b64encode(hash_digest).decode().replace("=", "") + + return {"verifier": verifier, "challenge": challenge} + + @commands.hybrid_command(name="link", aliases=["v", "verify"]) + @commands.guild_only() + @app_commands.guild_only() + async def link(self, ctx: commands.Context[Tux]) -> None: + """Link your Discord account with Roblox to gain an XP boost.""" + # Check if already verified + result = await self.bot.db.verification.get_by_discord_id(ctx.author.id) + if result: + await ctx.send( + f"You are already verified as Roblox user **{result.roblox_username or result.roblox_id}**.", + ephemeral=True, + ) + return + + state = secrets.token_urlsafe(16) + params = self.generate_challenge() + + # Store verification session data (valid for 10 minutes) + cache = self.bot.cache_service.get_client() if self.bot.cache_service else None + + if cache: + await cache.set(f"verify:discord_id:{state}", str(ctx.author.id), ex=600) + if ctx.guild: + await cache.set(f"verify:guild_id:{state}", str(ctx.guild.id), ex=600) + await cache.set(f"verify:verifier:{state}", params["verifier"], ex=600) + else: + self._pending_verifications[state] = { + "discord_id": ctx.author.id, + "guild_id": ctx.guild.id if ctx.guild else 0, + } + self._verifiers[state] = params["verifier"] + + client_id = CONFIG.OAUTH2_CLIENTID + redirect_uri = "https://verify.spst.dev/callback" + scopes = "openid profile group:read" + + auth_url = ( + "https://apis.roblox.com/oauth/v1/authorize" + f"?client_id={client_id}" + f"&code_challenge={quote(params['challenge'])}" + f"&code_challenge_method=S256" + f"&redirect_uri={quote(redirect_uri)}" + f"&scope={quote(scopes)}" + f"&response_type=code" + f"&state={quote(state)}" + ) + + await ctx.send( + "To verify your Roblox account, please click the button below and follow the instructions.\n" + "This link will expire in 10 minutes.", + view=VerificationView(auth_url), + ephemeral=True, + ) + + @commands.hybrid_command(name="unlink") + @commands.guild_only() + @app_commands.guild_only() + async def unlink(self, ctx: commands.Context[Tux]) -> None: + """Unlink your Discord account from Roblox.""" + result = await self.bot.db.verification.get_by_discord_id(ctx.author.id) + if not result: + await ctx.send( + "You are not verified.", + ephemeral=True, + ) + return + await self.bot.db.verification.delete(result.discord_id) + await ctx.send( + "Your Discord account has been unlinked from Roblox.", + ephemeral=True, + ) + + async def handle_callback(self, request: web.Request) -> web.Response: + """Handle the incoming OAuth2 callback from Roblox.""" + code = request.query.get("code") + state = request.query.get("state") + + if not code or not state: + return web.Response( + text="Invalid callback: missing code or state.", + status=400, + ) + + discord_id = None + code_verifier = None + + # Retrieve session data + cache = self.bot.cache_service.get_client() if self.bot.cache_service else None + if cache: + discord_id_val = await cache.get(f"verify:discord_id:{state}") + discord_id = int(discord_id_val) if discord_id_val else None + guild_id_val = await cache.get(f"verify:guild_id:{state}") + guild_id = int(guild_id_val) if guild_id_val else None + code_verifier = await cache.get(f"verify:verifier:{state}") + + # Clean up cache + await cache.delete(f"verify:discord_id:{state}") + await cache.delete(f"verify:guild_id:{state}") + await cache.delete(f"verify:verifier:{state}") + else: + session = self._pending_verifications.pop(state, {}) + discord_id = session.get("discord_id") + guild_id = session.get("guild_id") + code_verifier = self._verifiers.pop(state, None) + + if not discord_id or not code_verifier or not guild_id: + return web.Response( + text="Verification session expired or invalid. Please run /verify again in Discord.", + status=400, + ) + + try: + # Exchange code for access token + token_response = await http_client.post( + "https://apis.roblox.com/oauth/v1/token", + data={ + "client_id": CONFIG.OAUTH2_CLIENTID, + "client_secret": CONFIG.OAUTH2_SECRET, + "grant_type": "authorization_code", + "code": code, + "code_verifier": code_verifier, + }, + headers={"Content-Type": "application/x-www-form-urlencoded"}, + ) + + token_data = token_response.json() + access_token = token_data.get("access_token") + + if not access_token: + logger.error(f"Failed to obtain Roblox access token: {token_data}") + return web.Response( + text="Failed to authenticate with Roblox.", + status=500, + ) + + # Fetch Roblox user info + user_response = await http_client.get( + "https://apis.roblox.com/oauth/v1/userinfo", + headers={"Authorization": f"Bearer {access_token}"}, + ) + user_data = user_response.json() + + roblox_id = int(user_data["sub"]) + roblox_username = ( + user_data.get("preferred_username") + or user_data.get("nickname") + or str(roblox_id) + ) + + # Store the linkage in our database using the controller + await self.bot.db.verification.upsert_verification( + discord_id=discord_id, + roblox_id=roblox_id, + roblox_username=roblox_username, + ) + + logger.info( + f"Verified Discord user {discord_id} as Roblox user {roblox_username} ({roblox_id})", + ) + + # Apply roles in all configured guilds + await self._apply_roles(discord_id, roblox_id, access_token) + + user = self.bot.get_user(discord_id) + name = user.name if user else "Unknown" + + return web.HTTPFound( + f"https://spst.dev/verify?success=true&rbx={roblox_username}&dc={name}", + ) + + except Exception: + logger.exception("Error during Roblox OAuth2 callback processing") + return web.HTTPFound( + "https://spst.dev/verify?success=false", + ) + + async def _apply_roles( # noqa: PLR0912 + self, + discord_id: int, + roblox_id: int, + access_token: str, + ) -> None: # sourcery skip: low-code-quality + """Apply verification roles to a user in all configured guilds.""" + group_check_cache: dict[int, bool] = {} + + for g_id, g_config in CONFIG.VERIFICATION.GUILDS.items(): + if g_id == 0: + continue + + try: + guild = self.bot.get_guild(g_id) + if not guild: + continue + + try: + member = await guild.fetch_member(discord_id) + except discord.NotFound: + continue + except Exception as me: + logger.warning( + f"Failed to fetch member {discord_id} in guild {g_id}: {me}", + ) + continue + + target_group_id = g_config.ROBLOX_GROUP_ID + if target_group_id not in group_check_cache: + is_member = False + try: + resp = await http_client.get( + f"https://apis.roblox.com/cloud/v2/groups/{target_group_id}/memberships/users%2F{roblox_id}", + headers={"Authorization": f"Bearer {access_token}"}, + ) + if resp.status_code == 200: + is_member = True + except httpx.HTTPStatusError as ge: + if ge.response.status_code == 404: + logger.debug( + f"Roblox user {roblox_id} is not in group {target_group_id} (404).", + ) + else: + logger.warning( + f"Roblox API error checking group {target_group_id}: {ge}", + ) + except Exception as ge: + logger.warning( + f"Failed to check group membership for {target_group_id}: {ge}", + ) + group_check_cache[target_group_id] = is_member + + is_group_member = group_check_cache.get(target_group_id, False) + roles_to_add: list[discord.Role] = [] + + if (v_role_id := g_config.VERIFIED_ROLE_ID) and ( + v_role := guild.get_role(v_role_id) + ): + roles_to_add.append(v_role) + + if ( + is_group_member + and (g_role_id := g_config.GROUP_MEMBER_ROLE_ID) + and (g_role := guild.get_role(g_role_id)) + ): + roles_to_add.append(g_role) + + if roles_to_add: # noqa: SIM102 + if to_add := [r for r in roles_to_add if r not in member.roles]: + await member.add_roles(*to_add, reason="Roblox Verification") + logger.info(f"Added roles {to_add} to {member} in {guild.name}") + + except Exception as e: + logger.warning(f"Error processing roles for guild {g_id}: {e}") + + def _get_success_html(self, username: str) -> str: + """Return a simple HTML success page.""" + return f""" + + + Verification Successful + + + +
+

Verification Successful!

+

You have been verified as {username}.

+

You can now close this window and return to Discord.

+
+ + + """ + + +class VerificationView(discord.ui.View): + """Simple view with a button that links to the Roblox OAuth2 URL.""" + + def __init__(self, url: str) -> None: + super().__init__() + self.add_item( + discord.ui.Button( + label="Verify with Roblox", + url=url, + style=discord.ButtonStyle.link, + ), + ) + + +async def setup(bot: Tux) -> None: + """Register the verify cog with the bot.""" + await bot.add_cog(Verify(bot)) diff --git a/src/tux/shared/config/models.py b/src/tux/shared/config/models.py index 013ebac9..8895fed6 100644 --- a/src/tux/shared/config/models.py +++ b/src/tux/shared/config/models.py @@ -376,6 +376,35 @@ class Moderation(BaseModel): ] +class VerificationConfig(BaseModel): + """Per-guild verification configuration.""" + + VERIFIED_ROLE_ID: int = Field( + default=0, + description="Role ID to give to verified users", + ) + GROUP_MEMBER_ROLE_ID: int = Field( + default=0, + description="Role ID to give to users in the specified Roblox group", + ) + ROBLOX_GROUP_ID: int = Field( + default=16185131, + description="Roblox group ID to check for membership", + ) + + +class Verification(BaseModel): + """Verification configuration.""" + + GUILDS: Annotated[ + dict[int, VerificationConfig], + Field( + default_factory=lambda: {0: VerificationConfig()}, + description="Per-server verification settings (use guild ID 0 for defaults)", + ), + ] + + class ExternalServices(BaseModel): """External services configuration.""" diff --git a/src/tux/shared/config/settings.py b/src/tux/shared/config/settings.py index ce4c5cbf..a073758c 100644 --- a/src/tux/shared/config/settings.py +++ b/src/tux/shared/config/settings.py @@ -44,6 +44,7 @@ StatusRoles, TempVC, UserIds, + Verification, ) @@ -301,6 +302,29 @@ class Config(BaseSettings): SNIPPETS: Snippets = Field(default_factory=Snippets) # type: ignore[arg-type] IRC_CONFIG: IRC = Field(default_factory=IRC) # type: ignore[arg-type] MODERATION: Moderation = Field(default_factory=Moderation) # type: ignore[arg-type] + OAUTH2_CLIENTID: Annotated[ + str, + Field( + default="", + description="Roblox OAuth2 Client ID", + ), + ] + OAUTH2_SECRET: Annotated[ + str, + Field( + default="", + description="Roblox OAuth2 Client Secret", + ), + ] + WEB_PORT: Annotated[ + int, + Field( + default=8080, + description="Port for the bot's web server", + ), + ] + + VERIFICATION: Verification = Field(default_factory=Verification) # type: ignore[arg-type] # External services EXTERNAL_SERVICES: ExternalServices = Field(default_factory=ExternalServices) # type: ignore[arg-type]