diff --git a/src/opengradient/client/llm.py b/src/opengradient/client/llm.py index eecfa148..f21912d0 100644 --- a/src/opengradient/client/llm.py +++ b/src/opengradient/client/llm.py @@ -4,7 +4,8 @@ import logging import ssl from dataclasses import dataclass -from typing import AsyncGenerator, Dict, List, Optional, Union +from typing import AsyncGenerator, Awaitable, Callable, Dict, List, Optional, TypeVar, Union +import httpx from eth_account import Account from eth_account.account import LocalAccount @@ -19,6 +20,7 @@ from .tee_registry import TEERegistry, build_ssl_context_from_der logger = logging.getLogger(__name__) +T = TypeVar("T") DEFAULT_RPC_URL = "https://ogevmdevnet.opengradient.ai" DEFAULT_TEE_REGISTRY_ADDRESS = "0x4e72238852f3c918f4E4e57AeC9280dDB0c80248" @@ -94,32 +96,44 @@ def __init__( llm_server_url: Optional[str] = None, ): self._wallet_account: LocalAccount = Account.from_key(private_key) + self._rpc_url = rpc_url + self._tee_registry_address = tee_registry_address + self._llm_server_url = llm_server_url + # x402 payment stack (created once, reused across TEE refreshes) + signer = EthAccountSigner(self._wallet_account) + self._x402_client = x402Client() + register_exact_evm_client(self._x402_client, signer, networks=[BASE_TESTNET_NETWORK]) + register_upto_evm_client(self._x402_client, signer, networks=[BASE_TESTNET_NETWORK]) + + self._connect_tee() + + # ── TEE resolution and connection ─────────────────────────────────────────── + + def _connect_tee(self) -> None: + """Resolve TEE from registry and create a secure HTTP client for it.""" endpoint, tls_cert_der, tee_id, tee_payment_address = self._resolve_tee( - llm_server_url, - rpc_url, - tee_registry_address, + self._llm_server_url, + self._rpc_url, + self._tee_registry_address, ) - self._tee_id = tee_id self._tee_endpoint = endpoint self._tee_payment_address = tee_payment_address ssl_ctx = build_ssl_context_from_der(tls_cert_der) if tls_cert_der else None - # When connecting directly via llm_server_url, skip cert verification — - # self-hosted TEE servers commonly use self-signed certificates. - verify_ssl = llm_server_url is None - self._tls_verify: Union[ssl.SSLContext, bool] = ssl_ctx if ssl_ctx else verify_ssl - - # x402 client and signer - signer = EthAccountSigner(self._wallet_account) - self._x402_client = x402Client() - register_exact_evm_client(self._x402_client, signer, networks=[BASE_TESTNET_NETWORK]) - register_upto_evm_client(self._x402_client, signer, networks=[BASE_TESTNET_NETWORK]) - # httpx.AsyncClient subclass - construction is sync, connections open lazily + self._tls_verify: Union[ssl.SSLContext, bool] = ssl_ctx if ssl_ctx else (self._llm_server_url is None) self._http_client = x402HttpxClient(self._x402_client, verify=self._tls_verify) - # ── TEE resolution ────────────────────────────────────────────────── + async def _refresh_tee(self) -> None: + """Re-resolve TEE from the registry and rebuild the HTTP client.""" + old_http_client = self._http_client + self._connect_tee() + try: + await old_http_client.aclose() + except Exception: + logger.debug("Failed to close previous HTTP client during TEE refresh.", exc_info=True) + @staticmethod def _resolve_tee( @@ -188,6 +202,29 @@ def _tee_metadata(self) -> Dict: tee_payment_address=self._tee_payment_address, ) + async def _call_with_tee_retry( + self, + operation_name: str, + call: Callable[[], Awaitable[T]], + ) -> T: + """Execute *call*; on connection failure, pick a new TEE and retry once. + + Only retries when the request never reached the server (no HTTP response). + Server-side errors (4xx/5xx) are not retried. + """ + try: + return await call() + except httpx.HTTPStatusError: + raise + except Exception as exc: + logger.warning( + "Connection failure during %s; refreshing TEE and retrying once: %s", + operation_name, + exc, + ) + await self._refresh_tee() + return await call() + # ── Public API ────────────────────────────────────────────────────── def ensure_opg_approval(self, opg_amount: float) -> Permit2ApprovalResult: @@ -248,7 +285,6 @@ async def completion( RuntimeError: If the inference fails. """ model_id = model.split("/")[1] - headers = self._headers(x402_settlement_mode) payload: Dict = { "model": model_id, "prompt": prompt, @@ -258,11 +294,11 @@ async def completion( if stop_sequence: payload["stop"] = stop_sequence - try: + async def _request() -> TextGenerationOutput: response = await self._http_client.post( self._tee_endpoint + _COMPLETION_ENDPOINT, json=payload, - headers=headers, + headers=self._headers(x402_settlement_mode), timeout=_REQUEST_TIMEOUT, ) response.raise_for_status() @@ -275,6 +311,9 @@ async def completion( tee_timestamp=result.get("tee_timestamp"), **self._tee_metadata(), ) + + try: + return await self._call_with_tee_retry("completion", _request) except RuntimeError: raise except Exception as e: @@ -342,14 +381,13 @@ async def chat( async def _chat_request(self, params: _ChatParams, messages: List[Dict]) -> TextGenerationOutput: """Non-streaming chat request.""" - headers = self._headers(params.x402_settlement_mode) payload = self._chat_payload(params, messages) - try: + async def _request() -> TextGenerationOutput: response = await self._http_client.post( self._tee_endpoint + _CHAT_ENDPOINT, json=payload, - headers=headers, + headers=self._headers(params.x402_settlement_mode), timeout=_REQUEST_TIMEOUT, ) response.raise_for_status() @@ -375,6 +413,9 @@ async def _chat_request(self, params: _ChatParams, messages: List[Dict]) -> Text tee_timestamp=result.get("tee_timestamp"), **self._tee_metadata(), ) + + try: + return await self._call_with_tee_retry("chat", _request) except RuntimeError: raise except Exception as e: @@ -410,6 +451,31 @@ async def _chat_stream(self, params: _ChatParams, messages: List[Dict]) -> Async headers = self._headers(params.x402_settlement_mode) payload = self._chat_payload(params, messages, stream=True) + chunks_yielded = False + try: + async with self._http_client.stream( + "POST", + self._tee_endpoint + _CHAT_ENDPOINT, + json=payload, + headers=headers, + timeout=_REQUEST_TIMEOUT, + ) as response: + async for chunk in self._parse_sse_response(response): + chunks_yielded = True + yield chunk + return + except httpx.HTTPStatusError: + raise + except Exception as exc: + if chunks_yielded: + raise + logger.warning( + "Connection failure during stream setup; refreshing TEE and retrying once: %s", + exc, + ) + + await self._refresh_tee() + headers = self._headers(params.x402_settlement_mode) async with self._http_client.stream( "POST", self._tee_endpoint + _CHAT_ENDPOINT, diff --git a/tests/llm_test.py b/tests/llm_test.py index bb845a75..44ee58a8 100644 --- a/tests/llm_test.py +++ b/tests/llm_test.py @@ -5,9 +5,10 @@ """ import json +import ssl from contextlib import asynccontextmanager from typing import List -from unittest.mock import MagicMock, patch +from unittest.mock import AsyncMock, MagicMock, patch import httpx import pytest @@ -31,6 +32,8 @@ def __init__(self, *_args, **_kwargs): self._response_body: bytes = b"{}" self._post_calls: List[dict] = [] self._stream_response = None + self._error_on_next: BaseException | None = None + self._stream_error_on_next: BaseException | None = None def set_response(self, status_code: int, body: dict) -> None: self._response_status = status_code @@ -43,8 +46,19 @@ def set_stream_response(self, status_code: int, chunks: List[bytes]) -> None: def post_calls(self) -> List[dict]: return self._post_calls + def fail_next_post(self, exc: BaseException) -> None: + """Make the next post() call raise *exc*, then revert to normal.""" + self._error_on_next = exc + + def fail_next_stream(self, exc: BaseException) -> None: + """Make the next stream() call raise *exc*, then revert to normal.""" + self._stream_error_on_next = exc + async def post(self, url: str, *, json=None, headers=None, timeout=None) -> "_FakeResponse": self._post_calls.append({"url": url, "json": json, "headers": headers, "timeout": timeout}) + if self._error_on_next is not None: + exc, self._error_on_next = self._error_on_next, None + raise exc resp = _FakeResponse(self._response_status, self._response_body) if self._response_status >= 400: resp.raise_for_status = MagicMock(side_effect=httpx.HTTPStatusError("error", request=MagicMock(), response=MagicMock())) @@ -53,6 +67,9 @@ async def post(self, url: str, *, json=None, headers=None, timeout=None) -> "_Fa @asynccontextmanager async def stream(self, method: str, url: str, *, json=None, headers=None, timeout=None): self._post_calls.append({"method": method, "url": url, "json": json, "headers": headers, "timeout": timeout}) + if self._stream_error_on_next is not None: + exc, self._stream_error_on_next = self._stream_error_on_next, None + raise exc yield self._stream_response async def aclose(self): @@ -535,3 +552,237 @@ def test_registry_success(self): assert cert == b"cert-bytes" assert tee_id == "tee-42" assert pay_addr == "0xPay" + + +# ── TEE retry tests (non-streaming) ────────────────────────────────── + + +@pytest.mark.asyncio +class TestTeeRetryCompletion: + async def test_retries_on_connection_error_and_succeeds(self, fake_http): + """First call hits connection error → refresh TEE → second call succeeds.""" + fake_http.set_response(200, {"completion": "retried ok", "tee_signature": "s", "tee_timestamp": "t"}) + fake_http.fail_next_post(ConnectionError("connection refused")) + llm = _make_llm() + + result = await llm.completion(model=TEE_LLM.GPT_5, prompt="Hi") + + assert result.completion_output == "retried ok" + assert len(fake_http.post_calls) == 2 + + async def test_http_status_error_not_retried(self, fake_http): + """A server-side error (HTTP 500) should not trigger a TEE retry.""" + fake_http.set_response(500, {"error": "boom"}) + llm = _make_llm() + + with pytest.raises(RuntimeError, match="TEE LLM completion failed"): + await llm.completion(model=TEE_LLM.GPT_5, prompt="Hi") + assert len(fake_http.post_calls) == 1 + + async def test_second_failure_propagates(self, fake_http): + """If the retry also fails, the error should propagate.""" + call_count = 0 + + async def always_fail(*args, **kwargs): + nonlocal call_count + call_count += 1 + raise ConnectionError("still broken") + + fake_http.post = always_fail + llm = _make_llm() + + with pytest.raises(RuntimeError, match="TEE LLM completion failed"): + await llm.completion(model=TEE_LLM.GPT_5, prompt="Hi") + assert call_count == 2 + + +@pytest.mark.asyncio +class TestTeeRetryChat: + async def test_retries_on_connection_error_and_succeeds(self, fake_http): + fake_http.set_response( + 200, + {"choices": [{"message": {"role": "assistant", "content": "retry ok"}, "finish_reason": "stop"}]}, + ) + fake_http.fail_next_post(OSError("network unreachable")) + llm = _make_llm() + + result = await llm.chat(model=TEE_LLM.GPT_5, messages=[{"role": "user", "content": "Hi"}]) + + assert result.chat_output["content"] == "retry ok" + assert len(fake_http.post_calls) == 2 + + async def test_http_status_error_not_retried(self, fake_http): + fake_http.set_response(500, {"error": "internal"}) + llm = _make_llm() + + with pytest.raises(RuntimeError, match="TEE LLM chat failed"): + await llm.chat(model=TEE_LLM.GPT_5, messages=[{"role": "user", "content": "Hi"}]) + assert len(fake_http.post_calls) == 1 + + +# ── TEE retry tests (streaming) ────────────────────────────────────── + + +@pytest.mark.asyncio +class TestTeeRetryStreaming: + async def test_retries_stream_on_connection_error_before_chunks(self, fake_http): + """Connection failure during stream setup (no chunks yielded) → retry succeeds.""" + fake_http.set_stream_response( + 200, + [ + b'data: {"model":"m","choices":[{"index":0,"delta":{"content":"ok"},"finish_reason":"stop"}]}\n\n', + b"data: [DONE]\n\n", + ], + ) + fake_http.fail_next_stream(ConnectionError("reset by peer")) + llm = _make_llm() + + gen = await llm.chat( + model=TEE_LLM.GPT_5, + messages=[{"role": "user", "content": "Hi"}], + stream=True, + ) + chunks = [c async for c in gen] + + assert len(chunks) == 1 + assert chunks[0].choices[0].delta.content == "ok" + assert len(fake_http.post_calls) == 2 + + async def test_no_retry_after_chunks_yielded(self, fake_http): + """Failure AFTER chunks were yielded must raise, not retry.""" + + class _FailMidStream: + def __init__(self): + self.status_code = 200 + + async def aiter_raw(self): + yield b'data: {"model":"m","choices":[{"index":0,"delta":{"content":"partial"},"finish_reason":null}]}\n\n' + raise ConnectionError("mid-stream disconnect") + + async def aread(self) -> bytes: + return b"" + + fake_http._stream_response = _FailMidStream() + llm = _make_llm() + + gen = await llm.chat( + model=TEE_LLM.GPT_5, + messages=[{"role": "user", "content": "Hi"}], + stream=True, + ) + + with pytest.raises(ConnectionError): + _ = [c async for c in gen] + + assert len(fake_http.post_calls) == 1 + + +# ── _refresh_tee tests ───────────────────────────────────── + + +@pytest.mark.asyncio +class TestRefreshTeeAndReset: + async def test_replaces_http_client(self): + """After refresh, the http client should be a new instance.""" + clients_created = [] + + def make_client(*args, **kwargs): + c = FakeHTTPClient() + clients_created.append(c) + return c + + with ( + patch(_PATCHES["x402_httpx"], side_effect=make_client), + patch(_PATCHES["x402_client"]), + patch(_PATCHES["signer"]), + patch(_PATCHES["register_exact"]), + patch(_PATCHES["register_upto"]), + ): + llm = _make_llm() + old_client = llm._http_client + + await llm._refresh_tee() + + assert llm._http_client is not old_client + assert len(clients_created) == 2 # init + refresh + + async def test_closes_old_client(self, fake_http): + llm = _make_llm() + old_client = llm._http_client + old_client.aclose = AsyncMock() + + await llm._refresh_tee() + + old_client.aclose.assert_awaited_once() + + async def test_close_failure_is_swallowed(self, fake_http): + llm = _make_llm() + old_client = llm._http_client + old_client.aclose = AsyncMock(side_effect=OSError("already closed")) + + # Should not raise + await llm._refresh_tee() + + +# ── TEE cert rotation (crash + re-register) tests ──────────────────── + + +@pytest.mark.asyncio +class TestTeeCertRotation: + """Simulate a TEE crashing and a new one registering at the same IP + with a different ephemeral TLS certificate. The old cert is now + invalid, so the first request fails with SSLCertVerificationError. + The retry should re-resolve from the registry (getting the new cert) + and succeed.""" + + async def test_ssl_verification_failure_triggers_tee_refresh_completion(self, fake_http): + fake_http.set_response(200, {"completion": "ok after refresh", "tee_signature": "s", "tee_timestamp": "t"}) + fake_http.fail_next_post(ssl.SSLCertVerificationError("certificate verify failed")) + llm = _make_llm() + + with patch.object(llm, "_connect_tee", wraps=llm._connect_tee) as spy: + result = await llm.completion(model=TEE_LLM.GPT_5, prompt="Hi") + + # _connect_tee was called once during the retry (refresh) + spy.assert_called_once() + assert result.completion_output == "ok after refresh" + assert len(fake_http.post_calls) == 2 + + async def test_ssl_verification_failure_triggers_tee_refresh_chat(self, fake_http): + fake_http.set_response( + 200, + {"choices": [{"message": {"role": "assistant", "content": "ok after refresh"}, "finish_reason": "stop"}]}, + ) + fake_http.fail_next_post(ssl.SSLCertVerificationError("certificate verify failed")) + llm = _make_llm() + + with patch.object(llm, "_connect_tee", wraps=llm._connect_tee) as spy: + result = await llm.chat(model=TEE_LLM.GPT_5, messages=[{"role": "user", "content": "Hi"}]) + + spy.assert_called_once() + assert result.chat_output["content"] == "ok after refresh" + assert len(fake_http.post_calls) == 2 + + async def test_ssl_verification_failure_triggers_tee_refresh_streaming(self, fake_http): + fake_http.set_stream_response( + 200, + [ + b'data: {"model":"m","choices":[{"index":0,"delta":{"content":"ok"},"finish_reason":"stop"}]}\n\n', + b"data: [DONE]\n\n", + ], + ) + fake_http.fail_next_stream(ssl.SSLCertVerificationError("certificate verify failed")) + llm = _make_llm() + + with patch.object(llm, "_connect_tee", wraps=llm._connect_tee) as spy: + gen = await llm.chat( + model=TEE_LLM.GPT_5, + messages=[{"role": "user", "content": "Hi"}], + stream=True, + ) + chunks = [c async for c in gen] + + spy.assert_called_once() + assert len(chunks) == 1 + assert chunks[0].choices[0].delta.content == "ok" + assert len(fake_http.post_calls) == 2