Source code for pydoover.docker.device_agent.device_agent

import asyncio
import copy
import logging
import re
import sys
import json

from collections.abc import Callable
from datetime import datetime, timezone
from typing import Any

from ...utils.snowflake import generate_snowflake_id_at

import grpc

from google.protobuf import json_format
from google.protobuf.struct_pb2 import Struct
from google.protobuf.json_format import MessageToDict


from ...models.generated.device_agent import device_agent_pb2, device_agent_pb2_grpc
from ...models.data import (
    Aggregate,
    AggregateUpdateEvent,
    ChannelSyncEvent,
    EventSubscription,
    File,
    Message,
    MessageCreateEvent,
    MessageUpdateEvent,
    OneShotMessage,
    TurnCredential,
    Attachment,
)
from ..grpc_interface import GRPCInterface
from ...models.data.exceptions import DooverAPIError, NotFoundError
from ...cli.decorators import command as cli_command

log = logging.getLogger(__name__)

_VALID_KEY_RE = re.compile(r"^[a-zA-Z0-9_-]+$")
_SCALAR_TYPES = (bool, int, float, str, type(None))


def validate_payload(data, _path=""):
    """Validate that a payload is compatible with doover channel data.

    The top level must be a dict. Keys must be strings containing only
    alphanumeric characters, hyphens, and underscores. Values may be
    dicts, lists, strings, numbers, booleans, or None.

    Raises ValueError with a clear path to the offending key/value.
    """
    if not _path and not isinstance(data, dict):
        raise ValueError(f"Payload must be a dict, got {type(data).__name__}")

    if isinstance(data, dict):
        for key, value in data.items():
            key_path = f"{_path}.{key}" if _path else key
            if not isinstance(key, str):
                raise ValueError(
                    f"Keys must be strings, "
                    f"got {type(key).__name__} ({key!r}) at '{_path or 'root'}'"
                )
            if not _VALID_KEY_RE.match(key):
                raise ValueError(
                    f"Key '{key}' at '{_path or 'root'}' contains invalid characters — "
                    f"only a-z, A-Z, 0-9, hyphens and underscores are allowed"
                )
            validate_payload(value, key_path)
    elif isinstance(data, list):
        for i, item in enumerate(data):
            validate_payload(item, f"{_path}[{i}]")
    elif not isinstance(data, _SCALAR_TYPES):
        raise ValueError(
            f"Unsupported type {type(data).__name__} at '{_path}' — "
            f"allowed types: dict, list, str, int, float, bool, None"
        )


[docs] class DeviceAgentInterface(GRPCInterface): """Interface for interacting with the Device Agent gRPC service. Attributes ---------- dda_timeout : int Timeout for requests to the Device Agent service. max_connection_attempts : int Maximum number of attempts to connect to the Device Agent service. time_between_connection_attempts : int Time to wait between connection attempts to the Device Agent service. is_dda_available : bool Whether the Device Agent service is available. This is set to True once a successful request has been made to the service. is_dda_online: bool Whether the Device Agent service is currently online. has_dda_been_online: bool Whether the Device Agent service has been online at least once since the interface was created. last_channel_message_ts : dict A dictionary that stores the last time a message was received from each channel. """ stub = device_agent_pb2_grpc.deviceAgentStub def __init__( self, app_key: str, dda_uri: str = "127.0.0.1:50051", dda_timeout: int = 7, max_conn_attempts: int = 5, time_between_connection_attempts: int = 10, service_name: str = "doover.DeviceAgent", ): super().__init__(app_key, dda_uri, service_name, dda_timeout) self.dda_timeout = dda_timeout self.max_connection_attempts = max_conn_attempts self.time_between_connection_attempts = time_between_connection_attempts self.is_dda_available = False self.is_dda_online = False self.has_dda_been_online = False self.agent_id = None # Single event stream per channel, distributing to all registered callbacks self._event_callbacks: dict[str, list[tuple[Callable, EventSubscription]]] = {} self._stream_tasks: dict[str, asyncio.Task] = {} # Aggregate state tracking self._synced_channels: dict[str, bool] = {} self._aggregates: dict[str, Aggregate] = {} self.last_channel_message_ts: dict[str, datetime] = {}
[docs] @staticmethod def has_persistent_connection(): """For the Device Agent, this always returns `True`. This method exists to provide interoperability with the API client.""" return True
@cli_command() def get_is_dda_available(self): return self.is_dda_available @cli_command() def get_is_dda_online(self): return self.is_dda_online @cli_command() def get_has_dda_been_online(self): return self.has_dda_been_online async def wait_until_healthy(self, timeout: float = 10): start_time = datetime.now(tz=timezone.utc) backoff = 1 while True: try: healthy = await self.health_check() except Exception as e: log.error(f"Failed to get DDA comms: {e}") healthy = False if healthy: log.info("DDA is available.") return True if (datetime.now(tz=timezone.utc) - start_time).seconds > timeout: log.warning( f"Timed out waiting {timeout} seconds for DDA to become available" ) return False log.info(f"DDA is not available. Retrying in {backoff} seconds...") await asyncio.sleep(backoff) backoff = min(backoff * 2, 1) def _ensure_stream(self, channel_name: str) -> None: """Ensure a single event stream is running for this channel.""" if channel_name not in self._stream_tasks: self._stream_tasks[channel_name] = asyncio.create_task( self._run_channel_stream(channel_name) )
[docs] def add_event_callback( self, channel_name: str, callback: Callable, events: EventSubscription = EventSubscription.all, ) -> None: """Register a callback for events on a channel. The callback receives a single event argument, one of ``MessageCreateEvent``, ``MessageUpdateEvent``, ``AggregateUpdateEvent``, or ``OneShotMessage``, filtered by the ``events`` parameter. The channel name is accessible via the event payload itself. Starts the event stream for the channel if not already running. Parameters ---------- channel_name : str Name of channel to subscribe to. callback : Callable An async callback ``(event) -> None``. events : EventSubscription, optional Which event types to deliver. Defaults to ``EventSubscription.all``. """ entry = (callback, events) try: self._event_callbacks[channel_name].append(entry) except KeyError: self._event_callbacks[channel_name] = [entry] self._ensure_stream(channel_name)
@staticmethod def _event_type_to_flag(event) -> EventSubscription | None: if isinstance(event, OneShotMessage): return EventSubscription.oneshot_message elif isinstance(event, MessageCreateEvent): return EventSubscription.message_create elif isinstance(event, MessageUpdateEvent): return EventSubscription.message_update elif isinstance(event, AggregateUpdateEvent): return EventSubscription.aggregate_update elif isinstance(event, ChannelSyncEvent): return EventSubscription.channel_sync return None async def _run_channel_stream(self, channel_name: str): """Single event stream per channel. Seeds aggregate cache, then distributes events.""" await self.wait_until_healthy() # Seed the aggregate cache, then fire ChannelSyncEvent so subscribers get the initial state. agg = None try: agg = await self.fetch_channel_aggregate(channel_name) self._aggregates[channel_name] = agg except NotFoundError: log.info( f"Channel '{channel_name}' not found, creating with empty aggregate" ) try: agg = await self.update_channel_aggregate(channel_name, {}) except Exception as e: log.error(f"Failed to create channel '{channel_name}': {e}") else: self._aggregates[channel_name] = agg except Exception as e: log.error(f"Failed to seed aggregate cache for '{channel_name}': {e}") self._synced_channels[channel_name] = True if agg is not None: sync_event = ChannelSyncEvent(aggregate=agg) for callback, events in self._event_callbacks.get(channel_name, []): if EventSubscription.channel_sync not in events: continue try: asyncio.create_task(callback(sync_event)) except Exception as e: log.error( f"Error dispatching channel sync callback for {channel_name}: {e}", exc_info=e, ) # Wrap the event loop in a retry loop so any uncaught exception from # stream_channel_events or the dispatch body does not silently kill the # subscription task. Without this, an unexpected error leaves the task # dead and _ensure_stream has no mechanism to restart it, causing # subscriptions to stop firing until the process restarts. while True: try: async for event in self.stream_channel_events(channel_name): # Update internal aggregate state on AggregateUpdate if isinstance(event, AggregateUpdateEvent): self._aggregates[channel_name] = event.aggregate self._synced_channels[channel_name] = True self.last_channel_message_ts[channel_name] = datetime.now( tz=timezone.utc ) # Determine which flag this event corresponds to event_flag = self._event_type_to_flag(event) # Distribute to matching registered callbacks for callback, events in self._event_callbacks.get(channel_name, []): if event_flag is None or event_flag not in events: continue try: asyncio.create_task(callback(event)) except Exception as e: log.error( f"Error dispatching event callback for {channel_name}: {e}", exc_info=e, ) except asyncio.CancelledError: raise except BaseException as e: log.exception( f"Stream task for {channel_name} crashed, restarting: {e}" ) await asyncio.sleep(1) continue async def stream_channel_events(self, channel_name: str): backoff = 1 while True: try: async with grpc.aio.insecure_channel(self.uri) as channel: pl = device_agent_pb2.ChannelEventSubscriptionRequest( channel_name=channel_name ) channel_stream = device_agent_pb2_grpc.deviceAgentStub( channel ).ChannelEventSubscription(pl) backoff = 1 # reset on successful connection while True: try: response: device_agent_pb2.ChannelEventSubscriptionResponse = await channel_stream.read() log.debug( f"Received event response from subscription request on {channel_name}: {str(response)[:120]}" ) if not response.response_header.success: raise RuntimeError( f"Failed to subscribe to channel {channel_name}: {response.response_header.response_message}" ) match response.event_name: case "MessageCreate": yield MessageCreateEvent.from_dict( MessageToDict(response.data) ) case "MessageUpdate": yield MessageUpdateEvent.from_dict( MessageToDict(response.data) ) case "AggregateUpdate": yield AggregateUpdateEvent.from_dict( MessageToDict(response.data) ) case "OneShotMessage": yield OneShotMessage.from_dict( MessageToDict(response.data) ) except StopAsyncIteration: log.debug("Channel event stream ended.") break except Exception as e: log.error( f"Error in channel event stream for {channel_name}: {e}", exc_info=e, ) await asyncio.sleep(backoff) backoff = min(backoff * 2, self.time_between_connection_attempts) def process_response(self, stub_call: str, response, *args, **kwargs): if response is not None: self.update_dda_status(response.response_header) return super().process_response(stub_call, response, *args, **kwargs) def update_dda_status(self, header): if header.success: self.is_dda_available = True else: self.is_dda_available = False if header.cloud_synced: self.is_dda_online = True if not self.has_dda_been_online: log.info("Device Agent is online") self.has_dda_been_online = True else: self.is_dda_online = False
[docs] def is_channel_synced(self, channel_name): """Check if a channel is synced with DDA. During normal operation, this should always return `True` while DDA is active. It is only really useful for timing during the startup process. Parameters ---------- channel_name : str Name of the channel to check. Returns ------- bool True if the channel is synced, False otherwise. """ if channel_name not in self._event_callbacks: return False if channel_name not in self._synced_channels: return False return self._synced_channels[channel_name]
[docs] async def wait_for_channels_sync( self, channel_names: list[str], timeout: int = 5, inter_wait: float = 0.2 ) -> bool: """Wait for all specified channels to be synced with DDA. This is invoked internally at startup to ensure that all channels are ready before proceeding with operations that depend on them. You shouldn't need to use this during normal operation. Parameters ---------- channel_names : list[str] List of channel names to check for sync status. timeout : int Maximum time to wait for all channels to sync, in seconds. inter_wait : float Time to wait between checks, in seconds. Returns ------- bool True if all channels are synced within the timeout, False otherwise. """ start_time = datetime.now(tz=timezone.utc) while not all( [self.is_channel_synced(channel_name) for channel_name in channel_names] ): if (datetime.now(tz=timezone.utc) - start_time).seconds > timeout: return False await asyncio.sleep(inter_wait) return True
[docs] @cli_command() async def fetch_channel_aggregate(self, channel_name: str) -> Aggregate: """Fetch a channel's current aggregate payload. If the channel has been subscribed to via :meth:`add_event_callback`, the cached aggregate is returned. Otherwise, a gRPC call is made to fetch it. Examples -------- >>> aggregate = await self.device_agent.fetch_channel_aggregate("my_channel") >>> print(aggregate.data) Parameters ---------- channel_name : str Name of channel to get aggregate from. Returns ------- Aggregate Aggregate from channel. Raises ------ NotFoundError If the channel does not exist. DooverAPIError If the request fails. """ if channel_name in self._aggregates: return copy.deepcopy(self._aggregates[channel_name]) log.debug(f"Getting channel aggregate for {channel_name}") resp = await self.make_request( "GetAggregate", device_agent_pb2.GetAggregateRequest(channel_name=channel_name), ) return Aggregate.from_proto(resp.aggregate)
@cli_command() async def fetch_turn_token( self, ) -> TurnCredential: resp = await self.make_request( "GetTurnCredential", device_agent_pb2.TurnCredentialRequest( header=device_agent_pb2.RequestHeader(app_id=self.app_key) ), ) return TurnCredential.from_proto(resp.turn_credential) @cli_command() async def fetch_message( self, channel_name: str, message_id: int, ) -> Message: resp = await self.make_request( "GetMessage", device_agent_pb2.GetMessageRequest( channel_name=channel_name, message_id=message_id, ), ) return Message.from_proto(resp.message) @cli_command() async def list_messages( self, channel_name: str, before: int | datetime | None = None, after: int | datetime | None = None, limit: int | None = None, field_names: list[str] | None = None, ) -> list[Message]: kwargs = {} if before is not None: kwargs["before"] = ( before if isinstance(before, int) else generate_snowflake_id_at(before) ) if after is not None: kwargs["after"] = ( after if isinstance(after, int) else generate_snowflake_id_at(after) ) if limit is not None: kwargs["limit"] = limit if field_names is not None: if isinstance(field_names, str): field_names = [f.strip() for f in field_names.split(",")] kwargs["field_names"] = field_names resp = await self.make_request( "GetMessages", device_agent_pb2.GetMessagesRequest( channel_name=channel_name, **kwargs, ), ) return [Message.from_proto(m) for m in resp.messages] @cli_command() async def create_message( self, channel_name: str, data: dict[str, Any], files: list[File] = None, timestamp: datetime = None, ) -> int: validate_payload(data) d = Struct() json_format.ParseDict(data, d) files = files or [] timestamp = (timestamp or datetime.now(tz=timezone.utc)).timestamp() * 1000 req = device_agent_pb2.CreateMessageRequest( header=device_agent_pb2.RequestHeader(app_id=self.app_key), channel_name=channel_name, data=d, files=[file.to_proto() for file in files], timestamp=int(timestamp), ) resp = await self.make_request("CreateMessage", req) return resp.message_id @cli_command() async def send_oneshot_message( self, channel_name: str, data: dict[str, Any], timestamp: datetime | None = None ) -> bool: # Oneshot messages are fire-and-forget (e.g. live tags) — best-effort by # nature, with no retry or delivery guarantee. A DDA that's too old to accept # the request rejects it, so swallow transport/rejection errors and return # False rather than crashing the caller's loop. Local payload errors (e.g. a # bad ``data`` dict failing ParseDict) are caller bugs and still raise. d = Struct() json_format.ParseDict(data, d) req = device_agent_pb2.SendOneShotMessageRequest( header=device_agent_pb2.RequestHeader(app_id=self.app_key), channel_name=channel_name, data=d, ) if timestamp is not None: req.timestamp = int(timestamp.timestamp() * 1000) try: await self.make_request("SendOneShotMessage", req) except DooverAPIError as e: log.warning( "Oneshot message to '%s' dropped (device agent rejected or " "unreachable): %s", channel_name, e, ) return False return True @cli_command() async def update_message( self, channel_name: str, message_id: int, data: dict[str, Any], files: list[File] = None, replace_data: bool = False, clear_attachments: bool = False, ) -> Message: validate_payload(data) d = Struct() json_format.ParseDict(data, d) files = files or [] req = device_agent_pb2.UpdateMessageRequest( header=device_agent_pb2.RequestHeader(app_id=self.app_key), channel_name=channel_name, message_id=str(message_id), data=d, files=[file.to_proto() for file in files], clear_attachments=clear_attachments, replace_data=replace_data, ) resp = await self.make_request("UpdateMessage", req) return Message.from_proto(resp.message) @cli_command() async def update_channel_aggregate( self, channel_name: str, data: dict[str, Any], files: list[File] = None, clear_attachments: bool = False, replace_data: bool = False, max_age_secs: float = None, ): validate_payload(data) d = Struct() json_format.ParseDict(data, d) files = files or [] req = device_agent_pb2.UpdateAggregateRequest( channel_name=channel_name, data=d, files=[file.to_proto() for file in files], clear_attachments=clear_attachments, replace_data=replace_data, max_age_secs=max_age_secs, ) resp = await self.make_request("UpdateAggregate", req) return Aggregate.from_proto(resp.aggregate) @cli_command() async def fetch_message_attachment(self, attachment: Attachment) -> File: req = device_agent_pb2.FetchAttachmentRequest( attachment=attachment.to_proto(), ) resp = await self.make_request("FetchAttachment", req) return File.from_proto(resp.file) async def close(self): for task in self._stream_tasks.values(): task.cancel() self._stream_tasks.clear() logging.info("Closing device agent interface...")
[docs] @cli_command() async def listen_channel(self, channel_name: str) -> None: """Listen to channel events, printing the output to the console. Parameters ---------- channel_name : str Name of channel to listen to. """ try: async for event in self.stream_channel_events(channel_name): print(json.dumps(obj=event.to_dict())) sys.stdout.flush() except asyncio.CancelledError: await self.close()
class MockDeviceAgentInterface(DeviceAgentInterface): """ This interface is used to test the Device Agent Interface without relying on a real Device Agent service. """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.is_dda_online = True self.is_dda_available = True self.has_dda_been_online = True async def wait_for_channels_sync( self, channel_names: list[str], timeout: int = 5, inter_wait: float = 0.2 ): for channel in channel_names: if channel not in self._aggregates: self._aggregates[channel] = Aggregate( data={}, attachments=[], last_updated=None ) self._synced_channels[channel] = True return True async def _run_channel_stream(self, channel_name: str): # No-op in mock — no real event stream to listen to return async def fetch_channel_aggregate(self, channel_name): return copy.deepcopy( self._aggregates.get( channel_name, Aggregate(data={}, attachments=[], last_updated=None), ) ) async def wait_until_healthy(self, timeout: float = 10): return True async def make_request(self, *args, **kwargs): raise NotImplementedError("make_request is not implemented") async def update_channel_aggregate(self, channel_name, data, **kwargs): existing = self._aggregates.get( channel_name, Aggregate(data={}, attachments=[], last_updated=None) ) existing.data.update(data) self._aggregates[channel_name] = existing return copy.deepcopy(existing) async def create_message(self, channel_name, data, **kwargs): return 0