diff --git a/src/opengradient/client/model_hub.py b/src/opengradient/client/model_hub.py index d8f5c92..c2d63c3 100644 --- a/src/opengradient/client/model_hub.py +++ b/src/opengradient/client/model_hub.py @@ -1,6 +1,7 @@ """Model Hub for creating, versioning, and uploading ML models.""" import os +import time from typing import Dict, List, Optional import firebase # type: ignore[import-untyped] @@ -19,6 +20,9 @@ "databaseURL": os.getenv("FIREBASE_DATABASE_URL", ""), } +# Firebase idTokens expire after 3600 seconds; refresh 60 s before expiry +_TOKEN_REFRESH_MARGIN_SEC = 60 + class ModelHub: """ @@ -34,7 +38,14 @@ class ModelHub: """ def __init__(self, email: Optional[str] = None, password: Optional[str] = None): - self._hub_user = self._login(email, password) if email is not None else None + self._firebase_app = None + self._hub_user = None + self._token_expiry: float = 0.0 + + if email is not None: + self._firebase_app, self._hub_user = self._login(email, password) + expires_in = int(self._hub_user.get("expiresIn", 3600)) + self._token_expiry = time.time() + expires_in @staticmethod def _login(email: str, password: Optional[str]): @@ -42,7 +53,34 @@ def _login(email: str, password: Optional[str]): raise ValueError("Firebase API Key is missing in environment variables") firebase_app = firebase.initialize_app(_FIREBASE_CONFIG) - return firebase_app.auth().sign_in_with_email_and_password(email, password) + user = firebase_app.auth().sign_in_with_email_and_password(email, password) + return firebase_app, user + + def _get_auth_token(self) -> str: + """Return a valid Firebase idToken, refreshing it if it has expired or is + about to expire within ``_TOKEN_REFRESH_MARGIN_SEC`` seconds. + + Raises: + ValueError: If the user is not authenticated. + """ + if not self._hub_user: + raise ValueError("User not authenticated") + + if time.time() >= self._token_expiry - _TOKEN_REFRESH_MARGIN_SEC: + # Refresh the token using the stored refresh token + refresh_token = self._hub_user.get("refreshToken") + if not refresh_token or self._firebase_app is None: + raise ValueError( + "Cannot refresh Firebase token: missing refresh token or Firebase app. " + "Please re-authenticate by creating a new ModelHub instance." + ) + refreshed = self._firebase_app.auth().refresh(refresh_token) + self._hub_user["idToken"] = refreshed["idToken"] + self._hub_user["refreshToken"] = refreshed.get("refreshToken", refresh_token) + expires_in = int(refreshed.get("expiresIn", 3600)) + self._token_expiry = time.time() + expires_in + + return str(self._hub_user["idToken"]) # cast Any->str for mypy [no-any-return] def create_model(self, model_name: str, model_desc: str, version: str = "1.00") -> ModelRepository: """ @@ -51,19 +89,17 @@ def create_model(self, model_name: str, model_desc: str, version: str = "1.00") Args: model_name (str): The name of the model. model_desc (str): The description of the model. - version (str): The version identifier (default is "1.00"). + version (str): A label used in the initial version notes (default is "1.00"). + Note: the actual version string is assigned by the server. Returns: - dict: The server response containing model details. + ModelRepository: Object containing the model name and server-assigned version string. Raises: - CreateModelError: If the model creation fails. + RuntimeError: If the model creation fails. """ - if not self._hub_user: - raise ValueError("User not authenticated") - url = "https://api.opengradient.ai/api/v0/models/" - headers = {"Authorization": f"Bearer {self._hub_user['idToken']}", "Content-Type": "application/json"} + headers = {"Authorization": f"Bearer {self._get_auth_token()}", "Content-Type": "application/json"} payload = {"name": model_name, "description": model_desc} try: @@ -74,14 +110,17 @@ def create_model(self, model_name: str, model_desc: str, version: str = "1.00") raise RuntimeError(f"Model creation failed: {error_details}") from e json_response = response.json() - model_name = json_response.get("name") - if not model_name: + created_name = json_response.get("name") + if not created_name: raise Exception(f"Model creation response missing 'name'. Full response: {json_response}") - # Create the specified version for the newly created model - version_response = self.create_version(model_name, version) + # Create the initial version for the newly created model. + # Pass `version` as release notes (e.g. "1.00") since the server assigns + # its own version string — previously `version` was incorrectly passed as + # the positional `notes` argument, producing the same result but confusingly. + version_response = self.create_version(created_name, notes=f"Initial version {version}") - return ModelRepository(model_name, version_response["versionString"]) + return ModelRepository(created_name, version_response["versionString"]) def create_version(self, model_name: str, notes: str = "", is_major: bool = False) -> dict: """ @@ -98,11 +137,8 @@ def create_version(self, model_name: str, notes: str = "", is_major: bool = Fals Raises: Exception: If the version creation fails. """ - if not self._hub_user: - raise ValueError("User not authenticated") - url = f"https://api.opengradient.ai/api/v0/models/{model_name}/versions" - headers = {"Authorization": f"Bearer {self._hub_user['idToken']}", "Content-Type": "application/json"} + headers = {"Authorization": f"Bearer {self._get_auth_token()}", "Content-Type": "application/json"} payload = {"notes": notes, "is_major": is_major} try: @@ -136,20 +172,16 @@ def upload(self, model_path: str, model_name: str, version: str) -> FileUploadRe version (str): The version identifier for the model. Returns: - dict: The processed result. + FileUploadResult: The processed result. Raises: RuntimeError: If the upload fails. """ - - if not self._hub_user: - raise ValueError("User not authenticated") - if not os.path.exists(model_path): raise FileNotFoundError(f"Model file not found: {model_path}") url = f"https://api.opengradient.ai/api/v0/models/{model_name}/versions/{version}/files" - headers = {"Authorization": f"Bearer {self._hub_user['idToken']}"} + headers = {"Authorization": f"Bearer {self._get_auth_token()}"} try: with open(model_path, "rb") as file: @@ -191,11 +223,8 @@ def list_files(self, model_name: str, version: str) -> List[Dict]: Raises: RuntimeError: If the file listing fails. """ - if not self._hub_user: - raise ValueError("User not authenticated") - url = f"https://api.opengradient.ai/api/v0/models/{model_name}/versions/{version}/files" - headers = {"Authorization": f"Bearer {self._hub_user['idToken']}"} + headers = {"Authorization": f"Bearer {self._get_auth_token()}"} try: response = requests.get(url, headers=headers) diff --git a/src/opengradient/client/opg_token.py b/src/opengradient/client/opg_token.py index d86d9de..22363b5 100644 --- a/src/opengradient/client/opg_token.py +++ b/src/opengradient/client/opg_token.py @@ -82,8 +82,10 @@ def ensure_opg_approval(wallet_account: LocalAccount, opg_amount: float) -> Perm allowance_before = token.functions.allowance(owner, spender).call() - # Only approve if the allowance is less than 10% of the requested amount - if allowance_before >= amount_base * 0.1: + # Only skip approval if the existing allowance fully covers the requested amount. + # Previously this used 0.1 * amount_base (10%), which was insufficient and caused + # downstream x402 payment failures when the allowance was between 10% and 100%. + if allowance_before >= amount_base: return Permit2ApprovalResult( allowance_before=allowance_before, allowance_after=allowance_before, @@ -124,7 +126,6 @@ def ensure_opg_approval(wallet_account: LocalAccount, opg_amount: float) -> Perm ) time.sleep(ALLOWANCE_POLL_INTERVAL) - return Permit2ApprovalResult( allowance_before=allowance_before, allowance_after=allowance_after,