Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 57 additions & 28 deletions src/opengradient/client/model_hub.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Model Hub for creating, versioning, and uploading ML models."""

import os
import time
from typing import Dict, List, Optional

import firebase # type: ignore[import-untyped]
Expand All @@ -19,6 +20,9 @@
"databaseURL": os.getenv("FIREBASE_DATABASE_URL", ""),
}

# Firebase idTokens expire after 3600 seconds; refresh 60 s before expiry
_TOKEN_REFRESH_MARGIN_SEC = 60


class ModelHub:
"""
Expand All @@ -34,15 +38,49 @@ class ModelHub:
"""

def __init__(self, email: Optional[str] = None, password: Optional[str] = None):
self._hub_user = self._login(email, password) if email is not None else None
self._firebase_app = None
self._hub_user = None
self._token_expiry: float = 0.0

if email is not None:
self._firebase_app, self._hub_user = self._login(email, password)
expires_in = int(self._hub_user.get("expiresIn", 3600))
self._token_expiry = time.time() + expires_in

@staticmethod
def _login(email: str, password: Optional[str]):
if not _FIREBASE_CONFIG.get("apiKey"):
raise ValueError("Firebase API Key is missing in environment variables")

firebase_app = firebase.initialize_app(_FIREBASE_CONFIG)
return firebase_app.auth().sign_in_with_email_and_password(email, password)
user = firebase_app.auth().sign_in_with_email_and_password(email, password)
return firebase_app, user

def _get_auth_token(self) -> str:
"""Return a valid Firebase idToken, refreshing it if it has expired or is
about to expire within ``_TOKEN_REFRESH_MARGIN_SEC`` seconds.

Raises:
ValueError: If the user is not authenticated.
"""
if not self._hub_user:
raise ValueError("User not authenticated")

if time.time() >= self._token_expiry - _TOKEN_REFRESH_MARGIN_SEC:
# Refresh the token using the stored refresh token
refresh_token = self._hub_user.get("refreshToken")
if not refresh_token or self._firebase_app is None:
raise ValueError(
"Cannot refresh Firebase token: missing refresh token or Firebase app. "
"Please re-authenticate by creating a new ModelHub instance."
)
refreshed = self._firebase_app.auth().refresh(refresh_token)
self._hub_user["idToken"] = refreshed["idToken"]
self._hub_user["refreshToken"] = refreshed.get("refreshToken", refresh_token)
expires_in = int(refreshed.get("expiresIn", 3600))
self._token_expiry = time.time() + expires_in

return str(self._hub_user["idToken"]) # cast Any->str for mypy [no-any-return]

def create_model(self, model_name: str, model_desc: str, version: str = "1.00") -> ModelRepository:
"""
Expand All @@ -51,19 +89,17 @@ def create_model(self, model_name: str, model_desc: str, version: str = "1.00")
Args:
model_name (str): The name of the model.
model_desc (str): The description of the model.
version (str): The version identifier (default is "1.00").
version (str): A label used in the initial version notes (default is "1.00").
Note: the actual version string is assigned by the server.

Returns:
dict: The server response containing model details.
ModelRepository: Object containing the model name and server-assigned version string.

Raises:
CreateModelError: If the model creation fails.
RuntimeError: If the model creation fails.
"""
if not self._hub_user:
raise ValueError("User not authenticated")

url = "https://api.opengradient.ai/api/v0/models/"
headers = {"Authorization": f"Bearer {self._hub_user['idToken']}", "Content-Type": "application/json"}
headers = {"Authorization": f"Bearer {self._get_auth_token()}", "Content-Type": "application/json"}
payload = {"name": model_name, "description": model_desc}

try:
Expand All @@ -74,14 +110,17 @@ def create_model(self, model_name: str, model_desc: str, version: str = "1.00")
raise RuntimeError(f"Model creation failed: {error_details}") from e

json_response = response.json()
model_name = json_response.get("name")
if not model_name:
created_name = json_response.get("name")
if not created_name:
raise Exception(f"Model creation response missing 'name'. Full response: {json_response}")

# Create the specified version for the newly created model
version_response = self.create_version(model_name, version)
# Create the initial version for the newly created model.
# Pass `version` as release notes (e.g. "1.00") since the server assigns
# its own version string — previously `version` was incorrectly passed as
# the positional `notes` argument, producing the same result but confusingly.
version_response = self.create_version(created_name, notes=f"Initial version {version}")

return ModelRepository(model_name, version_response["versionString"])
return ModelRepository(created_name, version_response["versionString"])

def create_version(self, model_name: str, notes: str = "", is_major: bool = False) -> dict:
"""
Expand All @@ -98,11 +137,8 @@ def create_version(self, model_name: str, notes: str = "", is_major: bool = Fals
Raises:
Exception: If the version creation fails.
"""
if not self._hub_user:
raise ValueError("User not authenticated")

url = f"https://api.opengradient.ai/api/v0/models/{model_name}/versions"
headers = {"Authorization": f"Bearer {self._hub_user['idToken']}", "Content-Type": "application/json"}
headers = {"Authorization": f"Bearer {self._get_auth_token()}", "Content-Type": "application/json"}
payload = {"notes": notes, "is_major": is_major}

try:
Expand Down Expand Up @@ -136,20 +172,16 @@ def upload(self, model_path: str, model_name: str, version: str) -> FileUploadRe
version (str): The version identifier for the model.

Returns:
dict: The processed result.
FileUploadResult: The processed result.

Raises:
RuntimeError: If the upload fails.
"""

if not self._hub_user:
raise ValueError("User not authenticated")

if not os.path.exists(model_path):
raise FileNotFoundError(f"Model file not found: {model_path}")

url = f"https://api.opengradient.ai/api/v0/models/{model_name}/versions/{version}/files"
headers = {"Authorization": f"Bearer {self._hub_user['idToken']}"}
headers = {"Authorization": f"Bearer {self._get_auth_token()}"}

try:
with open(model_path, "rb") as file:
Expand Down Expand Up @@ -191,11 +223,8 @@ def list_files(self, model_name: str, version: str) -> List[Dict]:
Raises:
RuntimeError: If the file listing fails.
"""
if not self._hub_user:
raise ValueError("User not authenticated")

url = f"https://api.opengradient.ai/api/v0/models/{model_name}/versions/{version}/files"
headers = {"Authorization": f"Bearer {self._hub_user['idToken']}"}
headers = {"Authorization": f"Bearer {self._get_auth_token()}"}

try:
response = requests.get(url, headers=headers)
Expand Down
7 changes: 4 additions & 3 deletions src/opengradient/client/opg_token.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,10 @@ def ensure_opg_approval(wallet_account: LocalAccount, opg_amount: float) -> Perm

allowance_before = token.functions.allowance(owner, spender).call()

# Only approve if the allowance is less than 10% of the requested amount
if allowance_before >= amount_base * 0.1:
# Only skip approval if the existing allowance fully covers the requested amount.
# Previously this used 0.1 * amount_base (10%), which was insufficient and caused
# downstream x402 payment failures when the allowance was between 10% and 100%.
if allowance_before >= amount_base:
return Permit2ApprovalResult(
allowance_before=allowance_before,
allowance_after=allowance_before,
Expand Down Expand Up @@ -124,7 +126,6 @@ def ensure_opg_approval(wallet_account: LocalAccount, opg_amount: float) -> Perm
)
time.sleep(ALLOWANCE_POLL_INTERVAL)


return Permit2ApprovalResult(
allowance_before=allowance_before,
allowance_after=allowance_after,
Expand Down
Loading