From f680d728cc59a3ed9496505702bb548438390c92 Mon Sep 17 00:00:00 2001 From: Aniket Dixit Date: Wed, 25 Mar 2026 19:39:43 +0530 Subject: [PATCH 1/9] cleanup --- examples/llm_chat.py | 2 +- examples/llm_chat_streaming.py | 2 +- src/opengradient/client/llm.py | 118 +++++++++++++++++++++++++++------ 3 files changed, 98 insertions(+), 24 deletions(-) diff --git a/examples/llm_chat.py b/examples/llm_chat.py index a2c599cb..15c4d6f4 100644 --- a/examples/llm_chat.py +++ b/examples/llm_chat.py @@ -10,7 +10,7 @@ async def main(): llm = og.LLM(private_key=os.environ.get("OG_PRIVATE_KEY")) - llm.ensure_opg_approval(opg_amount=0.1) + llm.ensure_opg_approval(opg_amount=1) messages = [ {"role": "user", "content": "What is the capital of France?"}, diff --git a/examples/llm_chat_streaming.py b/examples/llm_chat_streaming.py index 17db4774..8e34d550 100644 --- a/examples/llm_chat_streaming.py +++ b/examples/llm_chat_streaming.py @@ -6,7 +6,7 @@ async def main(): llm = og.LLM(private_key=os.environ.get("OG_PRIVATE_KEY")) - llm.ensure_opg_approval(opg_amount=0.1) + llm.ensure_opg_approval(opg_amount=1) messages = [ {"role": "user", "content": "What is Python?"}, diff --git a/src/opengradient/client/llm.py b/src/opengradient/client/llm.py index eecfa148..513f428d 100644 --- a/src/opengradient/client/llm.py +++ b/src/opengradient/client/llm.py @@ -3,8 +3,9 @@ import json import logging import ssl +import threading from dataclasses import dataclass -from typing import AsyncGenerator, Dict, List, Optional, Union +from typing import AsyncGenerator, Awaitable, Callable, Dict, List, Optional, TypeVar, Union from eth_account import Account from eth_account.account import LocalAccount @@ -19,6 +20,7 @@ from .tee_registry import TEERegistry, build_ssl_context_from_der logger = logging.getLogger(__name__) +T = TypeVar("T") DEFAULT_RPC_URL = "https://ogevmdevnet.opengradient.ai" DEFAULT_TEE_REGISTRY_ADDRESS = "0x4e72238852f3c918f4E4e57AeC9280dDB0c80248" @@ -94,31 +96,76 @@ def __init__( llm_server_url: Optional[str] = None, ): self._wallet_account: LocalAccount = Account.from_key(private_key) - - endpoint, tls_cert_der, tee_id, tee_payment_address = self._resolve_tee( - llm_server_url, - rpc_url, - tee_registry_address, + self._rpc_url = rpc_url + self._tee_registry_address = tee_registry_address + self._llm_server_url = llm_server_url + self._reset_lock = threading.Lock() + + self._refresh_tee_config() + self._init_x402_stack() + + def _refresh_tee_config(self) -> None: + """Resolve TEE metadata from the registry and update TLS config.""" + endpoint, tls_cert_der, tee_id, payment_addr = self._resolve_tee( + self._llm_server_url, self._rpc_url, self._tee_registry_address, ) - + ssl_ctx = build_ssl_context_from_der(tls_cert_der) if tls_cert_der else None self._tee_id = tee_id self._tee_endpoint = endpoint - self._tee_payment_address = tee_payment_address - - ssl_ctx = build_ssl_context_from_der(tls_cert_der) if tls_cert_der else None - # When connecting directly via llm_server_url, skip cert verification — - # self-hosted TEE servers commonly use self-signed certificates. - verify_ssl = llm_server_url is None - self._tls_verify: Union[ssl.SSLContext, bool] = ssl_ctx if ssl_ctx else verify_ssl + self._tee_payment_address = payment_addr + self._tls_verify: Union[ssl.SSLContext, bool] = ssl_ctx if ssl_ctx else (self._llm_server_url is None) - # x402 client and signer + def _init_x402_stack(self) -> None: + """Initialize x402 signer/client/http stack.""" signer = EthAccountSigner(self._wallet_account) self._x402_client = x402Client() register_exact_evm_client(self._x402_client, signer, networks=[BASE_TESTNET_NETWORK]) register_upto_evm_client(self._x402_client, signer, networks=[BASE_TESTNET_NETWORK]) - # httpx.AsyncClient subclass - construction is sync, connections open lazily self._http_client = x402HttpxClient(self._x402_client, verify=self._tls_verify) + @staticmethod + def _has_ssl_cause(exc: BaseException) -> bool: + """Return true when the exception chain contains an SSL error.""" + visited: set[int] = set() + current: Optional[BaseException] = exc + while current is not None and id(current) not in visited: + visited.add(id(current)) + if isinstance(current, ssl.SSLError): + return True + current = current.__cause__ or current.__context__ + return False + + async def _refresh_tee_and_reset(self) -> None: + """Re-resolve TEE and rebuild the HTTP client with fresh TLS config.""" + with self._reset_lock: + old_http_client = self._http_client + self._refresh_tee_config() + self._init_x402_stack() + + try: + await old_http_client.aclose() + except Exception: + logger.debug("Failed to close previous HTTP client during TEE refresh.", exc_info=True) + + async def _call_with_ssl_retry( + self, + operation_name: str, + call: Callable[[], Awaitable[T]], + ) -> T: + """Retry once with fresh TEE/TLS state when the failure is SSL-related.""" + try: + return await call() + except Exception as exc: + if not self._has_ssl_cause(exc): + raise + logger.warning( + "SSL failure during %s; refreshing TEE and retrying once: %s", + operation_name, + exc, + ) + await self._refresh_tee_and_reset() + return await call() + # ── TEE resolution ────────────────────────────────────────────────── @staticmethod @@ -248,7 +295,6 @@ async def completion( RuntimeError: If the inference fails. """ model_id = model.split("/")[1] - headers = self._headers(x402_settlement_mode) payload: Dict = { "model": model_id, "prompt": prompt, @@ -258,11 +304,11 @@ async def completion( if stop_sequence: payload["stop"] = stop_sequence - try: + async def _request() -> TextGenerationOutput: response = await self._http_client.post( self._tee_endpoint + _COMPLETION_ENDPOINT, json=payload, - headers=headers, + headers=self._headers(x402_settlement_mode), timeout=_REQUEST_TIMEOUT, ) response.raise_for_status() @@ -275,6 +321,9 @@ async def completion( tee_timestamp=result.get("tee_timestamp"), **self._tee_metadata(), ) + + try: + return await self._call_with_ssl_retry("completion", _request) except RuntimeError: raise except Exception as e: @@ -342,14 +391,13 @@ async def chat( async def _chat_request(self, params: _ChatParams, messages: List[Dict]) -> TextGenerationOutput: """Non-streaming chat request.""" - headers = self._headers(params.x402_settlement_mode) payload = self._chat_payload(params, messages) - try: + async def _request() -> TextGenerationOutput: response = await self._http_client.post( self._tee_endpoint + _CHAT_ENDPOINT, json=payload, - headers=headers, + headers=self._headers(params.x402_settlement_mode), timeout=_REQUEST_TIMEOUT, ) response.raise_for_status() @@ -375,6 +423,9 @@ async def _chat_request(self, params: _ChatParams, messages: List[Dict]) -> Text tee_timestamp=result.get("tee_timestamp"), **self._tee_metadata(), ) + + try: + return await self._call_with_ssl_retry("chat", _request) except RuntimeError: raise except Exception as e: @@ -410,6 +461,29 @@ async def _chat_stream(self, params: _ChatParams, messages: List[Dict]) -> Async headers = self._headers(params.x402_settlement_mode) payload = self._chat_payload(params, messages, stream=True) + chunks_yielded = False + try: + async with self._http_client.stream( + "POST", + self._tee_endpoint + _CHAT_ENDPOINT, + json=payload, + headers=headers, + timeout=_REQUEST_TIMEOUT, + ) as response: + async for chunk in self._parse_sse_response(response): + chunks_yielded = True + yield chunk + return + except Exception as exc: + if chunks_yielded or not self._has_ssl_cause(exc): + raise + logger.warning( + "SSL failure during stream setup; refreshing TEE and retrying once: %s", + exc, + ) + + await self._refresh_tee_and_reset() + headers = self._headers(params.x402_settlement_mode) async with self._http_client.stream( "POST", self._tee_endpoint + _CHAT_ENDPOINT, From 5e1e10a59ff0013c6035f286319941331dd463ef Mon Sep 17 00:00:00 2001 From: Aniket Dixit Date: Wed, 25 Mar 2026 20:55:21 +0530 Subject: [PATCH 2/9] tests --- tests/llm_test.py | 264 +++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 263 insertions(+), 1 deletion(-) diff --git a/tests/llm_test.py b/tests/llm_test.py index bb845a75..06a536c1 100644 --- a/tests/llm_test.py +++ b/tests/llm_test.py @@ -5,9 +5,10 @@ """ import json +import ssl from contextlib import asynccontextmanager from typing import List -from unittest.mock import MagicMock, patch +from unittest.mock import AsyncMock, MagicMock, patch import httpx import pytest @@ -31,6 +32,8 @@ def __init__(self, *_args, **_kwargs): self._response_body: bytes = b"{}" self._post_calls: List[dict] = [] self._stream_response = None + self._error_on_next: BaseException | None = None + self._stream_error_on_next: BaseException | None = None def set_response(self, status_code: int, body: dict) -> None: self._response_status = status_code @@ -43,8 +46,19 @@ def set_stream_response(self, status_code: int, chunks: List[bytes]) -> None: def post_calls(self) -> List[dict]: return self._post_calls + def fail_next_post(self, exc: BaseException) -> None: + """Make the next post() call raise *exc*, then revert to normal.""" + self._error_on_next = exc + + def fail_next_stream(self, exc: BaseException) -> None: + """Make the next stream() call raise *exc*, then revert to normal.""" + self._stream_error_on_next = exc + async def post(self, url: str, *, json=None, headers=None, timeout=None) -> "_FakeResponse": self._post_calls.append({"url": url, "json": json, "headers": headers, "timeout": timeout}) + if self._error_on_next is not None: + exc, self._error_on_next = self._error_on_next, None + raise exc resp = _FakeResponse(self._response_status, self._response_body) if self._response_status >= 400: resp.raise_for_status = MagicMock(side_effect=httpx.HTTPStatusError("error", request=MagicMock(), response=MagicMock())) @@ -53,6 +67,9 @@ async def post(self, url: str, *, json=None, headers=None, timeout=None) -> "_Fa @asynccontextmanager async def stream(self, method: str, url: str, *, json=None, headers=None, timeout=None): self._post_calls.append({"method": method, "url": url, "json": json, "headers": headers, "timeout": timeout}) + if self._stream_error_on_next is not None: + exc, self._stream_error_on_next = self._stream_error_on_next, None + raise exc yield self._stream_response async def aclose(self): @@ -535,3 +552,248 @@ def test_registry_success(self): assert cert == b"cert-bytes" assert tee_id == "tee-42" assert pay_addr == "0xPay" + + +# ── SSL cause detection tests ──────────────────────────────────────── + + +class TestHasSSLCause: + def test_direct_ssl_error(self): + assert LLM._has_ssl_cause(ssl.SSLError("cert expired")) is True + + def test_wrapped_ssl_error_via_cause(self): + root = ssl.SSLError("handshake failed") + wrapper = RuntimeError("connection error") + wrapper.__cause__ = root + assert LLM._has_ssl_cause(wrapper) is True + + def test_wrapped_ssl_error_via_context(self): + root = ssl.SSLError("cert verify failed") + wrapper = OSError("transport error") + wrapper.__context__ = root + assert LLM._has_ssl_cause(wrapper) is True + + def test_deeply_nested_ssl(self): + root = ssl.SSLError("deep") + mid = OSError("mid") + mid.__cause__ = root + top = RuntimeError("top") + top.__cause__ = mid + assert LLM._has_ssl_cause(top) is True + + def test_non_ssl_error(self): + assert LLM._has_ssl_cause(ValueError("not ssl")) is False + + def test_non_ssl_chain(self): + root = TimeoutError("timed out") + wrapper = RuntimeError("oops") + wrapper.__cause__ = root + assert LLM._has_ssl_cause(wrapper) is False + + def test_cycle_detection(self): + """Self-referencing chain should not loop forever.""" + exc = RuntimeError("cycle") + exc.__cause__ = exc + assert LLM._has_ssl_cause(exc) is False + + +# ── SSL retry tests (non-streaming) ────────────────────────────────── + + +def _make_ssl_error(msg: str = "certificate verify failed") -> Exception: + """Create a RuntimeError wrapping an SSLError, mimicking httpx behaviour.""" + root = ssl.SSLError(msg) + wrapper = RuntimeError(f"connection failed: {msg}") + wrapper.__cause__ = root + return wrapper + + +@pytest.mark.asyncio +class TestSSLRetryCompletion: + async def test_retries_on_ssl_and_succeeds(self, fake_http): + """First call hits SSL error → refresh → second call succeeds.""" + fake_http.set_response(200, {"completion": "retried ok", "tee_signature": "s", "tee_timestamp": "t"}) + fake_http.fail_next_post(_make_ssl_error()) + llm = _make_llm() + + result = await llm.completion(model=TEE_LLM.GPT_5, prompt="Hi") + + assert result.completion_output == "retried ok" + # Two post calls: the failed one + the retry + assert len(fake_http.post_calls) == 2 + + async def test_non_ssl_error_not_retried(self, fake_http): + """A non-SSL error should propagate immediately, no retry.""" + fake_http.fail_next_post(ConnectionError("DNS failed")) + llm = _make_llm() + + with pytest.raises(RuntimeError, match="TEE LLM completion failed"): + await llm.completion(model=TEE_LLM.GPT_5, prompt="Hi") + assert len(fake_http.post_calls) == 1 + + async def test_second_ssl_failure_propagates(self, fake_http): + """If the retry also hits SSL, the error should propagate.""" + fake_http.set_response(200, {"completion": "ok"}) + + call_count = 0 + + async def always_ssl(*args, **kwargs): + nonlocal call_count + call_count += 1 + raise _make_ssl_error() + + fake_http.post = always_ssl + llm = _make_llm() + + # The second SSL error bubbles out of _call_with_ssl_retry as a + # RuntimeError (our wrapper), then caught by completion()'s outer + # handler which re-wraps it. + with pytest.raises(RuntimeError): + await llm.completion(model=TEE_LLM.GPT_5, prompt="Hi") + assert call_count == 2 # original + retry + + +@pytest.mark.asyncio +class TestSSLRetryChat: + async def test_retries_on_ssl_and_succeeds(self, fake_http): + fake_http.set_response( + 200, + {"choices": [{"message": {"role": "assistant", "content": "retry ok"}, "finish_reason": "stop"}]}, + ) + fake_http.fail_next_post(_make_ssl_error()) + llm = _make_llm() + + result = await llm.chat(model=TEE_LLM.GPT_5, messages=[{"role": "user", "content": "Hi"}]) + + assert result.chat_output["content"] == "retry ok" + assert len(fake_http.post_calls) == 2 + + async def test_non_ssl_error_not_retried(self, fake_http): + fake_http.fail_next_post(TimeoutError("timed out")) + llm = _make_llm() + + with pytest.raises(RuntimeError, match="TEE LLM chat failed"): + await llm.chat(model=TEE_LLM.GPT_5, messages=[{"role": "user", "content": "Hi"}]) + assert len(fake_http.post_calls) == 1 + + +# ── SSL retry tests (streaming) ────────────────────────────────────── + + +@pytest.mark.asyncio +class TestSSLRetryStreaming: + async def test_retries_stream_on_ssl_before_chunks(self, fake_http): + """SSL failure during stream setup (no chunks yielded) → retry succeeds.""" + fake_http.set_stream_response( + 200, + [ + b'data: {"model":"m","choices":[{"index":0,"delta":{"content":"ok"},"finish_reason":"stop"}]}\n\n', + b"data: [DONE]\n\n", + ], + ) + fake_http.fail_next_stream(_make_ssl_error()) + llm = _make_llm() + + gen = await llm.chat( + model=TEE_LLM.GPT_5, + messages=[{"role": "user", "content": "Hi"}], + stream=True, + ) + chunks = [c async for c in gen] + + assert len(chunks) == 1 + assert chunks[0].choices[0].delta.content == "ok" + # Two stream attempts: failed + retry + assert len(fake_http.post_calls) == 2 + + async def test_no_retry_after_chunks_yielded(self, fake_http): + """SSL failure AFTER chunks were yielded must raise, not retry.""" + + class _FailMidStream: + """Yields one chunk then raises SSL.""" + + def __init__(self): + self.status_code = 200 + + async def aiter_raw(self): + yield b'data: {"model":"m","choices":[{"index":0,"delta":{"content":"partial"},"finish_reason":null}]}\n\n' + raise _make_ssl_error() + + async def aread(self) -> bytes: + return b"" + + fake_http._stream_response = _FailMidStream() + llm = _make_llm() + + gen = await llm.chat( + model=TEE_LLM.GPT_5, + messages=[{"role": "user", "content": "Hi"}], + stream=True, + ) + + with pytest.raises(RuntimeError): + _ = [c async for c in gen] + + # Only one stream call — no retry after partial output + assert len(fake_http.post_calls) == 1 + + async def test_non_ssl_stream_error_not_retried(self, fake_http): + fake_http.fail_next_stream(ConnectionError("reset")) + llm = _make_llm() + + gen = await llm.chat( + model=TEE_LLM.GPT_5, + messages=[{"role": "user", "content": "Hi"}], + stream=True, + ) + + with pytest.raises(ConnectionError): + _ = [c async for c in gen] + assert len(fake_http.post_calls) == 1 + + +# ── _refresh_tee_and_reset tests ───────────────────────────────────── + + +@pytest.mark.asyncio +class TestRefreshTeeAndReset: + async def test_replaces_http_client(self): + """After refresh, the http client should be a new instance.""" + clients_created = [] + + def make_client(*args, **kwargs): + c = FakeHTTPClient() + clients_created.append(c) + return c + + with ( + patch(_PATCHES["x402_httpx"], side_effect=make_client), + patch(_PATCHES["x402_client"]), + patch(_PATCHES["signer"]), + patch(_PATCHES["register_exact"]), + patch(_PATCHES["register_upto"]), + ): + llm = _make_llm() + old_client = llm._http_client + + await llm._refresh_tee_and_reset() + + assert llm._http_client is not old_client + assert len(clients_created) == 2 # init + refresh + + async def test_closes_old_client(self, fake_http): + llm = _make_llm() + old_client = llm._http_client + old_client.aclose = AsyncMock() + + await llm._refresh_tee_and_reset() + + old_client.aclose.assert_awaited_once() + + async def test_close_failure_is_swallowed(self, fake_http): + llm = _make_llm() + old_client = llm._http_client + old_client.aclose = AsyncMock(side_effect=OSError("already closed")) + + # Should not raise + await llm._refresh_tee_and_reset() From 1030563265a86c822b19aff210ea2aa5ece1fcb9 Mon Sep 17 00:00:00 2001 From: Aniket Dixit Date: Wed, 25 Mar 2026 22:00:37 +0530 Subject: [PATCH 3/9] cleanup --- src/opengradient/client/llm.py | 12 +++++++++--- tests/llm_test.py | 20 ++++++++++++++++---- 2 files changed, 25 insertions(+), 7 deletions(-) diff --git a/src/opengradient/client/llm.py b/src/opengradient/client/llm.py index 513f428d..e2097c75 100644 --- a/src/opengradient/client/llm.py +++ b/src/opengradient/client/llm.py @@ -126,13 +126,19 @@ def _init_x402_stack(self) -> None: @staticmethod def _has_ssl_cause(exc: BaseException) -> bool: """Return true when the exception chain contains an SSL error.""" + stack: list[BaseException] = [exc] visited: set[int] = set() - current: Optional[BaseException] = exc - while current is not None and id(current) not in visited: + while stack: + current = stack.pop() + if id(current) in visited: + continue visited.add(id(current)) if isinstance(current, ssl.SSLError): return True - current = current.__cause__ or current.__context__ + if current.__cause__ is not None: + stack.append(current.__cause__) + if current.__context__ is not None: + stack.append(current.__context__) return False async def _refresh_tee_and_reset(self) -> None: diff --git a/tests/llm_test.py b/tests/llm_test.py index 06a536c1..c8eb3df6 100644 --- a/tests/llm_test.py +++ b/tests/llm_test.py @@ -590,6 +590,20 @@ def test_non_ssl_chain(self): wrapper.__cause__ = root assert LLM._has_ssl_cause(wrapper) is False + def test_ssl_on_context_when_cause_is_dead_end(self): + """SSL sits on __context__ while __cause__ leads to a non-SSL chain.""" + ssl_root = ssl.SSLError("cert expired") + context_mid = OSError("transport") + context_mid.__cause__ = ssl_root # __context__ branch has SSL + + cause_dead_end = ValueError("unrelated") # __cause__ branch, no SSL + + top = RuntimeError("top") + top.__cause__ = cause_dead_end + top.__context__ = context_mid + + assert LLM._has_ssl_cause(top) is True + def test_cycle_detection(self): """Self-referencing chain should not loop forever.""" exc = RuntimeError("cycle") @@ -645,10 +659,8 @@ async def always_ssl(*args, **kwargs): fake_http.post = always_ssl llm = _make_llm() - # The second SSL error bubbles out of _call_with_ssl_retry as a - # RuntimeError (our wrapper), then caught by completion()'s outer - # handler which re-wraps it. - with pytest.raises(RuntimeError): + # The retry's RuntimeError is re-raised as-is (no double-wrapping). + with pytest.raises(RuntimeError, match="connection failed"): await llm.completion(model=TEE_LLM.GPT_5, prompt="Hi") assert call_count == 2 # original + retry From 451ee6970df947c56ae04fa9382954ff5ce2d0ac Mon Sep 17 00:00:00 2001 From: Aniket Dixit Date: Wed, 25 Mar 2026 22:14:17 +0530 Subject: [PATCH 4/9] cleanup --- src/opengradient/client/llm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/opengradient/client/llm.py b/src/opengradient/client/llm.py index e2097c75..5d81cad5 100644 --- a/src/opengradient/client/llm.py +++ b/src/opengradient/client/llm.py @@ -146,7 +146,7 @@ async def _refresh_tee_and_reset(self) -> None: with self._reset_lock: old_http_client = self._http_client self._refresh_tee_config() - self._init_x402_stack() + self._http_client = x402HttpxClient(self._x402_client, verify=self._tls_verify) try: await old_http_client.aclose() From e6b4491347c9df9d9bfcbd2216b04c6f02b6a3d0 Mon Sep 17 00:00:00 2001 From: Aniket Dixit Date: Wed, 25 Mar 2026 23:41:24 +0530 Subject: [PATCH 5/9] cleanup --- src/opengradient/client/llm.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/src/opengradient/client/llm.py b/src/opengradient/client/llm.py index 5d81cad5..da3172b9 100644 --- a/src/opengradient/client/llm.py +++ b/src/opengradient/client/llm.py @@ -3,7 +3,6 @@ import json import logging import ssl -import threading from dataclasses import dataclass from typing import AsyncGenerator, Awaitable, Callable, Dict, List, Optional, TypeVar, Union @@ -99,8 +98,6 @@ def __init__( self._rpc_url = rpc_url self._tee_registry_address = tee_registry_address self._llm_server_url = llm_server_url - self._reset_lock = threading.Lock() - self._refresh_tee_config() self._init_x402_stack() @@ -143,10 +140,9 @@ def _has_ssl_cause(exc: BaseException) -> bool: async def _refresh_tee_and_reset(self) -> None: """Re-resolve TEE and rebuild the HTTP client with fresh TLS config.""" - with self._reset_lock: - old_http_client = self._http_client - self._refresh_tee_config() - self._http_client = x402HttpxClient(self._x402_client, verify=self._tls_verify) + old_http_client = self._http_client + self._refresh_tee_config() + self._http_client = x402HttpxClient(self._x402_client, verify=self._tls_verify) try: await old_http_client.aclose() From a81269efdbaadea4a12e02c1d96fd0c7c24b74da Mon Sep 17 00:00:00 2001 From: "balogh.adam@icloud.com" Date: Wed, 25 Mar 2026 21:10:17 -0400 Subject: [PATCH 6/9] lower limits --- examples/llm_chat.py | 2 +- examples/llm_chat_streaming.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/llm_chat.py b/examples/llm_chat.py index 15c4d6f4..a2c599cb 100644 --- a/examples/llm_chat.py +++ b/examples/llm_chat.py @@ -10,7 +10,7 @@ async def main(): llm = og.LLM(private_key=os.environ.get("OG_PRIVATE_KEY")) - llm.ensure_opg_approval(opg_amount=1) + llm.ensure_opg_approval(opg_amount=0.1) messages = [ {"role": "user", "content": "What is the capital of France?"}, diff --git a/examples/llm_chat_streaming.py b/examples/llm_chat_streaming.py index 8e34d550..17db4774 100644 --- a/examples/llm_chat_streaming.py +++ b/examples/llm_chat_streaming.py @@ -6,7 +6,7 @@ async def main(): llm = og.LLM(private_key=os.environ.get("OG_PRIVATE_KEY")) - llm.ensure_opg_approval(opg_amount=1) + llm.ensure_opg_approval(opg_amount=0.1) messages = [ {"role": "user", "content": "What is Python?"}, From 3bf9bbf0776cf50d0527a2266301027805453006 Mon Sep 17 00:00:00 2001 From: "balogh.adam@icloud.com" Date: Wed, 25 Mar 2026 21:29:54 -0400 Subject: [PATCH 7/9] test --- src/opengradient/client/llm.py | 90 +++++++------- tests/llm_test.py | 215 +++++++++++++++------------------ 2 files changed, 137 insertions(+), 168 deletions(-) diff --git a/src/opengradient/client/llm.py b/src/opengradient/client/llm.py index da3172b9..9fa0aef6 100644 --- a/src/opengradient/client/llm.py +++ b/src/opengradient/client/llm.py @@ -6,6 +6,10 @@ from dataclasses import dataclass from typing import AsyncGenerator, Awaitable, Callable, Dict, List, Optional, TypeVar, Union +import httpx + +import httpx + from eth_account import Account from eth_account.account import LocalAccount from x402 import x402Client @@ -98,74 +102,60 @@ def __init__( self._rpc_url = rpc_url self._tee_registry_address = tee_registry_address self._llm_server_url = llm_server_url - self._refresh_tee_config() - self._init_x402_stack() - def _refresh_tee_config(self) -> None: - """Resolve TEE metadata from the registry and update TLS config.""" - endpoint, tls_cert_der, tee_id, payment_addr = self._resolve_tee( - self._llm_server_url, self._rpc_url, self._tee_registry_address, - ) - ssl_ctx = build_ssl_context_from_der(tls_cert_der) if tls_cert_der else None - self._tee_id = tee_id - self._tee_endpoint = endpoint - self._tee_payment_address = payment_addr - self._tls_verify: Union[ssl.SSLContext, bool] = ssl_ctx if ssl_ctx else (self._llm_server_url is None) - - def _init_x402_stack(self) -> None: - """Initialize x402 signer/client/http stack.""" + # x402 payment stack (created once, reused across TEE refreshes) signer = EthAccountSigner(self._wallet_account) self._x402_client = x402Client() register_exact_evm_client(self._x402_client, signer, networks=[BASE_TESTNET_NETWORK]) register_upto_evm_client(self._x402_client, signer, networks=[BASE_TESTNET_NETWORK]) - self._http_client = x402HttpxClient(self._x402_client, verify=self._tls_verify) - @staticmethod - def _has_ssl_cause(exc: BaseException) -> bool: - """Return true when the exception chain contains an SSL error.""" - stack: list[BaseException] = [exc] - visited: set[int] = set() - while stack: - current = stack.pop() - if id(current) in visited: - continue - visited.add(id(current)) - if isinstance(current, ssl.SSLError): - return True - if current.__cause__ is not None: - stack.append(current.__cause__) - if current.__context__ is not None: - stack.append(current.__context__) - return False - - async def _refresh_tee_and_reset(self) -> None: - """Re-resolve TEE and rebuild the HTTP client with fresh TLS config.""" - old_http_client = self._http_client - self._refresh_tee_config() + self._connect_tee() + + def _connect_tee(self) -> None: + """Resolve TEE from registry and create a secure HTTP client for it.""" + endpoint, tls_cert_der, tee_id, tee_payment_address = self._resolve_tee( + self._llm_server_url, + self._rpc_url, + self._tee_registry_address, + ) + self._tee_id = tee_id + self._tee_endpoint = endpoint + self._tee_payment_address = tee_payment_address + + ssl_ctx = build_ssl_context_from_der(tls_cert_der) if tls_cert_der else None + self._tls_verify: Union[ssl.SSLContext, bool] = ssl_ctx if ssl_ctx else (self._llm_server_url is None) self._http_client = x402HttpxClient(self._x402_client, verify=self._tls_verify) + async def _refresh_tee(self) -> None: + """Re-resolve TEE from the registry and rebuild the HTTP client.""" + old_http_client = self._http_client + self._connect_tee() try: await old_http_client.aclose() except Exception: logger.debug("Failed to close previous HTTP client during TEE refresh.", exc_info=True) - async def _call_with_ssl_retry( + async def _call_with_tee_retry( self, operation_name: str, call: Callable[[], Awaitable[T]], ) -> T: - """Retry once with fresh TEE/TLS state when the failure is SSL-related.""" + """Execute *call*; on connection failure, pick a new TEE and retry once. + + Only retries when the request never reached the server (no HTTP response). + Server-side errors (4xx/5xx) are not retried. + """ try: return await call() + except httpx.HTTPStatusError: + raise except Exception as exc: - if not self._has_ssl_cause(exc): - raise logger.warning( - "SSL failure during %s; refreshing TEE and retrying once: %s", + "Connection failure during %s; refreshing TEE and retrying once: %s", operation_name, exc, ) - await self._refresh_tee_and_reset() + await self._refresh_tee() return await call() # ── TEE resolution ────────────────────────────────────────────────── @@ -325,7 +315,7 @@ async def _request() -> TextGenerationOutput: ) try: - return await self._call_with_ssl_retry("completion", _request) + return await self._call_with_tee_retry("completion", _request) except RuntimeError: raise except Exception as e: @@ -427,7 +417,7 @@ async def _request() -> TextGenerationOutput: ) try: - return await self._call_with_ssl_retry("chat", _request) + return await self._call_with_tee_retry("chat", _request) except RuntimeError: raise except Exception as e: @@ -476,15 +466,17 @@ async def _chat_stream(self, params: _ChatParams, messages: List[Dict]) -> Async chunks_yielded = True yield chunk return + except httpx.HTTPStatusError: + raise except Exception as exc: - if chunks_yielded or not self._has_ssl_cause(exc): + if chunks_yielded: raise logger.warning( - "SSL failure during stream setup; refreshing TEE and retrying once: %s", + "Connection failure during stream setup; refreshing TEE and retrying once: %s", exc, ) - await self._refresh_tee_and_reset() + await self._refresh_tee() headers = self._headers(params.x402_settlement_mode) async with self._http_client.stream( "POST", diff --git a/tests/llm_test.py b/tests/llm_test.py index c8eb3df6..44ee58a8 100644 --- a/tests/llm_test.py +++ b/tests/llm_test.py @@ -554,125 +554,56 @@ def test_registry_success(self): assert pay_addr == "0xPay" -# ── SSL cause detection tests ──────────────────────────────────────── - - -class TestHasSSLCause: - def test_direct_ssl_error(self): - assert LLM._has_ssl_cause(ssl.SSLError("cert expired")) is True - - def test_wrapped_ssl_error_via_cause(self): - root = ssl.SSLError("handshake failed") - wrapper = RuntimeError("connection error") - wrapper.__cause__ = root - assert LLM._has_ssl_cause(wrapper) is True - - def test_wrapped_ssl_error_via_context(self): - root = ssl.SSLError("cert verify failed") - wrapper = OSError("transport error") - wrapper.__context__ = root - assert LLM._has_ssl_cause(wrapper) is True - - def test_deeply_nested_ssl(self): - root = ssl.SSLError("deep") - mid = OSError("mid") - mid.__cause__ = root - top = RuntimeError("top") - top.__cause__ = mid - assert LLM._has_ssl_cause(top) is True - - def test_non_ssl_error(self): - assert LLM._has_ssl_cause(ValueError("not ssl")) is False - - def test_non_ssl_chain(self): - root = TimeoutError("timed out") - wrapper = RuntimeError("oops") - wrapper.__cause__ = root - assert LLM._has_ssl_cause(wrapper) is False - - def test_ssl_on_context_when_cause_is_dead_end(self): - """SSL sits on __context__ while __cause__ leads to a non-SSL chain.""" - ssl_root = ssl.SSLError("cert expired") - context_mid = OSError("transport") - context_mid.__cause__ = ssl_root # __context__ branch has SSL - - cause_dead_end = ValueError("unrelated") # __cause__ branch, no SSL - - top = RuntimeError("top") - top.__cause__ = cause_dead_end - top.__context__ = context_mid - - assert LLM._has_ssl_cause(top) is True - - def test_cycle_detection(self): - """Self-referencing chain should not loop forever.""" - exc = RuntimeError("cycle") - exc.__cause__ = exc - assert LLM._has_ssl_cause(exc) is False - - -# ── SSL retry tests (non-streaming) ────────────────────────────────── - - -def _make_ssl_error(msg: str = "certificate verify failed") -> Exception: - """Create a RuntimeError wrapping an SSLError, mimicking httpx behaviour.""" - root = ssl.SSLError(msg) - wrapper = RuntimeError(f"connection failed: {msg}") - wrapper.__cause__ = root - return wrapper +# ── TEE retry tests (non-streaming) ────────────────────────────────── @pytest.mark.asyncio -class TestSSLRetryCompletion: - async def test_retries_on_ssl_and_succeeds(self, fake_http): - """First call hits SSL error → refresh → second call succeeds.""" +class TestTeeRetryCompletion: + async def test_retries_on_connection_error_and_succeeds(self, fake_http): + """First call hits connection error → refresh TEE → second call succeeds.""" fake_http.set_response(200, {"completion": "retried ok", "tee_signature": "s", "tee_timestamp": "t"}) - fake_http.fail_next_post(_make_ssl_error()) + fake_http.fail_next_post(ConnectionError("connection refused")) llm = _make_llm() result = await llm.completion(model=TEE_LLM.GPT_5, prompt="Hi") assert result.completion_output == "retried ok" - # Two post calls: the failed one + the retry assert len(fake_http.post_calls) == 2 - async def test_non_ssl_error_not_retried(self, fake_http): - """A non-SSL error should propagate immediately, no retry.""" - fake_http.fail_next_post(ConnectionError("DNS failed")) + async def test_http_status_error_not_retried(self, fake_http): + """A server-side error (HTTP 500) should not trigger a TEE retry.""" + fake_http.set_response(500, {"error": "boom"}) llm = _make_llm() with pytest.raises(RuntimeError, match="TEE LLM completion failed"): await llm.completion(model=TEE_LLM.GPT_5, prompt="Hi") assert len(fake_http.post_calls) == 1 - async def test_second_ssl_failure_propagates(self, fake_http): - """If the retry also hits SSL, the error should propagate.""" - fake_http.set_response(200, {"completion": "ok"}) - + async def test_second_failure_propagates(self, fake_http): + """If the retry also fails, the error should propagate.""" call_count = 0 - async def always_ssl(*args, **kwargs): + async def always_fail(*args, **kwargs): nonlocal call_count call_count += 1 - raise _make_ssl_error() + raise ConnectionError("still broken") - fake_http.post = always_ssl + fake_http.post = always_fail llm = _make_llm() - # The retry's RuntimeError is re-raised as-is (no double-wrapping). - with pytest.raises(RuntimeError, match="connection failed"): + with pytest.raises(RuntimeError, match="TEE LLM completion failed"): await llm.completion(model=TEE_LLM.GPT_5, prompt="Hi") - assert call_count == 2 # original + retry + assert call_count == 2 @pytest.mark.asyncio -class TestSSLRetryChat: - async def test_retries_on_ssl_and_succeeds(self, fake_http): +class TestTeeRetryChat: + async def test_retries_on_connection_error_and_succeeds(self, fake_http): fake_http.set_response( 200, {"choices": [{"message": {"role": "assistant", "content": "retry ok"}, "finish_reason": "stop"}]}, ) - fake_http.fail_next_post(_make_ssl_error()) + fake_http.fail_next_post(OSError("network unreachable")) llm = _make_llm() result = await llm.chat(model=TEE_LLM.GPT_5, messages=[{"role": "user", "content": "Hi"}]) @@ -680,8 +611,8 @@ async def test_retries_on_ssl_and_succeeds(self, fake_http): assert result.chat_output["content"] == "retry ok" assert len(fake_http.post_calls) == 2 - async def test_non_ssl_error_not_retried(self, fake_http): - fake_http.fail_next_post(TimeoutError("timed out")) + async def test_http_status_error_not_retried(self, fake_http): + fake_http.set_response(500, {"error": "internal"}) llm = _make_llm() with pytest.raises(RuntimeError, match="TEE LLM chat failed"): @@ -689,13 +620,13 @@ async def test_non_ssl_error_not_retried(self, fake_http): assert len(fake_http.post_calls) == 1 -# ── SSL retry tests (streaming) ────────────────────────────────────── +# ── TEE retry tests (streaming) ────────────────────────────────────── @pytest.mark.asyncio -class TestSSLRetryStreaming: - async def test_retries_stream_on_ssl_before_chunks(self, fake_http): - """SSL failure during stream setup (no chunks yielded) → retry succeeds.""" +class TestTeeRetryStreaming: + async def test_retries_stream_on_connection_error_before_chunks(self, fake_http): + """Connection failure during stream setup (no chunks yielded) → retry succeeds.""" fake_http.set_stream_response( 200, [ @@ -703,7 +634,7 @@ async def test_retries_stream_on_ssl_before_chunks(self, fake_http): b"data: [DONE]\n\n", ], ) - fake_http.fail_next_stream(_make_ssl_error()) + fake_http.fail_next_stream(ConnectionError("reset by peer")) llm = _make_llm() gen = await llm.chat( @@ -715,21 +646,18 @@ async def test_retries_stream_on_ssl_before_chunks(self, fake_http): assert len(chunks) == 1 assert chunks[0].choices[0].delta.content == "ok" - # Two stream attempts: failed + retry assert len(fake_http.post_calls) == 2 async def test_no_retry_after_chunks_yielded(self, fake_http): - """SSL failure AFTER chunks were yielded must raise, not retry.""" + """Failure AFTER chunks were yielded must raise, not retry.""" class _FailMidStream: - """Yields one chunk then raises SSL.""" - def __init__(self): self.status_code = 200 async def aiter_raw(self): yield b'data: {"model":"m","choices":[{"index":0,"delta":{"content":"partial"},"finish_reason":null}]}\n\n' - raise _make_ssl_error() + raise ConnectionError("mid-stream disconnect") async def aread(self) -> bytes: return b"" @@ -743,28 +671,13 @@ async def aread(self) -> bytes: stream=True, ) - with pytest.raises(RuntimeError): - _ = [c async for c in gen] - - # Only one stream call — no retry after partial output - assert len(fake_http.post_calls) == 1 - - async def test_non_ssl_stream_error_not_retried(self, fake_http): - fake_http.fail_next_stream(ConnectionError("reset")) - llm = _make_llm() - - gen = await llm.chat( - model=TEE_LLM.GPT_5, - messages=[{"role": "user", "content": "Hi"}], - stream=True, - ) - with pytest.raises(ConnectionError): _ = [c async for c in gen] + assert len(fake_http.post_calls) == 1 -# ── _refresh_tee_and_reset tests ───────────────────────────────────── +# ── _refresh_tee tests ───────────────────────────────────── @pytest.mark.asyncio @@ -788,7 +701,7 @@ def make_client(*args, **kwargs): llm = _make_llm() old_client = llm._http_client - await llm._refresh_tee_and_reset() + await llm._refresh_tee() assert llm._http_client is not old_client assert len(clients_created) == 2 # init + refresh @@ -798,7 +711,7 @@ async def test_closes_old_client(self, fake_http): old_client = llm._http_client old_client.aclose = AsyncMock() - await llm._refresh_tee_and_reset() + await llm._refresh_tee() old_client.aclose.assert_awaited_once() @@ -808,4 +721,68 @@ async def test_close_failure_is_swallowed(self, fake_http): old_client.aclose = AsyncMock(side_effect=OSError("already closed")) # Should not raise - await llm._refresh_tee_and_reset() + await llm._refresh_tee() + + +# ── TEE cert rotation (crash + re-register) tests ──────────────────── + + +@pytest.mark.asyncio +class TestTeeCertRotation: + """Simulate a TEE crashing and a new one registering at the same IP + with a different ephemeral TLS certificate. The old cert is now + invalid, so the first request fails with SSLCertVerificationError. + The retry should re-resolve from the registry (getting the new cert) + and succeed.""" + + async def test_ssl_verification_failure_triggers_tee_refresh_completion(self, fake_http): + fake_http.set_response(200, {"completion": "ok after refresh", "tee_signature": "s", "tee_timestamp": "t"}) + fake_http.fail_next_post(ssl.SSLCertVerificationError("certificate verify failed")) + llm = _make_llm() + + with patch.object(llm, "_connect_tee", wraps=llm._connect_tee) as spy: + result = await llm.completion(model=TEE_LLM.GPT_5, prompt="Hi") + + # _connect_tee was called once during the retry (refresh) + spy.assert_called_once() + assert result.completion_output == "ok after refresh" + assert len(fake_http.post_calls) == 2 + + async def test_ssl_verification_failure_triggers_tee_refresh_chat(self, fake_http): + fake_http.set_response( + 200, + {"choices": [{"message": {"role": "assistant", "content": "ok after refresh"}, "finish_reason": "stop"}]}, + ) + fake_http.fail_next_post(ssl.SSLCertVerificationError("certificate verify failed")) + llm = _make_llm() + + with patch.object(llm, "_connect_tee", wraps=llm._connect_tee) as spy: + result = await llm.chat(model=TEE_LLM.GPT_5, messages=[{"role": "user", "content": "Hi"}]) + + spy.assert_called_once() + assert result.chat_output["content"] == "ok after refresh" + assert len(fake_http.post_calls) == 2 + + async def test_ssl_verification_failure_triggers_tee_refresh_streaming(self, fake_http): + fake_http.set_stream_response( + 200, + [ + b'data: {"model":"m","choices":[{"index":0,"delta":{"content":"ok"},"finish_reason":"stop"}]}\n\n', + b"data: [DONE]\n\n", + ], + ) + fake_http.fail_next_stream(ssl.SSLCertVerificationError("certificate verify failed")) + llm = _make_llm() + + with patch.object(llm, "_connect_tee", wraps=llm._connect_tee) as spy: + gen = await llm.chat( + model=TEE_LLM.GPT_5, + messages=[{"role": "user", "content": "Hi"}], + stream=True, + ) + chunks = [c async for c in gen] + + spy.assert_called_once() + assert len(chunks) == 1 + assert chunks[0].choices[0].delta.content == "ok" + assert len(fake_http.post_calls) == 2 From 0eb13c81ee2c8d993569bf719b95727fc2b77ea2 Mon Sep 17 00:00:00 2001 From: "balogh.adam@icloud.com" Date: Wed, 25 Mar 2026 21:30:37 -0400 Subject: [PATCH 8/9] import --- src/opengradient/client/llm.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/opengradient/client/llm.py b/src/opengradient/client/llm.py index 9fa0aef6..b590f8e8 100644 --- a/src/opengradient/client/llm.py +++ b/src/opengradient/client/llm.py @@ -5,9 +5,6 @@ import ssl from dataclasses import dataclass from typing import AsyncGenerator, Awaitable, Callable, Dict, List, Optional, TypeVar, Union - -import httpx - import httpx from eth_account import Account From 2c974ea8a4a7304104f3acb712c26f0854bed71f Mon Sep 17 00:00:00 2001 From: "balogh.adam@icloud.com" Date: Wed, 25 Mar 2026 21:33:42 -0400 Subject: [PATCH 9/9] move --- src/opengradient/client/llm.py | 49 +++++++++++++++++----------------- 1 file changed, 25 insertions(+), 24 deletions(-) diff --git a/src/opengradient/client/llm.py b/src/opengradient/client/llm.py index b590f8e8..f21912d0 100644 --- a/src/opengradient/client/llm.py +++ b/src/opengradient/client/llm.py @@ -108,6 +108,8 @@ def __init__( self._connect_tee() + # ── TEE resolution and connection ─────────────────────────────────────────── + def _connect_tee(self) -> None: """Resolve TEE from registry and create a secure HTTP client for it.""" endpoint, tls_cert_der, tee_id, tee_payment_address = self._resolve_tee( @@ -132,30 +134,6 @@ async def _refresh_tee(self) -> None: except Exception: logger.debug("Failed to close previous HTTP client during TEE refresh.", exc_info=True) - async def _call_with_tee_retry( - self, - operation_name: str, - call: Callable[[], Awaitable[T]], - ) -> T: - """Execute *call*; on connection failure, pick a new TEE and retry once. - - Only retries when the request never reached the server (no HTTP response). - Server-side errors (4xx/5xx) are not retried. - """ - try: - return await call() - except httpx.HTTPStatusError: - raise - except Exception as exc: - logger.warning( - "Connection failure during %s; refreshing TEE and retrying once: %s", - operation_name, - exc, - ) - await self._refresh_tee() - return await call() - - # ── TEE resolution ────────────────────────────────────────────────── @staticmethod def _resolve_tee( @@ -224,6 +202,29 @@ def _tee_metadata(self) -> Dict: tee_payment_address=self._tee_payment_address, ) + async def _call_with_tee_retry( + self, + operation_name: str, + call: Callable[[], Awaitable[T]], + ) -> T: + """Execute *call*; on connection failure, pick a new TEE and retry once. + + Only retries when the request never reached the server (no HTTP response). + Server-side errors (4xx/5xx) are not retried. + """ + try: + return await call() + except httpx.HTTPStatusError: + raise + except Exception as exc: + logger.warning( + "Connection failure during %s; refreshing TEE and retrying once: %s", + operation_name, + exc, + ) + await self._refresh_tee() + return await call() + # ── Public API ────────────────────────────────────────────────────── def ensure_opg_approval(self, opg_amount: float) -> Permit2ApprovalResult: