Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 89 additions & 23 deletions src/opengradient/client/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
import logging
import ssl
from dataclasses import dataclass
from typing import AsyncGenerator, Dict, List, Optional, Union
from typing import AsyncGenerator, Awaitable, Callable, Dict, List, Optional, TypeVar, Union
import httpx

from eth_account import Account
from eth_account.account import LocalAccount
Expand All @@ -19,6 +20,7 @@
from .tee_registry import TEERegistry, build_ssl_context_from_der

logger = logging.getLogger(__name__)
T = TypeVar("T")

DEFAULT_RPC_URL = "https://ogevmdevnet.opengradient.ai"
DEFAULT_TEE_REGISTRY_ADDRESS = "0x4e72238852f3c918f4E4e57AeC9280dDB0c80248"
Expand Down Expand Up @@ -94,32 +96,44 @@ def __init__(
llm_server_url: Optional[str] = None,
):
self._wallet_account: LocalAccount = Account.from_key(private_key)
self._rpc_url = rpc_url
self._tee_registry_address = tee_registry_address
self._llm_server_url = llm_server_url

# x402 payment stack (created once, reused across TEE refreshes)
signer = EthAccountSigner(self._wallet_account)
self._x402_client = x402Client()
register_exact_evm_client(self._x402_client, signer, networks=[BASE_TESTNET_NETWORK])
register_upto_evm_client(self._x402_client, signer, networks=[BASE_TESTNET_NETWORK])

self._connect_tee()

# ── TEE resolution and connection ───────────────────────────────────────────

def _connect_tee(self) -> None:
"""Resolve TEE from registry and create a secure HTTP client for it."""
endpoint, tls_cert_der, tee_id, tee_payment_address = self._resolve_tee(
llm_server_url,
rpc_url,
tee_registry_address,
self._llm_server_url,
self._rpc_url,
self._tee_registry_address,
)

self._tee_id = tee_id
self._tee_endpoint = endpoint
self._tee_payment_address = tee_payment_address

ssl_ctx = build_ssl_context_from_der(tls_cert_der) if tls_cert_der else None
# When connecting directly via llm_server_url, skip cert verification —
# self-hosted TEE servers commonly use self-signed certificates.
verify_ssl = llm_server_url is None
self._tls_verify: Union[ssl.SSLContext, bool] = ssl_ctx if ssl_ctx else verify_ssl

# x402 client and signer
signer = EthAccountSigner(self._wallet_account)
self._x402_client = x402Client()
register_exact_evm_client(self._x402_client, signer, networks=[BASE_TESTNET_NETWORK])
register_upto_evm_client(self._x402_client, signer, networks=[BASE_TESTNET_NETWORK])
# httpx.AsyncClient subclass - construction is sync, connections open lazily
self._tls_verify: Union[ssl.SSLContext, bool] = ssl_ctx if ssl_ctx else (self._llm_server_url is None)
self._http_client = x402HttpxClient(self._x402_client, verify=self._tls_verify)

# ── TEE resolution ──────────────────────────────────────────────────
async def _refresh_tee(self) -> None:
"""Re-resolve TEE from the registry and rebuild the HTTP client."""
old_http_client = self._http_client
self._connect_tee()
try:
await old_http_client.aclose()
except Exception:
logger.debug("Failed to close previous HTTP client during TEE refresh.", exc_info=True)


@staticmethod
def _resolve_tee(
Expand Down Expand Up @@ -188,6 +202,29 @@ def _tee_metadata(self) -> Dict:
tee_payment_address=self._tee_payment_address,
)

async def _call_with_tee_retry(
self,
operation_name: str,
call: Callable[[], Awaitable[T]],
) -> T:
"""Execute *call*; on connection failure, pick a new TEE and retry once.

Only retries when the request never reached the server (no HTTP response).
Server-side errors (4xx/5xx) are not retried.
"""
try:
return await call()
except httpx.HTTPStatusError:
raise
except Exception as exc:
logger.warning(
"Connection failure during %s; refreshing TEE and retrying once: %s",
operation_name,
exc,
)
await self._refresh_tee()
return await call()

# ── Public API ──────────────────────────────────────────────────────

def ensure_opg_approval(self, opg_amount: float) -> Permit2ApprovalResult:
Expand Down Expand Up @@ -248,7 +285,6 @@ async def completion(
RuntimeError: If the inference fails.
"""
model_id = model.split("/")[1]
headers = self._headers(x402_settlement_mode)
payload: Dict = {
"model": model_id,
"prompt": prompt,
Expand All @@ -258,11 +294,11 @@ async def completion(
if stop_sequence:
payload["stop"] = stop_sequence

try:
async def _request() -> TextGenerationOutput:
response = await self._http_client.post(
self._tee_endpoint + _COMPLETION_ENDPOINT,
json=payload,
headers=headers,
headers=self._headers(x402_settlement_mode),
timeout=_REQUEST_TIMEOUT,
)
response.raise_for_status()
Expand All @@ -275,6 +311,9 @@ async def completion(
tee_timestamp=result.get("tee_timestamp"),
**self._tee_metadata(),
)

try:
return await self._call_with_tee_retry("completion", _request)
except RuntimeError:
raise
except Exception as e:
Expand Down Expand Up @@ -342,14 +381,13 @@ async def chat(

async def _chat_request(self, params: _ChatParams, messages: List[Dict]) -> TextGenerationOutput:
"""Non-streaming chat request."""
headers = self._headers(params.x402_settlement_mode)
payload = self._chat_payload(params, messages)

try:
async def _request() -> TextGenerationOutput:
response = await self._http_client.post(
self._tee_endpoint + _CHAT_ENDPOINT,
json=payload,
headers=headers,
headers=self._headers(params.x402_settlement_mode),
timeout=_REQUEST_TIMEOUT,
)
response.raise_for_status()
Expand All @@ -375,6 +413,9 @@ async def _chat_request(self, params: _ChatParams, messages: List[Dict]) -> Text
tee_timestamp=result.get("tee_timestamp"),
**self._tee_metadata(),
)

try:
return await self._call_with_tee_retry("chat", _request)
except RuntimeError:
raise
except Exception as e:
Expand Down Expand Up @@ -410,6 +451,31 @@ async def _chat_stream(self, params: _ChatParams, messages: List[Dict]) -> Async
headers = self._headers(params.x402_settlement_mode)
payload = self._chat_payload(params, messages, stream=True)

chunks_yielded = False
try:
async with self._http_client.stream(
"POST",
self._tee_endpoint + _CHAT_ENDPOINT,
json=payload,
headers=headers,
timeout=_REQUEST_TIMEOUT,
) as response:
async for chunk in self._parse_sse_response(response):
chunks_yielded = True
yield chunk
return
except httpx.HTTPStatusError:
raise
except Exception as exc:
if chunks_yielded:
raise
logger.warning(
"Connection failure during stream setup; refreshing TEE and retrying once: %s",
exc,
)

await self._refresh_tee()
headers = self._headers(params.x402_settlement_mode)
async with self._http_client.stream(
"POST",
self._tee_endpoint + _CHAT_ENDPOINT,
Expand Down
Loading
Loading