diff --git a/pyproject.toml b/pyproject.toml index a9daa644..884a1074 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,7 +27,7 @@ dependencies = [ "langchain>=0.3.7", "openai>=1.58.1", "pydantic>=2.9.2", - "og-x402==0.0.1.dev2" + "og-x402==0.0.1.dev4" ] [project.optional-dependencies] diff --git a/src/opengradient/agents/__init__.py b/src/opengradient/agents/__init__.py index 082f7064..aa6a35db 100644 --- a/src/opengradient/agents/__init__.py +++ b/src/opengradient/agents/__init__.py @@ -6,15 +6,22 @@ into existing applications and agent frameworks. """ +from ..client.llm import LLM from ..types import TEE_LLM, x402SettlementMode from .og_langchain import * def langchain_adapter( - private_key: str, - model_cid: TEE_LLM, + private_key: str | None = None, + model_cid: TEE_LLM | str | None = None, + model: TEE_LLM | str | None = None, max_tokens: int = 300, + temperature: float = 0.0, x402_settlement_mode: x402SettlementMode = x402SettlementMode.BATCH_HASHED, + client: LLM | None = None, + rpc_url: str | None = None, + tee_registry_address: str | None = None, + llm_server_url: str | None = None, ) -> OpenGradientChatModel: """ Returns an OpenGradient LLM that implements LangChain's LLM interface @@ -22,9 +29,14 @@ def langchain_adapter( """ return OpenGradientChatModel( private_key=private_key, - model_cid=model_cid, + client=client, + model_cid=model_cid or model, max_tokens=max_tokens, + temperature=temperature, x402_settlement_mode=x402_settlement_mode, + rpc_url=rpc_url, + tee_registry_address=tee_registry_address, + llm_server_url=llm_server_url, ) diff --git a/src/opengradient/agents/og_langchain.py b/src/opengradient/agents/og_langchain.py index 4f238a57..bcb02f47 100644 --- a/src/opengradient/agents/og_langchain.py +++ b/src/opengradient/agents/og_langchain.py @@ -1,29 +1,34 @@ # mypy: ignore-errors import asyncio import json -from typing import Any, Callable, Dict, List, Optional, Sequence, Union +from enum import Enum +from typing import Any, AsyncIterator, Awaitable, Callable, Dict, Iterator, List, Optional, Sequence, Union, cast -from langchain_core.callbacks.manager import CallbackManagerForLLMRun +from langchain_core.callbacks.manager import AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun from langchain_core.language_models.base import LanguageModelInput from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.messages import ( AIMessage, + AIMessageChunk, BaseMessage, + ChatMessage, HumanMessage, SystemMessage, ToolCall, ) -from langchain_core.messages.tool import ToolMessage +from langchain_core.messages.tool import ToolCallChunk, ToolMessage from langchain_core.outputs import ( ChatGeneration, + ChatGenerationChunk, ChatResult, ) from langchain_core.runnables import Runnable from langchain_core.tools import BaseTool +from langchain_core.utils.function_calling import convert_to_openai_tool from pydantic import PrivateAttr from ..client.llm import LLM -from ..types import TEE_LLM, x402SettlementMode +from ..types import StreamChunk, TEE_LLM, TextGenerationOutput, x402SettlementMode __all__ = ["OpenGradientChatModel"] @@ -47,7 +52,29 @@ def _extract_content(content: Any) -> str: return str(content) if content else "" -def _parse_tool_call(tool_call: Dict) -> ToolCall: +def _parse_tool_args(raw_args: Any) -> Dict[str, Any]: + if isinstance(raw_args, dict): + return raw_args + if raw_args is None or raw_args == "": + return {} + if isinstance(raw_args, str): + try: + parsed = json.loads(raw_args) + return parsed if isinstance(parsed, dict) else {} + except json.JSONDecodeError: + return {} + return {} + + +def _serialize_tool_args(raw_args: Any) -> str: + if raw_args is None: + return "{}" + if isinstance(raw_args, str): + return raw_args + return json.dumps(raw_args) + + +def _parse_tool_call(tool_call: Dict[str, Any]) -> ToolCall: """Parse a tool call from the API response. Handles both flat format {"id", "name", "arguments"} and @@ -58,86 +85,191 @@ def _parse_tool_call(tool_call: Dict) -> ToolCall: return ToolCall( id=tool_call.get("id", ""), name=func["name"], - args=json.loads(func.get("arguments", "{}")), + args=_parse_tool_args(func.get("arguments")), ) return ToolCall( id=tool_call.get("id", ""), name=tool_call["name"], - args=json.loads(tool_call.get("arguments", "{}")), + args=_parse_tool_args(tool_call.get("arguments")), + ) + + +def _parse_tool_call_chunk(tool_call: Dict[str, Any], default_index: int) -> ToolCallChunk: + if "function" in tool_call: + func = tool_call.get("function", {}) + name = func.get("name") + raw_args = func.get("arguments") + else: + name = tool_call.get("name") + raw_args = tool_call.get("arguments") + + args: Optional[str] + if raw_args is None: + args = None + elif isinstance(raw_args, str): + args = raw_args + else: + args = json.dumps(raw_args) + + return ToolCallChunk( + id=tool_call.get("id"), + index=tool_call.get("index", default_index), + name=name, + args=args, ) +def _run_coro_sync(coro_factory: Callable[[], Awaitable[Any]]) -> Any: + try: + asyncio.get_running_loop() + except RuntimeError: + return asyncio.run(coro_factory()) + + raise RuntimeError( + "Synchronous LangChain calls cannot run inside an active event loop for this adapter. " + "Use `ainvoke`/`astream` instead of `invoke`/`stream`." + ) + + +def _validate_model_string(model: Union[TEE_LLM, str]) -> Union[TEE_LLM, str]: + if isinstance(model, Enum): + model_str = str(model.value) + else: + model_str = str(model) + if "/" not in model_str: + raise ValueError( + f"Unsupported model value '{model_str}'. " + "Expected provider/model format (for example: 'openai/gpt-5')." + ) + return model + + class OpenGradientChatModel(BaseChatModel): """OpenGradient adapter class for LangChain chat model""" - model_cid: str + model_cid: Union[TEE_LLM, str] max_tokens: int = 300 - x402_settlement_mode: Optional[str] = x402SettlementMode.BATCH_HASHED + temperature: float = 0.0 + x402_settlement_mode: x402SettlementMode = x402SettlementMode.BATCH_HASHED _llm: LLM = PrivateAttr() + _owns_client: bool = PrivateAttr(default=False) _tools: List[Dict] = PrivateAttr(default_factory=list) + _tool_choice: Optional[str] = PrivateAttr(default=None) def __init__( self, - private_key: str, - model_cid: TEE_LLM, + private_key: Optional[str] = None, + model_cid: Optional[Union[TEE_LLM, str]] = None, + model: Optional[Union[TEE_LLM, str]] = None, max_tokens: int = 300, - x402_settlement_mode: Optional[x402SettlementMode] = x402SettlementMode.BATCH_HASHED, + temperature: float = 0.0, + x402_settlement_mode: x402SettlementMode = x402SettlementMode.BATCH_HASHED, + client: Optional[LLM] = None, + rpc_url: Optional[str] = None, + tee_registry_address: Optional[str] = None, + llm_server_url: Optional[str] = None, **kwargs, ): + resolved_model_cid = model_cid or model + if resolved_model_cid is None: + raise ValueError("model_cid (or model) is required.") + resolved_model_cid = _validate_model_string(resolved_model_cid) super().__init__( - model_cid=model_cid, + model_cid=resolved_model_cid, max_tokens=max_tokens, + temperature=temperature, x402_settlement_mode=x402_settlement_mode, **kwargs, ) - self._llm = LLM(private_key=private_key) + + if client is not None: + self._llm = client + self._owns_client = False + return + + if not private_key: + raise ValueError("private_key is required when client is not provided.") + + llm_kwargs: Dict[str, Any] = {} + if rpc_url is not None: + llm_kwargs["rpc_url"] = rpc_url + if tee_registry_address is not None: + llm_kwargs["tee_registry_address"] = tee_registry_address + if llm_server_url is not None: + llm_kwargs["llm_server_url"] = llm_server_url + + self._llm = LLM(private_key=private_key, **llm_kwargs) + self._owns_client = True @property def _llm_type(self) -> str: return "opengradient" + async def aclose(self) -> None: + if self._owns_client: + await self._llm.close() + + def close(self) -> None: + if self._owns_client: + _run_coro_sync(self._llm.close) + def bind_tools( self, tools: Sequence[ Union[Dict[str, Any], type, Callable, BaseTool] # noqa: UP006 ], + *, + tool_choice: Optional[str] = None, **kwargs: Any, ) -> Runnable[LanguageModelInput, BaseMessage]: """Bind tools to the model.""" - tool_dicts: List[Dict] = [] + strict = kwargs.get("strict") + self._tools = [convert_to_openai_tool(tool, strict=strict) for tool in tools] + self._tool_choice = tool_choice or kwargs.get("tool_choice") - for tool in tools: - if isinstance(tool, BaseTool): - tool_dicts.append( - { - "type": "function", - "function": { - "name": tool.name, - "description": tool.description, - "parameters": ( - tool.args_schema.model_json_schema() - if hasattr(tool, "args_schema") and tool.args_schema is not None - else {} - ), - }, - } - ) - else: - tool_dicts.append(tool) + return self - self._tools = tool_dicts + @staticmethod + def _stream_chunk_to_generation(chunk: StreamChunk) -> ChatGenerationChunk: + choice = chunk.choices[0] if chunk.choices else None + delta = choice.delta if choice else None - return self + usage = None + if chunk.usage is not None: + usage = { + "input_tokens": chunk.usage.prompt_tokens, + "output_tokens": chunk.usage.completion_tokens, + "total_tokens": chunk.usage.total_tokens, + } - def _generate( - self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, - **kwargs: Any, - ) -> ChatResult: - sdk_messages = [] + tool_call_chunks: List[ToolCallChunk] = [] + if delta and delta.tool_calls: + for index, tool_call in enumerate(delta.tool_calls): + tool_call_chunks.append(_parse_tool_call_chunk(tool_call, index)) + + message_chunk = AIMessageChunk( + content=_extract_content(delta.content if delta else ""), + tool_call_chunks=tool_call_chunks, + usage_metadata=usage, + ) + + generation_info: Dict[str, Any] = {} + if choice and choice.finish_reason is not None: + generation_info["finish_reason"] = choice.finish_reason + + for key in ["tee_signature", "tee_timestamp", "tee_id", "tee_endpoint", "tee_payment_address"]: + value = getattr(chunk, key, None) + if value is not None: + generation_info[key] = value + + return ChatGenerationChunk( + message=message_chunk, + generation_info=generation_info or None, + ) + + def _convert_messages_to_sdk(self, messages: List[BaseMessage]) -> List[Dict[str, Any]]: + sdk_messages: List[Dict[str, Any]] = [] for message in messages: if isinstance(message, SystemMessage): sdk_messages.append({"role": "system", "content": _extract_content(message.content)}) @@ -148,9 +280,12 @@ def _generate( if message.tool_calls: msg["tool_calls"] = [ { - "id": call["id"], + "id": call.get("id", ""), "type": "function", - "function": {"name": call["name"], "arguments": json.dumps(call["args"])}, + "function": { + "name": call["name"], + "arguments": _serialize_tool_args(call.get("args")), + }, } for call in message.tool_calls ] @@ -163,33 +298,125 @@ def _generate( "tool_call_id": message.tool_call_id, } ) + elif isinstance(message, ChatMessage): + sdk_messages.append({"role": message.role, "content": _extract_content(message.content)}) else: raise ValueError(f"Unexpected message type: {message}") + return sdk_messages - chat_output = asyncio.run( - self._llm.chat( - model=self.model_cid, - messages=sdk_messages, - stop_sequence=stop, - max_tokens=self.max_tokens, - tools=self._tools, - x402_settlement_mode=self.x402_settlement_mode, - ) - ) + def _build_chat_kwargs(self, sdk_messages: List[Dict[str, Any]], stop: Optional[List[str]], stream: bool, **kwargs: Any) -> Dict[str, Any]: + x402_settlement_mode = kwargs.get("x402_settlement_mode", self.x402_settlement_mode) + if isinstance(x402_settlement_mode, str): + x402_settlement_mode = x402SettlementMode(x402_settlement_mode) + model = kwargs.get("model", self.model_cid) + model = _validate_model_string(model) + return { + "model": model, + "messages": sdk_messages, + "stop_sequence": stop, + "max_tokens": kwargs.get("max_tokens", self.max_tokens), + "temperature": kwargs.get("temperature", self.temperature), + "tools": kwargs.get("tools", self._tools), + "tool_choice": kwargs.get("tool_choice", self._tool_choice), + "x402_settlement_mode": x402_settlement_mode, + "stream": stream, + } + + @staticmethod + def _build_chat_result(chat_output: TextGenerationOutput) -> ChatResult: finish_reason = chat_output.finish_reason or "" chat_response = chat_output.chat_output or {} + response_content = _extract_content(chat_response.get("content", "")) if chat_response.get("tool_calls"): tool_calls = [_parse_tool_call(tc) for tc in chat_response["tool_calls"]] - ai_message = AIMessage(content="", tool_calls=tool_calls) + ai_message = AIMessage(content=response_content, tool_calls=tool_calls) + else: + ai_message = AIMessage(content=response_content) + + generation_info = {"finish_reason": finish_reason} if finish_reason else {} + return ChatResult(generations=[ChatGeneration(message=ai_message, generation_info=generation_info)]) + + def _generate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> ChatResult: + sdk_messages = self._convert_messages_to_sdk(messages) + chat_kwargs = self._build_chat_kwargs(sdk_messages, stop, stream=False, **kwargs) + chat_output = _run_coro_sync(lambda: self._llm.chat(**chat_kwargs)) + if not isinstance(chat_output, TextGenerationOutput): + raise RuntimeError("Expected non-streaming chat output but received streaming generator.") + return self._build_chat_result(chat_output) + + async def _agenerate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> ChatResult: + sdk_messages = self._convert_messages_to_sdk(messages) + chat_kwargs = self._build_chat_kwargs(sdk_messages, stop, stream=False, **kwargs) + chat_output = await self._llm.chat(**chat_kwargs) + if not isinstance(chat_output, TextGenerationOutput): + raise RuntimeError("Expected non-streaming chat output but received streaming generator.") + return self._build_chat_result(chat_output) + + def _stream( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> Iterator[ChatGenerationChunk]: + sdk_messages = self._convert_messages_to_sdk(messages) + chat_kwargs = self._build_chat_kwargs(sdk_messages, stop, stream=True, **kwargs) + try: + asyncio.get_running_loop() + except RuntimeError: + pass else: - ai_message = AIMessage(content=_extract_content(chat_response.get("content", ""))) + raise RuntimeError( + "Synchronous stream cannot run inside an active event loop for this adapter. " + "Use `astream` instead." + ) + + loop = asyncio.new_event_loop() + try: + stream = loop.run_until_complete(self._llm.chat(**chat_kwargs)) + stream_iter = cast(AsyncIterator[StreamChunk], stream) + + while True: + try: + chunk = loop.run_until_complete(stream_iter.__anext__()) + except StopAsyncIteration: + break + yield self._stream_chunk_to_generation(chunk) + finally: + loop.run_until_complete(loop.shutdown_asyncgens()) + loop.close() - return ChatResult(generations=[ChatGeneration(message=ai_message, generation_info={"finish_reason": finish_reason})]) + async def _astream( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> AsyncIterator[ChatGenerationChunk]: + sdk_messages = self._convert_messages_to_sdk(messages) + chat_kwargs = self._build_chat_kwargs(sdk_messages, stop, stream=True, **kwargs) + stream = await self._llm.chat(**chat_kwargs) + async for chunk in cast(AsyncIterator[StreamChunk], stream): + yield self._stream_chunk_to_generation(chunk) @property def _identifying_params(self) -> Dict[str, Any]: return { "model_name": self.model_cid, + "temperature": self.temperature, + "max_tokens": self.max_tokens, } diff --git a/src/opengradient/client/llm.py b/src/opengradient/client/llm.py index eecfa148..0baf1315 100644 --- a/src/opengradient/client/llm.py +++ b/src/opengradient/client/llm.py @@ -3,8 +3,9 @@ import json import logging import ssl +import threading from dataclasses import dataclass -from typing import AsyncGenerator, Dict, List, Optional, Union +from typing import Any, AsyncGenerator, Awaitable, Callable, Dict, List, Optional, TypeVar, Union from eth_account import Account from eth_account.account import LocalAccount @@ -19,6 +20,7 @@ from .tee_registry import TEERegistry, build_ssl_context_from_der logger = logging.getLogger(__name__) +T = TypeVar("T") DEFAULT_RPC_URL = "https://ogevmdevnet.opengradient.ai" DEFAULT_TEE_REGISTRY_ADDRESS = "0x4e72238852f3c918f4E4e57AeC9280dDB0c80248" @@ -30,6 +32,9 @@ _CHAT_ENDPOINT = "/v1/chat/completions" _COMPLETION_ENDPOINT = "/v1/completions" _REQUEST_TIMEOUT = 60 +_TEE_SIGNATURE_HEADER = "X-TEE-Signature" +_TEE_TIMESTAMP_HEADER = "X-TEE-Timestamp" +_TEE_ID_HEADER = "X-TEE-ID" @dataclass @@ -45,6 +50,16 @@ class _ChatParams: x402_settlement_mode: x402SettlementMode +@dataclass +class _ResolvedTEEState: + """Resolved TEE connection details used to configure the HTTP client.""" + + endpoint: str + tls_verify: Union[ssl.SSLContext, bool] + tee_id: Optional[str] + tee_payment_address: Optional[str] + + class LLM: """ LLM inference namespace. @@ -95,23 +110,20 @@ def __init__( ): self._wallet_account: LocalAccount = Account.from_key(private_key) - endpoint, tls_cert_der, tee_id, tee_payment_address = self._resolve_tee( - llm_server_url, - rpc_url, - tee_registry_address, - ) + # Store registry params so we can re-resolve TEE on SSL failures + self._rpc_url = rpc_url + self._tee_registry_address = tee_registry_address + self._llm_server_url = llm_server_url - self._tee_id = tee_id - self._tee_endpoint = endpoint - self._tee_payment_address = tee_payment_address + state = self._resolve_tee_state() + self._apply_tee_state(state) + self._reset_lock = threading.Lock() - ssl_ctx = build_ssl_context_from_der(tls_cert_der) if tls_cert_der else None - # When connecting directly via llm_server_url, skip cert verification — - # self-hosted TEE servers commonly use self-signed certificates. - verify_ssl = llm_server_url is None - self._tls_verify: Union[ssl.SSLContext, bool] = ssl_ctx if ssl_ctx else verify_ssl + # x402 client/signer/http stack + self._init_x402_stack() - # x402 client and signer + def _init_x402_stack(self) -> None: + """Initialize x402 signer/client/http stack.""" signer = EthAccountSigner(self._wallet_account) self._x402_client = x402Client() register_exact_evm_client(self._x402_client, signer, networks=[BASE_TESTNET_NETWORK]) @@ -119,6 +131,90 @@ def __init__( # httpx.AsyncClient subclass - construction is sync, connections open lazily self._http_client = x402HttpxClient(self._x402_client, verify=self._tls_verify) + def _resolve_tee_state(self) -> _ResolvedTEEState: + """Resolve current TEE metadata and derive the TLS verification config.""" + endpoint, tls_cert_der, tee_id, tee_payment_address = self._resolve_tee( + self._llm_server_url, + self._rpc_url, + self._tee_registry_address, + ) + ssl_ctx = build_ssl_context_from_der(tls_cert_der) if tls_cert_der else None + # When connecting directly via llm_server_url, skip cert verification — + # self-hosted TEE servers commonly use self-signed certificates. + verify_ssl = self._llm_server_url is None + return _ResolvedTEEState( + endpoint=endpoint, + tls_verify=ssl_ctx if ssl_ctx else verify_ssl, + tee_id=tee_id, + tee_payment_address=tee_payment_address, + ) + + def _apply_tee_state(self, state: _ResolvedTEEState) -> None: + """Apply resolved TEE metadata to the current client instance.""" + self._tee_id = state.tee_id + self._tee_endpoint = state.endpoint + self._tee_payment_address = state.tee_payment_address + self._tls_verify = state.tls_verify + + @staticmethod + def _is_ssl_error(exc: Exception) -> bool: + """Detect SSL/TLS errors that indicate a stale certificate or connection reset.""" + visited: set[int] = set() + current: Optional[BaseException] = exc + + while current is not None and id(current) not in visited: + visited.add(id(current)) + if isinstance(current, ssl.SSLError): + return True + msg = str(current).lower() + if any( + keyword in msg + for keyword in ( + "certificate verify failed", + "tlsv1 alert", + "ssl handshake", + "sslcertverificationerror", + "ssl: certificate_verify_failed", + "self-signed certificate", + ) + ): + return True + current = current.__cause__ or current.__context__ + return False + + async def _refresh_tee_and_reset(self) -> None: + """Re-resolve TEE from registry and rebuild the HTTP client with a fresh SSL context.""" + with self._reset_lock: + old_http_client = self._http_client + state = self._resolve_tee_state() + self._apply_tee_state(state) + self._init_x402_stack() + + try: + await old_http_client.aclose() + except Exception: + logger.debug("Failed to close previous HTTP client during TEE refresh.", exc_info=True) + + async def _retry_once_on_recoverable_error( + self, + operation_name: str, + call: Callable[[], Awaitable[T]], + ) -> T: + """Retry once after refreshing TEE state for recoverable SSL errors.""" + try: + return await call() + except Exception as first_error: + if self._is_ssl_error(first_error): + logger.warning( + "SSL/TLS error during %s; re-resolving TEE from registry and retrying once: %s", + operation_name, + first_error, + ) + await self._refresh_tee_and_reset() + return await call() + + raise + # ── TEE resolution ────────────────────────────────────────────────── @staticmethod @@ -188,6 +284,16 @@ def _tee_metadata(self) -> Dict: tee_payment_address=self._tee_payment_address, ) + @staticmethod + def _extract_tee_headers(response: Any) -> Dict[str, Optional[str]]: + """Extract TEE proof metadata from HTTP headers.""" + headers = getattr(response, "headers", {}) or {} + return { + "tee_signature": headers.get(_TEE_SIGNATURE_HEADER), + "tee_timestamp": headers.get(_TEE_TIMESTAMP_HEADER), + "tee_id": headers.get(_TEE_ID_HEADER), + } + # ── Public API ────────────────────────────────────────────────────── def ensure_opg_approval(self, opg_amount: float) -> Permit2ApprovalResult: @@ -258,7 +364,7 @@ async def completion( if stop_sequence: payload["stop"] = stop_sequence - try: + async def _request() -> TextGenerationOutput: response = await self._http_client.post( self._tee_endpoint + _COMPLETION_ENDPOINT, json=payload, @@ -268,13 +374,20 @@ async def completion( response.raise_for_status() raw_body = await response.aread() result = json.loads(raw_body.decode()) + tee_headers = self._extract_tee_headers(response) + metadata = self._tee_metadata() + if tee_headers.get("tee_id"): + metadata["tee_id"] = tee_headers["tee_id"] return TextGenerationOutput( transaction_hash="external", completion_output=result.get("completion"), - tee_signature=result.get("tee_signature"), - tee_timestamp=result.get("tee_timestamp"), - **self._tee_metadata(), + tee_signature=result.get("tee_signature") or tee_headers.get("tee_signature"), + tee_timestamp=result.get("tee_timestamp") or tee_headers.get("tee_timestamp"), + **metadata, ) + + try: + return await self._retry_once_on_recoverable_error("completion", _request) except RuntimeError: raise except Exception as e: @@ -345,7 +458,7 @@ async def _chat_request(self, params: _ChatParams, messages: List[Dict]) -> Text headers = self._headers(params.x402_settlement_mode) payload = self._chat_payload(params, messages) - try: + async def _request() -> TextGenerationOutput: response = await self._http_client.post( self._tee_endpoint + _CHAT_ENDPOINT, json=payload, @@ -355,6 +468,10 @@ async def _chat_request(self, params: _ChatParams, messages: List[Dict]) -> Text response.raise_for_status() raw_body = await response.aread() result = json.loads(raw_body.decode()) + tee_headers = self._extract_tee_headers(response) + metadata = self._tee_metadata() + if tee_headers.get("tee_id"): + metadata["tee_id"] = tee_headers["tee_id"] choices = result.get("choices") if not choices: @@ -371,10 +488,13 @@ async def _chat_request(self, params: _ChatParams, messages: List[Dict]) -> Text transaction_hash="external", finish_reason=choices[0].get("finish_reason"), chat_output=message, - tee_signature=result.get("tee_signature"), - tee_timestamp=result.get("tee_timestamp"), - **self._tee_metadata(), + tee_signature=result.get("tee_signature") or tee_headers.get("tee_signature"), + tee_timestamp=result.get("tee_timestamp") or tee_headers.get("tee_timestamp"), + **metadata, ) + + try: + return await self._retry_once_on_recoverable_error("chat", _request) except RuntimeError: raise except Exception as e: @@ -410,17 +530,41 @@ async def _chat_stream(self, params: _ChatParams, messages: List[Dict]) -> Async headers = self._headers(params.x402_settlement_mode) payload = self._chat_payload(params, messages, stream=True) - async with self._http_client.stream( - "POST", - self._tee_endpoint + _CHAT_ENDPOINT, - json=payload, - headers=headers, - timeout=_REQUEST_TIMEOUT, - ) as response: - async for chunk in self._parse_sse_response(response): - yield chunk - - async def _parse_sse_response(self, response) -> AsyncGenerator[StreamChunk, None]: + retried = False + while True: + chunks_yielded = False + try: + async with self._http_client.stream( + "POST", + self._tee_endpoint + _CHAT_ENDPOINT, + json=payload, + headers=headers, + timeout=_REQUEST_TIMEOUT, + ) as response: + tee_headers = self._extract_tee_headers(response) + async for chunk in self._parse_sse_response(response, tee_headers=tee_headers): + chunks_yielded = True + yield chunk + return + except Exception as e: + if not retried and not chunks_yielded: + if self._is_ssl_error(e): + retried = True + logger.warning( + "SSL/TLS error during stream; re-resolving TEE from registry and retrying once: %s", + e, + ) + await self._refresh_tee_and_reset() + # Re-read headers since endpoint may have changed + headers = self._headers(params.x402_settlement_mode) + continue + raise + + async def _parse_sse_response( + self, + response, + tee_headers: Optional[Dict[str, Optional[str]]] = None, + ) -> AsyncGenerator[StreamChunk, None]: """Parse an SSE response stream into StreamChunk objects.""" status_code = getattr(response, "status_code", None) if status_code is not None and status_code >= 400: @@ -461,4 +605,8 @@ async def _parse_sse_response(self, response) -> AsyncGenerator[StreamChunk, Non chunk.tee_id = self._tee_id chunk.tee_endpoint = self._tee_endpoint chunk.tee_payment_address = self._tee_payment_address + if tee_headers: + chunk.tee_signature = chunk.tee_signature or tee_headers.get("tee_signature") + chunk.tee_timestamp = chunk.tee_timestamp or tee_headers.get("tee_timestamp") + chunk.tee_id = tee_headers.get("tee_id") or chunk.tee_id yield chunk diff --git a/src/opengradient/types.py b/src/opengradient/types.py index 1f7ec75d..d55dc767 100644 --- a/src/opengradient/types.py +++ b/src/opengradient/types.py @@ -500,6 +500,7 @@ class TEE_LLM(str, Enum): GPT_5_2 = "openai/gpt-5.2" # Anthropic models via TEE + CLAUDE_SONNET_4_0 = "anthropic/claude-sonnet-4-0" CLAUDE_SONNET_4_5 = "anthropic/claude-sonnet-4-5" CLAUDE_SONNET_4_6 = "anthropic/claude-sonnet-4-6" CLAUDE_HAIKU_4_5 = "anthropic/claude-haiku-4-5" @@ -510,10 +511,11 @@ class TEE_LLM(str, Enum): GEMINI_2_5_FLASH = "google/gemini-2.5-flash" GEMINI_2_5_PRO = "google/gemini-2.5-pro" GEMINI_2_5_FLASH_LITE = "google/gemini-2.5-flash-lite" - GEMINI_3_PRO = "google/gemini-3-pro-preview" GEMINI_3_FLASH = "google/gemini-3-flash-preview" # xAI Grok models via TEE + GROK_3 = "x-ai/grok-3" + GROK_3_MINI = "x-ai/grok-3-mini" GROK_4 = "x-ai/grok-4" GROK_4_FAST = "x-ai/grok-4-fast" GROK_4_1_FAST = "x-ai/grok-4-1-fast" diff --git a/tests/langchain_adapter_test.py b/tests/langchain_adapter_test.py index e651ab49..1747c1d7 100644 --- a/tests/langchain_adapter_test.py +++ b/tests/langchain_adapter_test.py @@ -1,3 +1,4 @@ +import asyncio import json import os import sys @@ -11,7 +12,7 @@ sys.path.append(os.path.join(os.path.dirname(__file__), "..")) from src.opengradient.agents.og_langchain import OpenGradientChatModel, _extract_content, _parse_tool_call -from src.opengradient.types import TEE_LLM, TextGenerationOutput, x402SettlementMode +from src.opengradient.types import StreamChoice, StreamChunk, StreamDelta, TEE_LLM, TextGenerationOutput, x402SettlementMode @pytest.fixture @@ -52,9 +53,24 @@ def test_initialization_custom_settlement_mode(self, mock_llm_client): ) assert model.x402_settlement_mode == x402SettlementMode.PRIVATE + def test_initialization_with_existing_client(self): + with patch("src.opengradient.agents.og_langchain.LLM") as MockLLM: + existing_client = MagicMock() + model = OpenGradientChatModel(private_key=None, client=existing_client, model_cid=TEE_LLM.GPT_5) + assert model._llm is existing_client + MockLLM.assert_not_called() + + def test_initialization_without_private_key_or_client_raises(self): + with pytest.raises(ValueError, match="private_key is required"): + OpenGradientChatModel(private_key=None, model_cid=TEE_LLM.GPT_5) + + def test_initialization_with_invalid_model_string_raises(self): + with pytest.raises(ValueError, match="provider/model format"): + OpenGradientChatModel(private_key="0x" + "a" * 64, model_cid="gpt-5") + def test_identifying_params(self, model): """Test _identifying_params returns model name.""" - assert model._identifying_params == {"model_name": TEE_LLM.GPT_5} + assert model._identifying_params == {"model_name": TEE_LLM.GPT_5, "temperature": 0.0, "max_tokens": 300} class TestGenerate: @@ -156,6 +172,24 @@ def test_empty_chat_output(self, model, mock_llm_client): assert result.generations[0].message.content == "" + def test_generate_with_invalid_model_kwarg_raises(self, model): + with pytest.raises(ValueError, match="provider/model format"): + model._generate([HumanMessage(content="Hi")], model="gpt-5") + + def test_sync_generate_inside_running_loop_raises(self, model): + async def run_test(): + with pytest.raises(RuntimeError, match="Use `ainvoke`/`astream`"): + model._generate([HumanMessage(content="Hi")]) + + asyncio.run(run_test()) + + def test_sync_stream_inside_running_loop_raises(self, model): + async def run_test(): + with pytest.raises(RuntimeError, match="Use `astream`"): + next(model._stream([HumanMessage(content="Hi")])) + + asyncio.run(run_test()) + class TestMessageConversion: def test_converts_all_message_types(self, model, mock_llm_client): @@ -215,8 +249,11 @@ def test_passes_correct_params_to_client(self, model, mock_llm_client): messages=[{"role": "user", "content": "Hi"}], stop_sequence=["END"], max_tokens=300, + temperature=0.0, tools=[], + tool_choice=None, x402_settlement_mode=x402SettlementMode.BATCH_HASHED, + stream=False, ) @@ -306,3 +343,77 @@ def test_nested_function_format(self): assert tc["name"] == "bar" assert tc["args"] == {"y": 2} assert tc["id"] == "2" + + +class TestAsyncPaths: + def test_agenerate(self, model, mock_llm_client): + mock_llm_client.chat.return_value = TextGenerationOutput( + transaction_hash="external", + finish_reason="stop", + chat_output={"role": "assistant", "content": "Hello async!"}, + ) + + result = asyncio.run(model._agenerate([HumanMessage(content="Hi")])) + assert result.generations[0].message.content == "Hello async!" + + def test_ainvoke(self, model, mock_llm_client): + mock_llm_client.chat.return_value = TextGenerationOutput( + transaction_hash="external", + finish_reason="stop", + chat_output={"role": "assistant", "content": "pong"}, + ) + + message = asyncio.run(model.ainvoke([HumanMessage(content="ping")])) + assert message.content == "pong" + + def test_astream(self, model, mock_llm_client): + async def stream(): + yield StreamChunk( + choices=[StreamChoice(delta=StreamDelta(role="assistant", content="Hel"), index=0)], + model="gpt-5", + ) + yield StreamChunk( + choices=[StreamChoice(delta=StreamDelta(content="lo"), index=0, finish_reason="stop")], + model="gpt-5", + is_final=True, + ) + + mock_llm_client.chat.return_value = stream() + + async def collect_chunks(): + return [chunk async for chunk in model.astream([HumanMessage(content="Hi")])] + + chunks = asyncio.run(collect_chunks()) + output_text = "".join(chunk.content for chunk in chunks if chunk.content) + assert output_text == "Hello" + + def test_astream_tool_call_chunk(self, model, mock_llm_client): + async def stream(): + yield StreamChunk( + choices=[ + StreamChoice( + delta=StreamDelta( + tool_calls=[ + { + "id": "call_1", + "type": "function", + "function": {"name": "search", "arguments": '{"q":"test"}'}, + } + ] + ), + index=0, + finish_reason="tool_calls", + ) + ], + model="gpt-5", + is_final=True, + ) + + mock_llm_client.chat.return_value = stream() + + async def collect_chunks(): + return [chunk async for chunk in model.astream([HumanMessage(content="Hi")])] + + chunks = asyncio.run(collect_chunks()) + assert chunks[0].tool_call_chunks[0]["id"] == "call_1" + assert chunks[0].tool_call_chunks[0]["name"] == "search" diff --git a/tests/llm_test.py b/tests/llm_test.py index bb845a75..47031593 100644 --- a/tests/llm_test.py +++ b/tests/llm_test.py @@ -5,9 +5,10 @@ """ import json +import ssl from contextlib import asynccontextmanager from typing import List -from unittest.mock import MagicMock, patch +from unittest.mock import AsyncMock, MagicMock, patch import httpx import pytest @@ -29,15 +30,17 @@ class FakeHTTPClient: def __init__(self, *_args, **_kwargs): self._response_status: int = 200 self._response_body: bytes = b"{}" + self._response_headers: dict = {} self._post_calls: List[dict] = [] self._stream_response = None - def set_response(self, status_code: int, body: dict) -> None: + def set_response(self, status_code: int, body: dict, headers: dict | None = None) -> None: self._response_status = status_code self._response_body = json.dumps(body).encode() + self._response_headers = headers or {} - def set_stream_response(self, status_code: int, chunks: List[bytes]) -> None: - self._stream_response = _FakeStreamResponse(status_code, chunks) + def set_stream_response(self, status_code: int, chunks: List[bytes], headers: dict | None = None) -> None: + self._stream_response = _FakeStreamResponse(status_code, chunks, headers=headers) @property def post_calls(self) -> List[dict]: @@ -45,7 +48,7 @@ def post_calls(self) -> List[dict]: async def post(self, url: str, *, json=None, headers=None, timeout=None) -> "_FakeResponse": self._post_calls.append({"url": url, "json": json, "headers": headers, "timeout": timeout}) - resp = _FakeResponse(self._response_status, self._response_body) + resp = _FakeResponse(self._response_status, self._response_body, headers=self._response_headers) if self._response_status >= 400: resp.raise_for_status = MagicMock(side_effect=httpx.HTTPStatusError("error", request=MagicMock(), response=MagicMock())) return resp @@ -60,9 +63,10 @@ async def aclose(self): class _FakeResponse: - def __init__(self, status_code: int, body: bytes): + def __init__(self, status_code: int, body: bytes, headers: dict | None = None): self.status_code = status_code self._body = body + self.headers = headers or {} def raise_for_status(self): pass @@ -72,9 +76,10 @@ async def aread(self) -> bytes: class _FakeStreamResponse: - def __init__(self, status_code: int, chunks: List[bytes]): + def __init__(self, status_code: int, chunks: List[bytes], headers: dict | None = None): self.status_code = status_code self._chunks = chunks + self.headers = headers or {} async def aiter_raw(self): for chunk in self._chunks: @@ -152,6 +157,24 @@ async def test_returns_completion_output(self, fake_http): assert result.tee_id == "test-tee-id" assert result.tee_payment_address == "0xTestPayment" + async def test_tee_metadata_falls_back_to_headers(self, fake_http): + fake_http.set_response( + 200, + {"completion": "ok"}, + headers={ + "X-TEE-Signature": "sig-from-header", + "X-TEE-Timestamp": "2026-03-13T00:00:00Z", + "X-TEE-ID": "tee-id-from-header", + }, + ) + llm = _make_llm() + + result = await llm.completion(model=TEE_LLM.GPT_5, prompt="Hi") + + assert result.tee_signature == "sig-from-header" + assert result.tee_timestamp == "2026-03-13T00:00:00Z" + assert result.tee_id == "tee-id-from-header" + async def test_sends_correct_payload(self, fake_http): fake_http.set_response(200, {"completion": "ok"}) llm = _make_llm() @@ -236,6 +259,29 @@ async def test_returns_chat_output(self, fake_http): assert result.finish_reason == "stop" assert result.tee_signature == "sig-xyz" + async def test_chat_tee_metadata_falls_back_to_headers(self, fake_http): + fake_http.set_response( + 200, + { + "choices": [{"message": {"role": "assistant", "content": "Hi there!"}, "finish_reason": "stop"}], + }, + headers={ + "X-TEE-Signature": "sig-from-header", + "X-TEE-Timestamp": "2026-03-13T00:00:00Z", + "X-TEE-ID": "tee-id-from-header", + }, + ) + llm = _make_llm() + + result = await llm.chat( + model=TEE_LLM.GPT_5, + messages=[{"role": "user", "content": "Hello"}], + ) + + assert result.tee_signature == "sig-from-header" + assert result.tee_timestamp == "2026-03-13T00:00:00Z" + assert result.tee_id == "tee-id-from-header" + async def test_flattens_content_blocks(self, fake_http): fake_http.set_response( 200, @@ -361,6 +407,68 @@ async def test_http_error_raises_opengradient_error(self, fake_http): with pytest.raises(RuntimeError, match="TEE LLM chat failed"): await llm.chat(model=TEE_LLM.GPT_5, messages=[{"role": "user", "content": "Hi"}]) + async def test_invalid_payment_required_propagates(self, fake_http): + llm = _make_llm() + + async def flaky_post(*args, **kwargs): + raise RuntimeError("Failed to handle payment: Invalid payment required response") + + llm._http_client.post = flaky_post + + with pytest.raises(RuntimeError, match="Failed to handle payment: Invalid payment required response"): + await llm.chat(model=TEE_LLM.GPT_5, messages=[{"role": "user", "content": "Hi"}]) + + async def test_retries_once_on_ssl_error(self, fake_http): + """SSL cert failure triggers TEE re-resolution and retry.""" + fake_http.set_response( + 200, + { + "choices": [{"message": {"role": "assistant", "content": "recovered"}, "finish_reason": "stop"}], + }, + ) + llm = _make_llm() + llm._refresh_tee_and_reset = AsyncMock(return_value=None) + original_post = llm._http_client.post + attempts = {"count": 0} + + async def ssl_failing_post(*args, **kwargs): + attempts["count"] += 1 + if attempts["count"] == 1: + # Simulate the exact error chain from production logs: + # httpx.ConnectError wrapping ssl.SSLCertVerificationError + ssl_err = ssl.SSLCertVerificationError( + "[SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: self-signed certificate (_ssl.c:1010)" + ) + raise httpx.ConnectError( + "[SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: self-signed certificate (_ssl.c:1010)" + ) from ssl_err + return await original_post(*args, **kwargs) + + llm._http_client.post = ssl_failing_post + + result = await llm.chat(model=TEE_LLM.GPT_5, messages=[{"role": "user", "content": "Hi"}]) + + assert result.chat_output["content"] == "recovered" + assert attempts["count"] == 2 + llm._refresh_tee_and_reset.assert_awaited_once() + + async def test_ssl_error_not_retried_twice(self, fake_http): + """If the retry also fails with SSL, the error propagates.""" + llm = _make_llm() + llm._refresh_tee_and_reset = AsyncMock(return_value=None) + + async def always_ssl_fail(*args, **kwargs): + raise httpx.ConnectError( + "[SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: self-signed certificate" + ) + + llm._http_client.post = always_ssl_fail + + with pytest.raises(RuntimeError, match="TEE LLM chat failed"): + await llm.chat(model=TEE_LLM.GPT_5, messages=[{"role": "user", "content": "Hi"}]) + + llm._refresh_tee_and_reset.assert_awaited_once() + # ── Streaming tests ────────────────────────────────────────────────── @@ -427,6 +535,33 @@ async def test_stream_sets_tee_metadata_on_final_chunk(self, fake_http): assert final.tee_id == "test-tee-id" assert final.tee_payment_address == "0xTestPayment" + async def test_stream_tee_signature_timestamp_fallback_to_headers(self, fake_http): + fake_http.set_stream_response( + 200, + [ + b'data: {"model":"gpt-5","choices":[{"index":0,"delta":{"content":"done"},"finish_reason":"stop"}]}\n\n', + b"data: [DONE]\n\n", + ], + headers={ + "X-TEE-Signature": "sig-from-header", + "X-TEE-Timestamp": "2026-03-13T00:00:00Z", + "X-TEE-ID": "tee-id-from-header", + }, + ) + llm = _make_llm() + + gen = await llm.chat( + model=TEE_LLM.GPT_5, + messages=[{"role": "user", "content": "Hi"}], + stream=True, + ) + chunks = [chunk async for chunk in gen] + + final = chunks[-1] + assert final.tee_signature == "sig-from-header" + assert final.tee_timestamp == "2026-03-13T00:00:00Z" + assert final.tee_id == "tee-id-from-header" + async def test_stream_error_raises(self, fake_http): fake_http.set_stream_response(500, [b"Internal Server Error"]) llm = _make_llm() @@ -469,6 +604,162 @@ async def test_tools_with_stream_falls_back_to_single_chunk(self, fake_http): assert chunks[0].choices[0].delta.tool_calls == [{"id": "tc1"}] assert chunks[0].choices[0].finish_reason == "tool_calls" + async def test_stream_invalid_payment_required_propagates(self, fake_http): + llm = _make_llm() + + @asynccontextmanager + async def flaky_stream(*args, **kwargs): + raise RuntimeError("Failed to handle payment: Invalid payment required response") + yield + + llm._http_client.stream = flaky_stream + + gen = await llm.chat( + model=TEE_LLM.GPT_5, + messages=[{"role": "user", "content": "Hi"}], + stream=True, + ) + + with pytest.raises(RuntimeError, match="Failed to handle payment: Invalid payment required response"): + _ = [chunk async for chunk in gen] + + async def test_stream_retries_once_on_ssl_error(self, fake_http): + """SSL cert failure during streaming triggers TEE re-resolution and retry.""" + fake_http.set_stream_response( + 200, + [ + b'data: {"model":"gpt-5","choices":[{"index":0,"delta":{"content":"ok"},"finish_reason":"stop"}]}\n\n', + b"data: [DONE]\n\n", + ], + ) + llm = _make_llm() + llm._refresh_tee_and_reset = AsyncMock(return_value=None) + original_stream = llm._http_client.stream + attempts = {"count": 0} + + @asynccontextmanager + async def ssl_failing_stream(*args, **kwargs): + attempts["count"] += 1 + if attempts["count"] == 1: + ssl_err = ssl.SSLCertVerificationError( + "[SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: self-signed certificate" + ) + raise httpx.ConnectError("SSL cert failed") from ssl_err + async with original_stream(*args, **kwargs) as response: + yield response + + llm._http_client.stream = ssl_failing_stream + + gen = await llm.chat( + model=TEE_LLM.GPT_5, + messages=[{"role": "user", "content": "Hi"}], + stream=True, + ) + chunks = [chunk async for chunk in gen] + + assert attempts["count"] == 2 + assert chunks[-1].choices[0].delta.content == "ok" + llm._refresh_tee_and_reset.assert_awaited_once() + + +# ── _is_ssl_error detection tests ──────────────────────────────────── + + +class TestIsSSLError: + def test_detects_ssl_error_instance(self): + exc = ssl.SSLError("something went wrong") + assert LLM._is_ssl_error(exc) is True + + def test_detects_ssl_cert_verification_error(self): + exc = ssl.SSLCertVerificationError( + "[SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: self-signed certificate (_ssl.c:1010)" + ) + assert LLM._is_ssl_error(exc) is True + + def test_detects_httpx_connect_error_wrapping_ssl(self): + """Matches the exact production error chain.""" + ssl_err = ssl.SSLCertVerificationError( + "[SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: self-signed certificate (_ssl.c:1010)" + ) + httpx_err = httpx.ConnectError( + "[SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: self-signed certificate (_ssl.c:1010)" + ) + httpx_err.__cause__ = ssl_err + assert LLM._is_ssl_error(httpx_err) is True + + def test_detects_certificate_verify_failed_in_message(self): + exc = RuntimeError("certificate verify failed") + assert LLM._is_ssl_error(exc) is True + + def test_detects_self_signed_certificate_in_message(self): + exc = RuntimeError("self-signed certificate") + assert LLM._is_ssl_error(exc) is True + + def test_detects_tlsv1_alert_in_message(self): + exc = RuntimeError("tlsv1 alert decode error") + assert LLM._is_ssl_error(exc) is True + + def test_does_not_match_unrelated_error(self): + exc = RuntimeError("connection timeout") + assert LLM._is_ssl_error(exc) is False + + def test_does_not_match_url_containing_ssl(self): + """Bare 'ssl' keyword was removed to avoid false positives like this.""" + exc = RuntimeError("Failed to connect to ssl.example.com") + assert LLM._is_ssl_error(exc) is False + + def test_does_not_match_generic_connection_error(self): + exc = httpx.ConnectError("Connection refused") + assert LLM._is_ssl_error(exc) is False + + +# ── _refresh_tee_and_reset tests ───────────────────────────────────── + + +@pytest.mark.asyncio +class TestRefreshTeeAndReset: + async def test_re_resolves_registry_and_rebuilds_client(self, fake_http): + """Verifies that _refresh_tee_and_reset updates endpoint, cert, and HTTP client.""" + llm = _make_llm() + + mock_tee = MagicMock() + mock_tee.endpoint = "https://new.tee.server" + mock_tee.tls_cert_der = None # simplify: no cert pinning + mock_tee.tee_id = "new-tee-id" + mock_tee.payment_address = "0xNewPayment" + + # Override _llm_server_url to None so _resolve_tee hits the registry path + llm._llm_server_url = None + llm._rpc_url = "https://rpc" + llm._tee_registry_address = "0xRegistry" + + new_http = FakeHTTPClient() + with ( + patch("src.opengradient.client.llm.TEERegistry") as mock_reg, + patch("src.opengradient.client.llm.x402HttpxClient", return_value=new_http), + ): + mock_reg.return_value.get_llm_tee.return_value = mock_tee + await llm._refresh_tee_and_reset() + + assert llm._tee_endpoint == "https://new.tee.server" + assert llm._tee_id == "new-tee-id" + assert llm._tee_payment_address == "0xNewPayment" + assert llm._http_client is new_http + + async def test_direct_url_preserves_verify_false_after_refresh(self, fake_http): + """When llm_server_url is set, refresh must keep verify=False (not flip to True).""" + llm = _make_llm(endpoint="https://direct.tee.server") + + # Confirm initial state: direct URL means no cert verification + assert llm._tls_verify is False + + new_http = FakeHTTPClient() + with patch("src.opengradient.client.llm.x402HttpxClient", return_value=new_http): + await llm._refresh_tee_and_reset() + + # After refresh, verify=False must be preserved for direct endpoints + assert llm._tls_verify is False + # ── ensure_opg_approval tests ──────────────────────────────────────── diff --git a/uv.lock b/uv.lock index 94346b2c..df50a6ff 100644 --- a/uv.lock +++ b/uv.lock @@ -1835,15 +1835,15 @@ wheels = [ [[package]] name = "og-x402" -version = "0.0.1.dev2" +version = "0.0.1.dev4" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "pydantic" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/1e/75/40c43cd44aa394e68acc98f8d5b8376f3a5e3b9eddf55b1c0c34616c340b/og_x402-0.0.1.dev2.tar.gz", hash = "sha256:bf5d4484ece5a371358a336fcc79fe5678be611044c55ade45c4be9d19d7691b", size = 899662, upload-time = "2026-03-17T06:35:36.587Z" } +sdist = { url = "https://files.pythonhosted.org/packages/b1/5b/46a55d93d9da5535ff2bb28d48d5766c9108d9e16546cb9c7a65cde0fb11/og_x402-0.0.1.dev4.tar.gz", hash = "sha256:2d8a71b2f4222284e65d45e2d122faafe3bdb33c4fae77903f9665d29e517a97", size = 900109, upload-time = "2026-03-23T15:10:37.144Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/4c/79/8c7543c2e647508e04ad0983e9a3a7b861f388ec591ccdc42c69a3128d42/og_x402-0.0.1.dev2-py3-none-any.whl", hash = "sha256:65e7d3bbb3c7f51e51dad974f6c405a230693816f72d874cf0d6d705a8eec271", size = 952331, upload-time = "2026-03-17T06:35:34.695Z" }, + { url = "https://files.pythonhosted.org/packages/45/da/5e0be4b8415a6c557a94991367c6124998df3ba014bceb76b595ef48c8c7/og_x402-0.0.1.dev4-py3-none-any.whl", hash = "sha256:c329ceb4fe7cc4195fa5bf9c769f5c571b61c8333b33fd0fe204a2ab377d8366", size = 952662, upload-time = "2026-03-23T15:10:35.21Z" }, ] [[package]] @@ -1907,7 +1907,7 @@ requires-dist = [ { name = "langgraph", marker = "extra == 'dev'" }, { name = "mypy", marker = "extra == 'dev'" }, { name = "numpy", specifier = ">=1.26.4" }, - { name = "og-x402", specifier = "==0.0.1.dev2" }, + { name = "og-x402", specifier = "==0.0.1.dev4" }, { name = "openai", specifier = ">=1.58.1" }, { name = "pdoc3", marker = "extra == 'dev'", specifier = "==0.10.0" }, { name = "pydantic", specifier = ">=2.9.2" },