diff --git a/pyhilo/__init__.py b/pyhilo/__init__.py index e19e0d4..b998e17 100644 --- a/pyhilo/__init__.py +++ b/pyhilo/__init__.py @@ -6,9 +6,9 @@ from pyhilo.device.switch import Switch from pyhilo.devices import Devices from pyhilo.event import Event -from pyhilo.exceptions import HiloError, InvalidCredentialsError, WebsocketError +from pyhilo.exceptions import HiloError, InvalidCredentialsError, SignalRError +from pyhilo.signalr import SignalREvent from pyhilo.util import from_utc_timestamp, time_diff -from pyhilo.websocket import WebsocketEvent __all__ = [ "API", @@ -17,10 +17,10 @@ "Event", "HiloError", "InvalidCredentialsError", - "WebsocketError", + "SignalRError", "from_utc_timestamp", "time_diff", - "WebsocketEvent", + "SignalREvent", "UNMONITORED_DEVICES", "Switch", ] diff --git a/pyhilo/api.py b/pyhilo/api.py index feb5098..f3051e0 100755 --- a/pyhilo/api.py +++ b/pyhilo/api.py @@ -45,14 +45,8 @@ ) from pyhilo.device import DeviceAttribute, HiloDevice, get_device_attributes from pyhilo.exceptions import InvalidCredentialsError, RequestError -from pyhilo.util.state import ( - StateDict, - WebsocketDict, - WebsocketTransportsDict, - get_state, - set_state, -) -from pyhilo.websocket import WebsocketClient, WebsocketManager +from pyhilo.signalr import SignalRHub, SignalRManager +from pyhilo.util.state import AndroidDeviceDict, StateDict, get_state, set_state class API: @@ -74,7 +68,6 @@ def __init__( ) -> None: """Initialize""" self._backoff_refresh_lock_api = asyncio.Lock() - self._backoff_refresh_lock_ws = asyncio.Lock() self._request_retries = request_retries self._state_yaml: str = DEFAULT_STATE_FILE self.state: StateDict = {} @@ -82,19 +75,12 @@ def __init__( self.device_attributes = get_device_attributes() self.session: ClientSession = session self._oauth_session = oauth_session - self.websocket_devices: WebsocketClient - # Backward compatibility during transition to websocket for challenges. Currently the HA Hilo integration - # uses the .websocket attribute. Re-added this attribute and point to the same object as websocket_devices. - # Should be removed once the transition to the challenge websocket is completed everywhere. - self.websocket: WebsocketClient - self.websocket_challenges: WebsocketClient + self.signalr_devices: SignalRHub + self.signalr_challenges: SignalRHub self.log_traces = log_traces self._get_device_callbacks: list[Callable[..., Any]] = [] - self.ws_url: str = "" - self.ws_token: str = "" - self.endpoint: str = "" self._urn: str | None = None - # Device cache from websocket DeviceListInitialValuesReceived + # Device cache from SignalR DeviceListInitialValuesReceived self._device_cache: list[dict[str, Any]] = [] self._device_cache_event: asyncio.Event = asyncio.Event() @@ -146,7 +132,7 @@ async def async_get_access_token(self) -> str: await self._oauth_session.async_ensure_token_valid() access_token = str(self._oauth_session.token["access_token"]) - LOG.debug("Websocket access token is %s", access_token) + LOG.debug("SignalR access token is %s", access_token) urn = self.urn LOG.debug("Extracted URN: %s", urn) @@ -356,19 +342,7 @@ async def _async_handle_on_backoff(self, _: dict[str, Any]) -> None: err: ClientResponseError = err_info[1].with_traceback(err_info[2]) # type: ignore if err.status in (401, 403): - LOG.warning("Refreshing websocket token %s", err.request_info.url) - if ( - "client/negotiate" in str(err.request_info.url) - and err.request_info.method == "POST" - ): - LOG.info( - "401 detected on websocket, refreshing websocket token. Old url: {self.ws_url} Old Token: {self.ws_token}" - ) - LOG.info("401 detected on %s", err.request_info.url) - async with self._backoff_refresh_lock_ws: - await self.refresh_ws_token() - await self.get_websocket_params() - return + LOG.warning("Refreshing API token on %s", err.request_info.url) @staticmethod def _handle_on_giveup(_: dict[str, Any]) -> None: @@ -413,57 +387,13 @@ def enable_request_retries(self) -> None: async def _async_post_init(self) -> None: """Perform some post-init actions.""" - LOG.debug("Websocket _async_post_init running") + LOG.debug("SignalR _async_post_init running") await self._get_fid() await self._get_device_token() - # Initialize WebsocketManager ic-dev21 - self.websocket_manager = WebsocketManager( - self.session, self.async_request, self._state_yaml, set_state - ) - await self.websocket_manager.initialize_websockets() - - # Create both websocket clients - # ic-dev21 need to work on this as it can't lint as is, may need to - # instantiate differently - # TODO: fix type ignore after refactor - self.websocket_devices = WebsocketClient(self.websocket_manager.devicehub) # type: ignore - - # For backward compatibility during the transition to challengehub websocket - self.websocket = self.websocket_devices - self.websocket_challenges = WebsocketClient(self.websocket_manager.challengehub) # type: ignore - - async def refresh_ws_token(self) -> None: - """Refresh the websocket token.""" - await self.websocket_manager.refresh_token(self.websocket_manager.devicehub) - await self.websocket_manager.refresh_token(self.websocket_manager.challengehub) - - async def get_websocket_params(self) -> None: - """Retrieves and constructs WebSocket connection parameters from the negotiation endpoint.""" - uri = parse.urlparse(self.ws_url) - LOG.debug("Getting websocket params") - LOG.debug("Getting uri %s", uri) - resp: dict[str, Any] = await self.async_request( - "post", - f"{uri.path}negotiate?{uri.query}", - host=uri.netloc, - headers={ - "authorization": f"Bearer {self.ws_token}", - }, - ) - conn_id: str = resp.get("connectionId", "") - self.full_ws_url = f"{self.ws_url}&id={conn_id}&access_token={self.ws_token}" - LOG.debug("Getting full ws URL %s", self.full_ws_url) - transport_dict: list[WebsocketTransportsDict] = resp.get( - "availableTransports", [] - ) - websocket_dict: WebsocketDict = { - "connection_id": conn_id, - "available_transports": transport_dict, - "full_ws_url": self.full_ws_url, - } - LOG.debug("Calling set_state from get_websocket_params") - await set_state(self._state_yaml, "websocket", websocket_dict) + signalr_manager = SignalRManager(self.async_request) + self.signalr_devices = signalr_manager.build_hub("/DeviceHub") + self.signalr_challenges = signalr_manager.build_hub("/ChallengeHub") async def fb_install(self, fb_id: str) -> None: """Registers a Firebase installation and stores the authentication token state.""" @@ -535,9 +465,7 @@ async def android_register(self) -> None: await set_state( self._state_yaml, "android", - { - "token": token, - }, + cast(AndroidDeviceDict, {"token": token}), ) async def get_location_ids(self) -> tuple[int, str]: @@ -548,26 +476,26 @@ async def get_location_ids(self) -> tuple[int, str]: return (req[0]["id"], req[0]["locationHiloId"]) def set_device_cache(self, devices: list[dict[str, Any]]) -> None: - """Store devices received from websocket DeviceListInitialValuesReceived. + """Store devices received from SignalR DeviceListInitialValuesReceived. - This replaces the old REST API get_devices call. The websocket sends + This replaces the old REST API get_devices call. SignalR sends device data with list-type attributes (supportedAttributesList, etc.) which need to be converted to comma-separated strings to match the format that HiloDevice.update() expects. """ self._device_cache = [self._convert_ws_device(device) for device in devices] LOG.debug( - "Device cache populated with %d devices from websocket", + "Device cache populated with %d devices from SignalR", len(self._device_cache), ) self._device_cache_event.set() @staticmethod def _convert_ws_device(ws_device: dict[str, Any]) -> dict[str, Any]: - """Convert a websocket device dict to the format generate_device expects. + """Convert a SignalR device dict to the format generate_device expects. The REST API returned supportedAttributes/settableAttributes as - comma-separated strings. The websocket returns supportedAttributesList/ + comma-separated strings. SignalR returns supportedAttributesList/ settableAttributesList/supportedParametersList as Python lists. We convert to the old format so HiloDevice.update() works unchanged. """ @@ -590,25 +518,25 @@ def _convert_ws_device(ws_device: dict[str, Any]) -> dict[str, Any]: return device async def wait_for_device_cache(self, timeout: float = 30.0) -> None: - """Wait for the device cache to be populated from websocket. + """Wait for the device cache to be populated from SignalR. :param timeout: Maximum time to wait in seconds :raises TimeoutError: If the device cache is not populated in time """ if self._device_cache_event.is_set(): return - LOG.debug("Waiting for device cache from websocket (timeout=%ss)", timeout) + LOG.debug("Waiting for device cache from SignalR (timeout=%ss)", timeout) try: await asyncio.wait_for(self._device_cache_event.wait(), timeout=timeout) except asyncio.TimeoutError: LOG.error( - "Timed out waiting for device list from websocket after %ss", + "Timed out waiting for device list from SignalR after %ss", timeout, ) raise def get_device_cache(self, location_id: int) -> list[dict[str, Any]]: - """Return cached devices from websocket. + """Return cached devices from SignalR. :param location_id: Hilo location id (unused but kept for interface compat) :return: List of device dicts ready for generate_device() @@ -618,7 +546,7 @@ def get_device_cache(self, location_id: int) -> list[dict[str, Any]]: def add_to_device_cache(self, devices: list[dict[str, Any]]) -> None: """Append new devices to the existing cache (e.g. from DeviceAdded). - Converts websocket format and adds to the cache without replacing + Converts SignalR format and adds to the cache without replacing existing entries. Skips devices already in cache (by id). """ existing_ids = {d.get("id") for d in self._device_cache} diff --git a/pyhilo/devices.py b/pyhilo/devices.py index ce58585..bfcfd4f 100644 --- a/pyhilo/devices.py +++ b/pyhilo/devices.py @@ -23,7 +23,7 @@ def all(self) -> list[HiloDevice]: @property def attributes_list(self) -> list[Union[int, dict[int, list[str]]]]: - """This is sent to websocket to subscribe to the device attributes updates + """This is sent to the SignalR hub to subscribe to the device attributes updates :return: Dict of devices (key) with their attributes. :rtype: list @@ -99,8 +99,8 @@ def generate_device(self, device: dict) -> HiloDevice: return dev async def update(self) -> None: - """Update device list from websocket cache + gateway from REST.""" - # Get devices from websocket cache (already populated by DeviceListInitialValuesReceived) + """Update device list from SignalR cache + gateway from REST.""" + # Get devices from SignalR cache (already populated by DeviceListInitialValuesReceived) cached_devices = self._api.get_device_cache(self.location_id) generated_devices = [] for raw_device in cached_devices: @@ -140,7 +140,7 @@ async def update(self) -> None: async def update_devicelist_from_signalr( self, values: list[dict[str, Any]] ) -> list[HiloDevice]: - """Process device list received from SignalR websocket. + """Process device list received from SignalR hub. This is called when DeviceListInitialValuesReceived arrives. It populates the API device cache and generates HiloDevice objects. @@ -161,7 +161,7 @@ async def update_devicelist_from_signalr( async def add_device_from_signalr( self, values: list[dict[str, Any]] ) -> list[HiloDevice]: - """Process individual device additions from SignalR websocket. + """Process individual device additions from SignalR hub. This is called when DeviceAdded arrives. It appends to the existing cache rather than replacing it. @@ -181,7 +181,7 @@ async def add_device_from_signalr( async def async_init(self) -> None: """Initialize the Hilo "manager" class. - Gets location IDs from REST API, then waits for the websocket + Gets location IDs from REST API, then waits for the SignalR hub to deliver the device list via DeviceListInitialValuesReceived. The gateway is appended from REST. """ @@ -190,5 +190,5 @@ async def async_init(self) -> None: self.location_id = location_ids[0] self.location_hilo_id = location_ids[1] # Device list will be populated when DeviceListInitialValuesReceived - # arrives on the websocket. The hilo integration's async_init will + # arrives on the SignalR hub. The hilo integration's async_init will # call wait_for_device_cache() and then update() after subscribing. diff --git a/pyhilo/exceptions.py b/pyhilo/exceptions.py index d980214..e5f3c40 100644 --- a/pyhilo/exceptions.py +++ b/pyhilo/exceptions.py @@ -27,37 +27,37 @@ class RequestError(HiloError): pass -class WebsocketError(HiloError): - """An error related to generic websocket errors.""" +class SignalRError(HiloError): + """An error related to generic SignalR errors.""" pass -class CannotConnectError(WebsocketError): - """Define a error when the websocket can't be connected to.""" +class CannotConnectError(SignalRError): + """Define a error when the SignalR hub can't be connected to.""" pass -class ConnectionClosedError(WebsocketError): - """Define a error when the websocket closes unexpectedly.""" +class ConnectionClosedError(SignalRError): + """Define a error when the SignalR hub closes unexpectedly.""" pass -class ConnectionFailedError(WebsocketError): - """Define a error when the websocket connection fails.""" +class ConnectionFailedError(SignalRError): + """Define a error when the SignalR connection fails.""" pass -class InvalidMessageError(WebsocketError): - """Define a error related to an invalid message from the websocket server.""" +class InvalidMessageError(SignalRError): + """Define a error related to an invalid message from the SignalR server.""" pass -class NotConnectedError(WebsocketError): - """Define a error when the websocket isn't properly connected to.""" +class NotConnectedError(SignalRError): + """Define a error when the SignalR hub isn't properly connected to.""" pass diff --git a/pyhilo/graphql.py b/pyhilo/graphql.py index 6e176df..59126c2 100644 --- a/pyhilo/graphql.py +++ b/pyhilo/graphql.py @@ -620,6 +620,95 @@ async def subscribe_to_location_updated( location_hilo_id, ) + # Seconds without any SSE event before treating the connection as stalled. + # Hilo's server sends periodic keepalive comments; 180 s is a safe margin. + _SSE_KEEPALIVE_TIMEOUT = 180 + # Reconnection back-off: starts at _BACKOFF_BASE, doubles each failure, caps at _BACKOFF_MAX. + _BACKOFF_BASE = 5 + _BACKOFF_MAX = 300 # 5 minutes + + def _parse_sse_message(self, sse: Any) -> Optional[Dict[str, Any]]: + """Parse raw SSE event data; returns None if it should be skipped.""" + if not sse.data: + return None + try: + result: Dict[str, Any] = json.loads(sse.data) + return result + except json.JSONDecodeError: + return None + + def _handle_sse_message( + self, + data: Dict[str, Any], + handler: Callable[[Dict[str, Any]], str], + callback: Optional[Callable[[str], None]], + ) -> tuple: + """Dispatch a parsed SSE message. Returns (retry_apq, had_success).""" + if "errors" in data: + if any( + e.get("message") == "PersistedQueryNotFound" for e in data["errors"] + ): + return True, False + LOG.error("GraphQL Subscription Errors: %s", data["errors"]) + return False, False + if "data" in data: + LOG.debug("Received subscription result %s", data["data"]) + handler_result = handler(data["data"]) + if callback: + callback(handler_result) + return False, True + return False, False + + async def _drain_sse_events( + self, + event_source: Any, + handler: Callable[[Dict[str, Any]], str], + callback: Optional[Callable[[str], None]], + ) -> tuple: + """Read events until the stream closes. Returns (retry_apq, had_success).""" + had_success = False + sse_iter = event_source.aiter_sse() + while True: + try: + sse = await asyncio.wait_for( + sse_iter.__anext__(), + timeout=self._SSE_KEEPALIVE_TIMEOUT, + ) + except StopAsyncIteration: + break + except asyncio.TimeoutError: + LOG.info( + "SSE keepalive timeout (%ss without data), reconnecting...", + self._SSE_KEEPALIVE_TIMEOUT, + ) + raise + data = self._parse_sse_message(sse) + if data is None: + continue + retry, success = self._handle_sse_message(data, handler, callback) + if retry: + return True, had_success + if success: + had_success = True + return False, had_success + + async def _run_sse_connection( + self, + url: str, + headers: Dict[str, str], + payload: Dict[str, Any], + handler: Callable[[Dict[str, Any]], str], + callback: Optional[Callable[[str], None]], + ) -> tuple: + """Open one SSE connection and drain it. Returns (retry_apq, had_success).""" + async with httpx.AsyncClient( + http2=True, timeout=None, verify=self._ssl_context + ) as client: + async with aconnect_sse( + client, "POST", url, json=payload, headers=headers + ) as event_source: + return await self._drain_sse_events(event_source, handler, callback) + async def _listen_to_sse( self, query: str, @@ -630,70 +719,62 @@ async def _listen_to_sse( ) -> None: query_hash = hashlib.sha256(query.encode("utf-8")).hexdigest() payload: Dict[str, Any] = { - "extensions": { - "persistedQuery": { - "version": 1, - "sha256Hash": query_hash, - } - }, + "extensions": {"persistedQuery": {"version": 1, "sha256Hash": query_hash}}, "variables": variables, } + backoff_delay = self._BACKOFF_BASE while True: try: access_token = await self._get_access_token() url = f"https://{PLATFORM_HOST}/api/digital-twin/v3/graphql" headers = {"Authorization": f"Bearer {access_token}"} - - retry_with_full_query = False - - async with httpx.AsyncClient( - http2=True, timeout=None, verify=self._ssl_context - ) as client: - async with aconnect_sse( - client, "POST", url, json=payload, headers=headers - ) as event_source: - async for sse in event_source.aiter_sse(): - if not sse.data: - continue - try: - data = json.loads(sse.data) - except json.JSONDecodeError: - continue - - if "errors" in data: - if any( - e.get("message") == "PersistedQueryNotFound" - for e in data["errors"] - ): - retry_with_full_query = True - break - LOG.error( - "GraphQL Subscription Errors: %s", data["errors"] - ) - continue - - if "data" in data: - LOG.debug( - "Received subscription result %s", data["data"] - ) - handler_result = handler(data["data"]) - if callback: - callback(handler_result) - - if retry_with_full_query: + retry_apq, had_success = await self._run_sse_connection( + url, headers, payload, handler, callback + ) + if had_success: + backoff_delay = self._BACKOFF_BASE + if retry_apq: payload["query"] = query continue - - except Exception as e: + # Server closed cleanly — brief pause before reconnecting. LOG.debug( - "Subscription connection lost: %s. Reconnecting in 5 seconds...", e + "SSE connection closed by server. Reconnecting in 5 seconds..." ) await asyncio.sleep(5) - # Reset payload to APQ only on reconnect + + except asyncio.TimeoutError: + # Keepalive timeout: the connection was recently live, so start + # the next reconnect cycle with the base delay (not accumulated). + backoff_delay = self._BACKOFF_BASE + LOG.debug( + "Reconnecting after keepalive timeout in %s seconds...", + backoff_delay, + ) + await asyncio.sleep(backoff_delay) + backoff_delay = min(backoff_delay * 2, self._BACKOFF_MAX) if "query" in payload: del payload["query"] + if location_hilo_id: + try: + await self.call_get_location_query(location_hilo_id) + LOG.debug("call_get_location_query success after reconnect") + except Exception as e2: + LOG.error( + "exception while RE-connecting, retrying: %s", + e2, + ) + except Exception as e: + LOG.debug( + "Subscription connection lost: %s. Reconnecting in %s seconds...", + e, + backoff_delay, + ) + await asyncio.sleep(backoff_delay) + backoff_delay = min(backoff_delay * 2, self._BACKOFF_MAX) + if "query" in payload: + del payload["query"] if location_hilo_id: try: await self.call_get_location_query(location_hilo_id) diff --git a/pyhilo/signalr.py b/pyhilo/signalr.py new file mode 100644 index 0000000..8ad6fa1 --- /dev/null +++ b/pyhilo/signalr.py @@ -0,0 +1,271 @@ +"""Define a connection to the Hilo SignalR hubs via pysignalr.""" + +from __future__ import annotations + +import asyncio +from collections import defaultdict +from dataclasses import dataclass, field +from datetime import datetime +from enum import IntEnum +import ssl +from typing import Any, Callable, Optional, Tuple + +from pysignalr.client import SignalRClient + +from pyhilo.const import LOG + + +class SignalRMsgType(IntEnum): + INVOKE = 0x1 + CLOSE = 0x7 + UNKNOWN = 0xFF + + @classmethod + def has_value(cls, value: int) -> bool: + return value in cls._value2member_map_ + + @classmethod + def value(cls, value: int) -> IntEnum: # type: ignore + return cls._value2member_map_.get(value, cls.UNKNOWN) # type: ignore + + +@dataclass(frozen=True) +class SignalREvent: + """Define a representation of a message.""" + + event_type_id: int + target: str + arguments: list[list] + invocation: int | None + error: str | None + timestamp: datetime = field(default_factory=datetime.now) + event_type: str | None = field(init=False) + + def __post_init__(self) -> None: + if SignalRMsgType.has_value(self.event_type_id): + object.__setattr__( + self, "event_type", SignalRMsgType.value(self.event_type_id).name + ) + if self.event_type_id == SignalRMsgType.CLOSE: + LOG.error( + "Received close event from SignalR: Error: %s Target: %s Args: %s", + self.event_type, + self.target, + self.arguments, + ) + + +def signalr_event_from_payload(payload: dict[str, Any]) -> SignalREvent: + """Create a SignalREvent object from a SignalR event payload.""" + return SignalREvent( + payload.get("type", SignalRMsgType.UNKNOWN), + payload.get("target", ""), + payload.get("arguments", ""), + payload.get("invocationId"), + payload.get("error"), + ) + + +class SignalRHub: + """Wraps pysignalr.SignalRClient for a single hub endpoint. + + Lifecycle: + ``run()`` — negotiate fresh token, build client, block until disconnect + ``invoke()`` — send a hub method invocation + ``disconnect()`` — stop the transport cleanly + """ + + def __init__( + self, + negotiate_callback: Callable[[], Any], + ) -> None: + """Initialize. + + :param negotiate_callback: async callable returning ``(azure_url, access_token)`` + """ + self._negotiate = negotiate_callback + self._client: Optional[SignalRClient] = None + self._connect_callbacks: list[Callable[..., Any]] = [] + self._disconnect_callbacks: list[Callable[..., Any]] = [] + self._event_callbacks: list[Callable[..., Any]] = [] + + # ------------------------------------------------------------------ + # Public interface + # ------------------------------------------------------------------ + + @property + def connected(self) -> bool: + """Return True if a client is currently running.""" + return self._client is not None + + def add_connect_callback(self, callback: Callable[..., Any]) -> Callable[..., None]: + """Register a callback to fire after a successful connection.""" + self._connect_callbacks.append(callback) + + def remove() -> None: + self._connect_callbacks.remove(callback) + + return remove + + def add_disconnect_callback( + self, callback: Callable[..., Any] + ) -> Callable[..., None]: + """Register a callback to fire after a disconnection.""" + self._disconnect_callbacks.append(callback) + + def remove() -> None: + self._disconnect_callbacks.remove(callback) + + return remove + + def add_event_callback(self, callback: Callable[..., Any]) -> Callable[..., None]: + """Register a callback to fire for every inbound hub message.""" + self._event_callbacks.append(callback) + + def remove() -> None: + self._event_callbacks.remove(callback) + + return remove + + async def run(self) -> None: + """Negotiate, create the SignalR client, and block until it disconnects. + + Raises on hard errors (auth failure, negotiation failure, etc.). + Returns normally on a clean server-side disconnect. + """ + azure_url, access_token = await self._negotiate() + LOG.info("SignalRHub: negotiate succeeded, connecting to %s", azure_url) + + # Pre-create SSL context off the event loop to avoid blocking HA's loop monitor (Lesson ❷) + ssl_context = await asyncio.get_running_loop().run_in_executor( + None, ssl.create_default_context + ) + + self._client = SignalRClient( + url=azure_url, + headers={"Authorization": f"Bearer {access_token}"}, + retry_count=0, + ssl=ssl_context, + ) + + # Install catch-all handler — must be set AFTER __init__ and BEFORE run() (Lesson ❹) + hub_self = self + + class _CatchAllDict(defaultdict): # noqa: N805 + def __missing__(self_dict, target: str) -> list[Any]: # noqa: N805 + async def _handler(arguments: Any) -> None: + LOG.debug("SignalRHub: received target=%s", target) + event = SignalREvent( + event_type_id=SignalRMsgType.INVOKE, + target=target, + arguments=( + arguments if isinstance(arguments, list) else [arguments] + ), + invocation=None, + error=None, + ) + for cb in hub_self._event_callbacks: + await cb(event) + + self_dict[target] = [_handler] + return [_handler] + + self._client._message_handlers = _CatchAllDict(list) + + # Wire connect/disconnect/error hooks + self._client.on_open(self._on_open) + self._client.on_close(self._on_close) + self._client.on_error(self._on_error) + + try: + await self._client.run() + finally: + self._client = None + + async def invoke(self, method: str, args: list[Any]) -> None: + """Invoke a hub method on the server. + + :param method: Hub method name + :param args: Positional arguments for the method + """ + if self._client is None: + LOG.warning("SignalRHub: invoke(%s) called while disconnected", method) + return + LOG.debug("SignalRHub: invoke method=%s args=%s", method, args) + await self._client.send(method, args) + + async def disconnect(self) -> None: + """Request the client to stop.""" + if self._client is not None: + LOG.info("SignalRHub: disconnecting") + await self._client.stop() + + # ------------------------------------------------------------------ + # Internal pysignalr hooks + # ------------------------------------------------------------------ + + async def _on_open(self) -> None: + LOG.info("SignalRHub: connected") + for cb in self._connect_callbacks: + try: + result = cb() + if asyncio.iscoroutine(result): + await result + except Exception as exc: + LOG.error("SignalRHub: connect callback error: %s", exc) + + async def _on_close(self) -> None: + LOG.info("SignalRHub: disconnected") + for cb in self._disconnect_callbacks: + try: + result = cb() + if asyncio.iscoroutine(result): + await result + except Exception as exc: + LOG.error("SignalRHub: disconnect callback error: %s", exc) + + async def _on_error(self, message: str) -> None: + LOG.error("SignalRHub: error: %s", message) + + +class SignalRManager: + def __init__( + self, + async_request: Callable[..., Any], + ) -> None: + """Initialize. + + :param async_request: The ``API.async_request`` coroutine-factory + """ + self._async_request = async_request + + async def _negotiate(self, endpoint: str) -> Tuple[str, str]: + """POST ``{endpoint}/negotiate`` and return ``(azure_url, access_token)``. + + :param endpoint: e.g. ``/DeviceHub`` or ``/ChallengeHub`` + :returns: Tuple of (azure_url, access_token) + """ + url = f"{endpoint}/negotiate" + LOG.debug("SignalRManager: negotiating %s", url) + resp = await self._async_request("post", url) + azure_url: str = resp["url"] + access_token: str = resp["accessToken"] + LOG.info("SignalRManager: negotiate succeeded for %s → %s", endpoint, azure_url) + return azure_url, access_token + + def build_hub( + self, + endpoint: str, + ) -> SignalRHub: + """Create a ``SignalRHub`` whose negotiate lambda hits ``endpoint``. + + :param endpoint: Hilo hub path (``/DeviceHub`` or ``/ChallengeHub``) + :returns: A ready-to-run ``SignalRHub`` + """ + + async def _negotiate_for_endpoint() -> Tuple[str, str]: + return await self._negotiate(endpoint) + + return SignalRHub( + negotiate_callback=_negotiate_for_endpoint, + ) diff --git a/pyhilo/util/state.py b/pyhilo/util/state.py index 0f180dd..a28db6b 100644 --- a/pyhilo/util/state.py +++ b/pyhilo/util/state.py @@ -34,23 +34,6 @@ class AndroidDeviceDict(TypedDict): device_id: int -class WebsocketTransportsDict(TypedDict): - """Represents a dictionary containing Websocket connection information.""" - - transport: str - transfer_formats: list[str] - - -class WebsocketDict(TypedDict, total=False): - """Represents a dictionary containing registration information.""" - - token: str - connection_id: str - full_ws_url: str - url: str - available_transports: list[WebsocketTransportsDict] - - class RegistrationDict(TypedDict, total=False): """Represents a dictionary containing registration information.""" @@ -73,7 +56,6 @@ class StateDict(TypedDict, total=False): registration: RegistrationDict firebase: FirebaseDict android: AndroidDeviceDict - websocket: WebsocketDict T = TypeVar("T", bound="StateDict") @@ -189,9 +171,7 @@ async def get_state(state_yaml: str, _already_locked: bool = False) -> StateDict async def set_state( state_yaml: str, key: str, - state: ( - TokenDict | RegistrationDict | FirebaseDict | AndroidDeviceDict | WebsocketDict - ), + state: TokenDict | RegistrationDict | FirebaseDict | AndroidDeviceDict, ) -> None: """Save state yaml. diff --git a/pyhilo/websocket.py b/pyhilo/websocket.py deleted file mode 100755 index b6d69c9..0000000 --- a/pyhilo/websocket.py +++ /dev/null @@ -1,570 +0,0 @@ -"""Define a connection to the Hilo websocket.""" - -from __future__ import annotations - -import asyncio -from dataclasses import dataclass, field -from datetime import datetime, timedelta -from enum import IntEnum -import json -from os import environ -from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple -from urllib import parse - -from aiohttp import ClientSession, ClientWebSocketResponse, WSMsgType -from aiohttp.client_exceptions import ( - ClientError, - ServerDisconnectedError, - WSServerHandshakeError, -) -from yarl import URL - -from pyhilo.const import ( - AUTOMATION_CHALLENGE_ENDPOINT, - AUTOMATION_DEVICEHUB_ENDPOINT, - DEFAULT_USER_AGENT, - LOG, -) -from pyhilo.exceptions import ( - CannotConnectError, - ConnectionClosedError, - ConnectionFailedError, - InvalidCredentialsError, - InvalidMessageError, - NotConnectedError, -) -from pyhilo.util import schedule_callback - -if TYPE_CHECKING: - from pyhilo import API - -DEFAULT_WATCHDOG_TIMEOUT = timedelta(minutes=5) - - -class SignalRMsgType(IntEnum): - INVOKE = 0x1 - STREAM = 0x2 - COMPLETE = 0x3 - STREAM_INVOCATION = 0x4 - CANCEL_INVOCATION = 0x5 - PING = 0x6 - CLOSE = 0x7 - UNKNOWN = 0xFF - - @classmethod - def has_value(cls, value: int) -> bool: - return value in cls._value2member_map_ - - @classmethod - def value(cls, value: int) -> IntEnum: # type: ignore - return cls._value2member_map_.get(value, cls.UNKNOWN) # type: ignore - - -@dataclass(frozen=True) -class WebsocketEvent: - """Define a representation of a message.""" - - event_type_id: int - target: str - arguments: list[list] - invocation: int | None - error: str | None - timestamp: datetime = field(default=datetime.now()) - event_type: str | None = field(init=False) - - def __post_init__(self) -> None: - if SignalRMsgType.has_value(self.event_type_id): - object.__setattr__( - self, "event_type", SignalRMsgType.value(self.event_type_id).name - ) - if self.event_type_id == SignalRMsgType.CLOSE: - LOG.error( - f"Received close event from SignalR: Error: {self.event_type} Target: {self.target} Args: {self.arguments} Error: {self.error}" - ) - - -def websocket_event_from_payload(payload: dict[str, Any]) -> WebsocketEvent: - """Create a Message object from a websocket event payload.""" - return WebsocketEvent( - payload["type"], - payload.get("target", ""), - payload.get("arguments", ""), - payload.get("invocationId"), - payload.get("error"), - ) - - -class Watchdog: - """Define a watchdog to kick the websocket connection at intervals.""" - - def __init__( - self, action: Callable[..., Any], timeout: timedelta = DEFAULT_WATCHDOG_TIMEOUT - ): - """Initialize.""" - self._action = action - self._action_task: asyncio.Task | None = None - self._loop = asyncio.get_running_loop() - self._timeout_seconds = timeout.total_seconds() - self._timer_task: asyncio.TimerHandle | None = None - - def _on_expire(self) -> None: - """Log and act when the watchdog expires.""" - LOG.warning("Websocket: Watchdog expired") - schedule_callback(self._action) - - def cancel(self) -> None: - """Cancel the watchdog.""" - if self._timer_task: - self._timer_task.cancel() - self._timer_task = None - - def trigger(self) -> None: - """Trigger the watchdog.""" - if self._timer_task: - self._timer_task.cancel() - - self._timer_task = self._loop.call_later(self._timeout_seconds, self._on_expire) - - -class WebsocketClient: - """A websocket connection to the Hilo cloud. - Note that this class shouldn't be instantiated directly; it will be instantiated as - :param api: A :meth:`pyhilo.API` object - :type api: :meth:`pyhilo.API` - """ - - def __init__(self, api: API) -> None: - """Initialize.""" - self._api = api - self._connect_callbacks: list[Callable[..., None]] = [] - self._disconnect_callbacks: list[Callable[..., None]] = [] - self._event_callbacks: list[Callable[..., None]] = [] - self._loop = asyncio.get_running_loop() - self._watchdog = Watchdog(self.async_reconnect) - self._ready_event: asyncio.Event = asyncio.Event() - self._ready: bool = False - self._queued_tasks: list[asyncio.TimerHandle] = [] - - # These will get filled in after initial authentication: - self._client: ClientWebSocketResponse | None = None - - @property - def connected(self) -> bool: - """Return if currently connected to the websocket.""" - return self._client is not None and not self._client.closed - - @staticmethod - def _add_callback( - callback_list: list, callback: Callable[..., Any] - ) -> Callable[..., None]: - """Add a callback callback to a particular list.""" - callback_list.append(callback) - - def remove() -> None: - """Remove the callback.""" - callback_list.remove(callback) - - return remove - - async def _async_receive_json(self) -> list[Dict[str, Any]]: - """Receive a JSON response from the websocket server.""" - assert self._client - - response = await self._client.receive(300) - - if response.type in (WSMsgType.CLOSE, WSMsgType.CLOSED, WSMsgType.CLOSING): - LOG.error( - "Websocket: Received event to close connection: %s", response.type - ) - raise ConnectionClosedError("Connection was closed.") - - if response.type == WSMsgType.ERROR: - LOG.error( - f"Websocket: Received error event, Connection failed: {response.type}" - ) - raise ConnectionFailedError - - if response.type != WSMsgType.TEXT: - LOG.error("Websocket: Received invalid message: %s", response) - raise InvalidMessageError(f"Received non-text message: {response.type}") - - messages: list[Dict[str, Any]] = [] - try: - # Sometimes the http lib stacks multiple messages in the buffer, we need to split them to process. - received_messages = response.data.strip().split("\x1e") - for msg in received_messages: - data = json.loads(msg) - messages.append(data) - except ValueError as v_exc: - raise InvalidMessageError("Received invalid JSON") from v_exc - except json.decoder.JSONDecodeError as j_exc: - LOG.error("Received invalid JSON: %s", msg) - LOG.exception(j_exc) - data = {} - - self._watchdog.trigger() - - return messages - - async def _async_send_json(self, payload: dict[str, Any]) -> None: - """Send a JSON message to the websocket server. - Raises NotConnectedError if client is not connected. - """ - if not self.connected: - raise NotConnectedError - - assert self._client - - if self._api.log_traces: - LOG.debug( - "[TRACE] Sending data to websocket %s : %s", - self._api.endpoint, - json.dumps(payload), - ) - # Hilo added a control character (chr(30)) at the end of each payload they send. - # They also expect this char to be there at the end of every payload we send them. - LOG.debug("WebsocketClient _async_send_json payload: %s", payload) - await self._client.send_str(json.dumps(payload) + chr(30)) - - def _parse_message(self, msg: dict[str, Any]) -> None: - """Parse an incoming message.""" - if self._api.log_traces: - LOG.debug( - "[TRACE] Received message on websocket(_parse_message) %s: %s", - self._api.endpoint, - msg, - ) - if msg.get("type") == SignalRMsgType.PING: - schedule_callback(self._async_pong) - return - if isinstance(msg, dict) and not len(msg): - self._ready = True - self._ready_event.set() - LOG.info("Websocket: Ready for data") - return - event = websocket_event_from_payload(msg) - for callback in self._event_callbacks: - schedule_callback(callback, event) - - def add_connect_callback(self, callback: Callable[..., Any]) -> Callable[..., None]: - """Add a callback callback to be called after connecting. - :param callback: The method to call after connecting - :type callback: ``Callable[..., None]`` - """ - return self._add_callback(self._connect_callbacks, callback) - - def add_disconnect_callback( - self, callback: Callable[..., Any] - ) -> Callable[..., None]: - """Add a callback callback to be called after disconnecting. - :param callback: The method to call after disconnecting - :type callback: ``Callable[..., None]`` - """ - return self._add_callback(self._disconnect_callbacks, callback) - - def add_event_callback(self, callback: Callable[..., Any]) -> Callable[..., None]: - """Add a callback to be called upon receiving an event. - Note that callbacks should expect to receive a WebsocketEvent object as a - parameter. - :param callback: The method to call after receiving an event. - :type callback: ``Callable[..., None]`` - """ - return self._add_callback(self._event_callbacks, callback) - - async def async_connect(self) -> None: - """Connect to the websocket server.""" - if self.connected: - LOG.debug("Websocket: async_connect() called but already connected") - return - - if self._api.session.closed: - LOG.error("Websocket: Cannot connect, session is closed") - raise CannotConnectError("Session is closed") - - LOG.info("Websocket: Connecting to server %s", self._api.endpoint) - if self._api.log_traces: - LOG.debug("[TRACE] Websocket URL: %s", self._api.full_ws_url) - headers = { - "Sec-WebSocket-Extensions": "permessage-deflate; client_max_window_bits", - "Pragma": "no-cache", - "Cache-Control": "no-cache", - "User-Agent": DEFAULT_USER_AGENT, - "Origin": "http://localhost", - "Accept-Language": "en-US,en;q=0.9", - } - # NOTE(dvd): for troubleshooting purpose we can pass a mitmproxy as env variable - proxy_env: dict[str, Any] = {} - if proxy := environ.get("WS_PROXY"): - proxy_env["proxy"] = proxy - proxy_env["verify_ssl"] = False - - try: - self._client = await self._api.session.ws_connect( - URL( - self._api.full_ws_url, - encoded=True, - ), - heartbeat=55, - headers=headers, - **proxy_env, - ) - except (ClientError, ServerDisconnectedError, WSServerHandshakeError) as err: - LOG.error("Unable to connect to WS server %s", err) - if hasattr(err, "status") and err.status in (401, 403, 404, 409): - raise InvalidCredentialsError("Invalid credentials") from err - except Exception as err: - LOG.error("Unable to connect to WS server %s", err) - raise CannotConnectError(err) from err - - LOG.info("Connected to websocket server %s", self._api.endpoint) - - # Quick pause to prevent race condition - await asyncio.sleep(0.05) - - self._watchdog.trigger() - for callback in self._connect_callbacks: - schedule_callback(callback) - - async def _clean_queue(self) -> None: - """Removes queued tasks.""" - for task in self._queued_tasks: - task.cancel() - - async def async_disconnect(self) -> None: - """Disconnect from the websocket server.""" - await self._clean_queue() - if not self.connected: - return - - assert self._client - - await self._client.close() - - LOG.info("Disconnected from websocket server") - - async def async_listen(self) -> None: - """Start listening to the websocket server.""" - assert self._client - LOG.info("Websocket: Listen started.") - try: - while not self._client.closed: - messages = await self._async_receive_json() - for msg in messages: - self._parse_message(msg) - except asyncio.CancelledError: - LOG.info("Websocket: Listen cancelled.") - raise - except ConnectionClosedError as err: - LOG.error("Websocket: Closed while listening: %s", err) - LOG.exception(err) - pass - except InvalidMessageError as err: - LOG.warning("Websocket: Received invalid json : %s", err) - pass - finally: - LOG.info("Websocket: Listen completed; cleaning up") - self._watchdog.cancel() - await self._clean_queue() - - for callback in self._disconnect_callbacks: - schedule_callback(callback) - - async def async_reconnect(self) -> None: - """Reconnect (and re-listen, if appropriate) to the websocket.""" - LOG.warning("Websocket: Reconnecting") - await self.async_disconnect() - await asyncio.sleep(5) - await self.async_connect() - - async def send_status(self) -> None: - LOG.debug("Sending status") - self._ready = False - await self._async_send_json({"protocol": "json", "version": 1}) - - async def _async_pong(self) -> None: - await self._async_send_json({"type": SignalRMsgType.PING}) - - async def async_invoke( - self, arg: list, target: str, inv_id: int, inv_type: WSMsgType = WSMsgType.TEXT - ) -> None: - """ - Sends an invocation message over the WebSocket connection. - - Waits for the WebSocket to be ready if it is not already, then sends a message - containing the target method, arguments, and invocation ID. - - Args: - arg (list): The list of arguments to send with the invocation. - target (str): The name of the method or action being invoked on the server. - inv_id (int): A unique identifier for this invocation message. - inv_type (WSMsgType, optional): The WebSocket message type. Defaults to WSMsgType.TEXT. - - Returns: - None - - Notes: - If the WebSocket is not ready within 10 seconds, the invocation is skipped. - """ - if not self._ready: - LOG.warning( - f"Delaying invoke {target} {inv_id} {arg}: Websocket not ready." - ) - try: - await asyncio.wait_for(self._ready_event.wait(), timeout=10) - except asyncio.TimeoutError: - return - self._ready_event.clear() - LOG.debug( - "async_invoke invoke argument: %s, invocationId: %s, target: %s, type: %s", - arg, - inv_id, - target, - type, - ) - await self._async_send_json( - { - "arguments": arg, - "invocationId": str(inv_id), - "target": target, - "type": inv_type, - } - ) - - -@dataclass -class WebsocketConfig: - """Configuration for a websocket connection""" - - endpoint: str - url: Optional[str] = None - token: Optional[str] = None - connection_id: Optional[str] = None - full_ws_url: Optional[str] = None - log_traces: bool = True - session: ClientSession | None = None - - -class WebsocketManager: - """Manages multiple websocket connections for the Hilo API""" - - def __init__( - self, - session: ClientSession, - async_request: Callable[..., Any], - state_yaml: str, - set_state_callback: Callable[..., Any], - ) -> None: - """Initialize the websocket manager. - - Args: - session: The aiohttp client session - async_request: The async request method from the API class - state_yaml: Path to the state file - set_state_callback: Callback to save state - """ - self.session = session - self.async_request = async_request - self._state_yaml = state_yaml - self._set_state = set_state_callback - self._shared_token: Optional[str] = None - # Initialize websocket configurations, more can be added here - self.devicehub = WebsocketConfig( - endpoint=AUTOMATION_DEVICEHUB_ENDPOINT, session=session - ) - self.challengehub = WebsocketConfig( - endpoint=AUTOMATION_CHALLENGE_ENDPOINT, session=session - ) - - async def initialize_websockets(self) -> None: - """Initialize both websocket connections""" - await self.refresh_token(self.devicehub, get_new_token=True) - await self.refresh_token(self.challengehub, get_new_token=True) - - async def refresh_token( - self, config: WebsocketConfig, get_new_token: bool = True - ) -> None: - """Refresh token for a specific websocket configuration. - Args: - config: The websocket configuration to refresh - """ - if get_new_token: - config.url, self._shared_token = await self._negotiate(config) - config.token = self._shared_token - else: - config.url, _ = await self._negotiate(config) - config.token = self._shared_token - - await self._get_websocket_params(config) - - async def _negotiate(self, config: WebsocketConfig) -> Tuple[str, str]: - """Negotiate websocket connection and get URL and token. - Args: - config: The websocket configuration to negotiate - Returns: - Tuple containing the websocket URL and access token - """ - LOG.debug("Getting websocket url for %s", config.endpoint) - url = f"{config.endpoint}/negotiate" - LOG.debug("Negotiate URL is %s", url) - - resp = await self.async_request("post", url) - ws_url = resp.get("url") - ws_token = resp.get("accessToken") - - # Save state - state_key = ( - "websocketDevices" - if config.endpoint == AUTOMATION_DEVICEHUB_ENDPOINT - else "websocketChallenges" - ) - await self._set_state( - self._state_yaml, - state_key, - { - "url": ws_url, - "token": ws_token, - }, - ) - - return ws_url, ws_token - - async def _get_websocket_params(self, config: WebsocketConfig) -> None: - """Get websocket parameters including connection ID. - - Args: - config: The websocket configuration to get parameters for - """ - uri = parse.urlparse(config.url) - LOG.debug("Getting websocket params for %s", config.endpoint) - LOG.debug("Getting uri %s", uri) - - resp = await self.async_request( - "post", - f"{uri.path}negotiate?{uri.query}", # type: ignore - host=uri.netloc, - headers={ - "authorization": f"Bearer {config.token}", - }, - ) - - config.connection_id = resp.get("connectionId", "") - config.full_ws_url = ( - f"{config.url}&id={config.connection_id}&access_token={config.token}" - ) - LOG.debug("Getting full ws URL %s", config.full_ws_url) - - transport_dict = resp.get("availableTransports", []) - websocket_dict = { - "connection_id": config.connection_id, - "available_transports": transport_dict, - "full_url": config.full_ws_url, - } - - # Save state - state_key = ( - "websocketDevices" - if config.endpoint == AUTOMATION_DEVICEHUB_ENDPOINT - else "websocketChallenges" - ) - LOG.debug("Calling set_state %s_params", state_key) - await self._set_state(self._state_yaml, state_key, websocket_dict) diff --git a/pyproject.toml b/pyproject.toml index 8fdbc04..bbe4893 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -66,6 +66,7 @@ classifiers = [ [tool.poetry.dependencies] aiohttp = ">=3.8.0" aiofiles = ">=23.2.1" +pysignalr = ">=1.3.0" aiosignal = ">=1.2.0" async-timeout = ">=4.0.0" attrs = ">=21.2.0" diff --git a/requirements.txt b/requirements.txt index 20593bb..0377cb6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,6 @@ aiofiles >=23.2.1 aiohttp>=3.8.0 +pysignalr>=1.3.0 aiosignal>=1.1.0 async-timeout>=4.0.0 attrs>=23.1.0