Source code for greengrasssdk.stream_manager.streammanagerclient

"""
Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
SPDX-License-Identifier: Apache-2.0
"""

import asyncio
import logging
import os
from threading import Thread
from typing import List, Optional

import cbor2

from .data import (
    AppendMessageRequest,
    AppendMessageResponse,
    ConnectRequest,
    ConnectResponse,
    CreateMessageStreamRequest,
    CreateMessageStreamResponse,
    DeleteMessageStreamRequest,
    DeleteMessageStreamResponse,
    DescribeMessageStreamRequest,
    DescribeMessageStreamResponse,
    ListStreamsRequest,
    ListStreamsResponse,
    Message,
    MessageFrame,
    MessageStreamDefinition,
    MessageStreamInfo,
    Operation,
    ReadMessagesOptions,
    ReadMessagesRequest,
    ReadMessagesResponse,
    ResponseStatusCode,
    UnknownOperationError,
    UpdateMessageStreamRequest,
    UpdateMessageStreamResponse,
    VersionInfo,
)
from .exceptions import ClientException, ConnectFailedException, StreamManagerException, ValidationException
from .utilinternal import UtilInternal

# Version of the Python SDK.
# NOTE: This version is independent of the StreamManager PROTOCOL_VERSION, which versions the data format
#  over the wire. When bumping the PROTOCOL_VERSION, consider adding the old version to
#  __OLD_SUPPORTED_PROTOCOL_VERSIONS list (if you intend to support it). Nothing else is needed to bump
#  this SDK_VERSION.
SDK_VERSION = "1.1.1"


[docs]class StreamManagerClient: """ Creates a client for the Greengrass StreamManager. All parameters are optional. :param host: The host which StreamManager server is running on. Default is localhost. :param port: The port which StreamManager server is running on. Default is found in environment variables. :param connect_timeout: The timeout in seconds for connecting to the server. Default is 3 seconds. :param request_timeout: The timeout in seconds for all operations. Default is 60 seconds. :param logger: A logger to use for client logging. Default is Python's builtin logger. :raises: :exc:`~.exceptions.StreamManagerException` and subtypes if authenticating to the server fails. :raises: :exc:`asyncio.TimeoutError` if the request times out. :raises: :exc:`ConnectionError` if the client is unable to connect to the server. """ # List of supported protocol protocol. # These are meant to be used for graceful degradation if the server does not support the current SDK version. __OLD_SUPPORTED_PROTOCOL_VERSIONS = ["1.0.0"] __CONNECT_VERSION = 1 def __init__( self, host="127.0.0.1", port=None, connect_timeout=3, request_timeout=60, logger=logging.getLogger("StreamManagerClient"), ): self.host = host if port is None: port = int(os.getenv("STREAM_MANAGER_SERVER_PORT", 8088)) self.port = port self.__requests = {} self.connect_timeout = connect_timeout self.request_timeout = request_timeout self.logger = logger self.auth_token = os.getenv("AWS_CONTAINER_AUTHORIZATION_TOKEN") # Python Logging doesn't have a TRACE level # so we will add our own at level 5. (Debug is level 10) if logger.level <= 5: logging.addLevelName(5, "TRACE") self.__loop = asyncio.new_event_loop() self.__closed = False self.__reader = None self.__writer = None # Defines a function to be run in a separate thread to run the event loop # this enables our synchronous interface without locks def run_event_loop(loop: asyncio.AbstractEventLoop): try: loop.run_forever() finally: loop.close() # Making the thread a daemon will kill the thread once the main thread closes self.__event_loop_thread = Thread(target=run_event_loop, args=(self.__loop,), daemon=True) self.__event_loop_thread.start() self.connected = False UtilInternal.sync(self.__connect(), loop=self.__loop) def __enter__(self): return self def __exit__(self, type, value, traceback): self.close() async def _close(self): if self.__writer is not None: self.__closed = True self.connected = False self.__reader = None # Drain any existing data waiting to be sent await self.__writer.drain() self.__writer.close() try: # Only added in Python 3.7, so try to call it, but otherwise just skip it await self.__writer.wait_closed() except AttributeError: pass self.__writer = None def __check_closed(self): if self.__closed: raise StreamManagerException("Client is closed. Create a new client first.") async def __connect(self): self.__check_closed() if self.connected: return try: self.logger.debug("Opening connection to %s:%d", self.host, self.port) future = asyncio.open_connection(self.host, self.port) self.__reader, self.__writer = await asyncio.wait_for( future, timeout=self.connect_timeout ) await asyncio.wait_for(self.__connect_request_response(), timeout=self.request_timeout) self.logger.debug("Socket connected successfully. Starting read loop.") self.connected = True self.__loop.create_task(self.__read_loop()) except ConnectionError as e: self.logger.error("Connection error while connecting to server: %s", e) raise def __log_trace(self, *args, **kwargs): self.logger.log(5, *args, **kwargs) async def __read_message_frame(self): length_bytes = await self.__reader.read(n=4) if len(length_bytes) == 0: raise asyncio.IncompleteReadError(length_bytes, 4) length = UtilInternal.int_from_bytes(length_bytes) operation = UtilInternal.int_from_bytes(await self.__reader.read(n=1)) # Read from the socket until we have read the full packet payload = bytearray() read_bytes = 1 while read_bytes < length: next_payload = await self.__reader.read(n=length - read_bytes) if len(next_payload) == 0: raise asyncio.IncompleteReadError(next_payload, length - read_bytes) payload.extend(next_payload) read_bytes += len(next_payload) try: op = Operation.from_dict(operation) except ValueError: self.logger.error("Found unknown operation %d", operation) op = Operation.Unknown return MessageFrame(operation=op, payload=bytes(payload)) async def __read_loop(self): # Continually try to read packets from the socket while not self.__closed: try: try: self.__log_trace("Starting long poll read") response = await self.__read_message_frame() self.__log_trace("Got message frame from server: %s", response) except asyncio.IncompleteReadError: if self.__closed: return self.logger.error("Unable to read from socket, likely socket is closed or server died") self.connected = False try: await self.__connect() except ConnectionError: # Already logged in __connect, so just ignore it here pass except ConnectFailedException: # Already logged in __connect_request_response, so just ignore it here pass return payload = cbor2.loads(response.payload) await self.__handle_read_response(payload, response) except Exception: self.logger.exception("Unhandled exception occurred") return async def __handle_read_response(self, payload, response): if response.operation == Operation.ReadMessagesResponse: response = ReadMessagesResponse.from_dict(payload) self.logger.debug("Received ReadMessagesResponse from server") await self.__requests[response.request_id].put(response) elif response.operation == Operation.CreateMessageStreamResponse: response = CreateMessageStreamResponse.from_dict(payload) self.logger.debug("Received CreateMessageStreamResponse from server: %s", response) await self.__requests[response.request_id].put(response) elif response.operation == Operation.DeleteMessageStreamResponse: response = DeleteMessageStreamResponse.from_dict(payload) self.logger.debug("Received DeleteMessageStreamResponse from server: %s", response) await self.__requests[response.request_id].put(response) elif response.operation == Operation.UpdateMessageStreamResponse: response = UpdateMessageStreamResponse.from_dict(payload) self.logger.debug("Received UpdateMessageStreamResponse from server: %s", response) await self.__requests[response.request_id].put(response) elif response.operation == Operation.AppendMessageResponse: response = AppendMessageResponse.from_dict(payload) self.logger.debug("Received AppendMessageResponse from server: %s", response) await self.__requests[response.request_id].put(response) elif response.operation == Operation.ListStreamsResponse: response = ListStreamsResponse.from_dict(payload) self.logger.debug("Received ListStreamsResponse from server: %s", response) await self.__requests[response.request_id].put(response) elif response.operation == Operation.DescribeMessageStreamResponse: response = DescribeMessageStreamResponse.from_dict(payload) self.logger.debug("Received DescribeMessageStreamResponse from server: %s", response) await self.__requests[response.request_id].put(response) elif response.operation == Operation.UnknownOperationError: self.logger.error( "Received response with unsupported operation from server: %s. " "You should update your server version", response.operation, ) response = UnknownOperationError.from_dict(payload) await self.__requests[response.request_id].put(response) elif response.operation == Operation.Unknown: self.logger.error("Received response with unknown operation from server: %s", response) try: request_id = cbor2.loads(response.payload)["requestId"] await self.__requests[request_id].put(response) except Exception: # We tried our best to figure out the request id, but it failed. # We already logged the unknown operation, so there's nothing # else we can do at this point pass else: self.logger.error("Received data with unhandled operation %s.", response.operation) async def __connect_request_response(self): data = ConnectRequest() data.request_id = UtilInternal.get_request_id() data.sdk_version = SDK_VERSION data.other_supported_protocol_versions = self.__OLD_SUPPORTED_PROTOCOL_VERSIONS data.protocol_version = VersionInfo.PROTOCOL_VERSION.value if self.auth_token is not None: data.auth_token = self.auth_token # Write the connect version self.__writer.write(UtilInternal.int_to_bytes(self.__CONNECT_VERSION, 1)) # Write request to socket frame = MessageFrame(operation=Operation.Connect, payload=cbor2.dumps(data.as_dict())) for b in UtilInternal.encode_frame(frame): self.__writer.write(b) await self.__writer.drain() # Read connect version connect_response_version_byte = await self.__reader.read(n=1) if len(connect_response_version_byte) == 0: raise asyncio.IncompleteReadError(connect_response_version_byte, 1) connect_response_version = UtilInternal.int_from_bytes(connect_response_version_byte) if connect_response_version != self.__CONNECT_VERSION: self.logger.error("Unexpected response from the server, Connect version: %s.", connect_response_version) raise ConnectFailedException("Failed to establish connection with the server") # Read connect response response = await self.__read_message_frame() # type: MessageFrame if response.operation == Operation.ConnectResponse: payload = cbor2.loads(response.payload) response = ConnectResponse.from_dict(payload) # type: ConnectResponse self.logger.debug("Received ConnectResponse from server: %s", response) else: self.logger.error("Received data with unexpected operation %s.", response.operation) raise ConnectFailedException("Failed to establish connection with the server") if response.status != ResponseStatusCode.Success: self.logger.error("Received ConnectResponse with unexpected status %s.", response.status) raise ConnectFailedException("Failed to establish connection with the server") if data.protocol_version != response.protocol_version: self.logger.warn( "SDK with version %s using Protocol version %s is not fully compatible with Server with version %s. " "Client has connected in a compatibility mode using protocol version %s. " "Some features will not work as expected", SDK_VERSION, data.protocol_version, response.server_version, response.protocol_version, ) async def __send_and_receive(self, operation, data): async def inner(operation, data): if data.request_id is None: data.request_id = UtilInternal.get_request_id() validation = UtilInternal.is_invalid(data) if validation: raise ValidationException(validation) # If we're not connected, immediately try to reconnect if not self.connected: await self.__connect() self.__requests[data.request_id] = asyncio.Queue(1) # Write request to socket frame = MessageFrame(operation=operation, payload=cbor2.dumps(data.as_dict())) for b in UtilInternal.encode_frame(frame): self.__writer.write(b) await self.__writer.drain() # Wait for reader to come back with the response result = await self.__requests[data.request_id].get() # Drop async queue from request map del self.__requests[data.request_id] if isinstance(result, MessageFrame) and result.operation == Operation.Unknown: raise ClientException("Received response with unknown operation from server") return result # Perform the actual work as async so that we can put a timeout on the whole operation try: return await asyncio.wait_for(inner(operation, data), timeout=self.request_timeout) except asyncio.TimeoutError: # Drop async queue from request map del self.__requests[data.request_id] raise def __validate_read_message_options(self, options: Optional[ReadMessagesOptions]): if options is not None: if not isinstance(options, ReadMessagesOptions): raise ValidationException("options argument to read_messages must be a ReadMessageOptions object") validation = UtilInternal.is_invalid(options) if validation: raise ValidationException(validation) if ( options.min_message_count is not None and options.max_message_count is not None and options.min_message_count > options.max_message_count ): raise ValidationException("min_message_count must be less than or equal to max_message_count") if options.read_timeout_millis is not None and options.read_timeout_millis > self.request_timeout * 1000: raise ValidationException( "read_timeout_millis must be less than or equal to the client's request_timeout" ) async def _append_message(self, stream_name: str, data: bytes) -> int: append_message_request = AppendMessageRequest(name=stream_name, payload=data) append_message_response = await self.__send_and_receive( Operation.AppendMessage, data=append_message_request ) # type: AppendMessageResponse UtilInternal.raise_on_error_response(append_message_response) return append_message_response.sequence_number async def _create_message_stream(self, definition: MessageStreamDefinition) -> None: if not isinstance(definition, MessageStreamDefinition): raise ValidationException("definition argument to create_stream must be a MessageStreamDefinition object") create_stream_request = CreateMessageStreamRequest(definition=definition) create_stream_response = await self.__send_and_receive( Operation.CreateMessageStream, data=create_stream_request ) # type: CreateMessageStreamResponse UtilInternal.raise_on_error_response(create_stream_response) async def _delete_message_stream(self, stream_name: str) -> None: delete_stream_request = DeleteMessageStreamRequest(name=stream_name) delete_stream_response = await self.__send_and_receive( Operation.DeleteMessageStream, data=delete_stream_request ) # type: DeleteMessageStreamResponse UtilInternal.raise_on_error_response(delete_stream_response) async def _update_message_stream(self, definition: MessageStreamDefinition) -> None: if not isinstance(definition, MessageStreamDefinition): raise ValidationException( "definition argument to update_message_stream must be a MessageStreamDefinition object" ) update_stream_request = UpdateMessageStreamRequest(definition=definition) update_stream_response = await self.__send_and_receive( Operation.UpdateMessageStream, data=update_stream_request ) # type: UpdateMessageStreamResponse UtilInternal.raise_on_error_response(update_stream_response) async def _read_messages(self, stream_name: str, options: ReadMessagesOptions = None) -> List[Message]: self.__validate_read_message_options(options) read_messages_request = ReadMessagesRequest(stream_name=stream_name, read_messages_options=options) read_messages_response = await self.__send_and_receive( Operation.ReadMessages, data=read_messages_request ) # type: ReadMessagesResponse UtilInternal.raise_on_error_response(read_messages_response) return read_messages_response.messages async def _list_streams(self) -> List[str]: list_streams_response = await self.__send_and_receive( Operation.ListStreams, data=ListStreamsRequest() ) # type: ListStreamsResponse UtilInternal.raise_on_error_response(list_streams_response) return list_streams_response.streams async def _describe_message_stream(self, stream_name: str) -> MessageStreamInfo: describe_message_stream_response = await self.__send_and_receive( Operation.DescribeMessageStream, data=DescribeMessageStreamRequest(name=stream_name) ) # type: DescribeMessageStreamResponse UtilInternal.raise_on_error_response(describe_message_stream_response) return describe_message_stream_response.message_stream_info #################### # PUBLIC API # ####################
[docs] def read_messages(self, stream_name: str, options: Optional[ReadMessagesOptions] = None) -> List[Message]: """ Read message(s) from a chosen stream with options. If no options are specified it will try to read 1 message from the stream. :param stream_name: The name of the stream to read from. :param options: (Optional) Options used when reading from the stream of type :class:`.data.ReadMessagesOptions`. Defaults are: * desired_start_sequence_number: 0, * min_message_count: 1, * max_message_count: 1, * read_timeout_millis: 0 ``# Where 0 here represents that the server will immediately return the messages`` ``# or an exception if there were not enough messages available.`` If desired_start_sequence_number is specified in the options and is less than the current beginning of the stream, returned messages will start at the beginning of the stream and not necessarily the desired_start_sequence_number. :return: List of at least 1 message. :raises: :exc:`~.exceptions.StreamManagerException` and subtypes based on the precise error. :raises: :exc:`asyncio.TimeoutError` if the request times out. :raises: :exc:`ConnectionError` if the client is unable to reconnect to the server. """ self.__check_closed() return UtilInternal.sync(self._read_messages(stream_name, options), loop=self.__loop)
[docs] def append_message(self, stream_name: str, data: bytes) -> int: """ Append a message into the specified message stream. Returns the sequence number of the message if it was successfully appended. :param stream_name: The name of the stream to append to. :param data: Bytes type data. :return: Sequence number that the message was assigned if it was appended. :raises: :exc:`~.exceptions.StreamManagerException` and subtypes based on the precise error. :raises: :exc:`asyncio.TimeoutError` if the request times out. :raises: :exc:`ConnectionError` if the client is unable to reconnect to the server. """ self.__check_closed() return UtilInternal.sync(self._append_message(stream_name, data), loop=self.__loop)
[docs] def create_message_stream(self, definition: MessageStreamDefinition) -> None: """ Create a message stream with a given definition. :param definition: :class:`~.data.MessageStreamDefinition` definition object. :return: Nothing is returned if the request succeeds. :raises: :exc:`~.exceptions.StreamManagerException` and subtypes based on the precise error. :raises: :exc:`asyncio.TimeoutError` if the request times out. :raises: :exc:`ConnectionError` if the client is unable to reconnect to the server. """ self.__check_closed() return UtilInternal.sync(self._create_message_stream(definition), loop=self.__loop)
[docs] def delete_message_stream(self, stream_name: str) -> None: """ Deletes a message stream based on its name. Nothing is returned if the request succeeds, a subtype of :exc:`~.exceptions.StreamManagerException` will be raised if an error occurs. :param stream_name: The name of the stream to be deleted. :return: Nothing is returned if the request succeeds. :raises: :exc:`~.exceptions.StreamManagerException` and subtypes based on the precise error. :raises: :exc:`asyncio.TimeoutError` if the request times out. :raises: :exc:`ConnectionError` if the client is unable to reconnect to the server. """ self.__check_closed() return UtilInternal.sync(self._delete_message_stream(stream_name), loop=self.__loop)
[docs] def update_message_stream(self, definition: MessageStreamDefinition) -> None: """ Updates a message stream based on a given definition. Minimum version requirements: StreamManager server version 1.1 (or AWS IoT Greengrass Core 1.11.0) :param definition: class:`~.data.MessageStreamDefinition` definition object. :return: Nothing is returned if the request succeeds. :raises: :exc:`~.exceptions.StreamManagerException` and subtypes based on the precise error. :raises: :exc:`asyncio.TimeoutError` if the request times out. :raises: :exc:`ConnectionError` if the client is unable to reconnect to the server. """ self.__check_closed() return UtilInternal.sync(self._update_message_stream(definition), loop=self.__loop)
[docs] def list_streams(self) -> List[str]: """ List the streams in StreamManager. Returns a list of their names. :return: List of stream names. :raises: :exc:`~.exceptions.StreamManagerException` and subtypes based on the precise error. :raises: :exc:`asyncio.TimeoutError` if the request times out. :raises: :exc:`ConnectionError` if the client is unable to reconnect to the server. """ self.__check_closed() return UtilInternal.sync(self._list_streams(), loop=self.__loop)
[docs] def describe_message_stream(self, stream_name: str) -> MessageStreamInfo: """ Describe a message stream to get metadata including the stream's definition, size, and exporter statuses. :param stream_name: The name of the stream to describe. :return: :class:`~.data.MessageStreamInfo` type containing the stream information. :raises: :exc:`~.exceptions.StreamManagerException` and subtypes based on the precise error. :raises: :exc:`asyncio.TimeoutError` if the request times out. :raises: :exc:`ConnectionError` if the client is unable to reconnect to the server. """ self.__check_closed() return UtilInternal.sync(self._describe_message_stream(stream_name), loop=self.__loop)
[docs] def close(self): """ Call to shutdown the client and close all existing connections. Once a client is closed it cannot be reused. :raises: :exc:`~.exceptions.StreamManagerException` and subtypes based on the precise error. """ if not self.__closed: UtilInternal.sync(self._close(), loop=self.__loop) if not self.__loop.is_closed(): self.__loop.call_soon_threadsafe(self.__loop.stop)