"""MQTT related stuffs"""
import asyncio
import json
from abc import ABC, abstractmethod
from asyncio import Future, get_running_loop
from collections.abc import Callable
from ssl import SSLContext
from typing import Any, cast
import paho.mqtt.client as mqtt
from .cloud_api import DeyeApiResponseFogPlatformDeviceProperties, DeyeCloudApi
from .const import QUERY_DEVICE_STATE_COMMAND_CLASSIC
from .device_command import DeyeDeviceCommand
from .device_state import DeyeDeviceState
[docs]
class BaseDeyeMqttClient(ABC):
"""Base class for MQTT clients connected to Deye MQTT servers."""
_mqtt_host: str
_mqtt_ssl_port: int
def __init__(
self,
cloud_api: DeyeCloudApi,
tls_context: SSLContext | None = None,
) -> None:
self._loop = get_running_loop()
self._cloud_api = cloud_api
self._mqtt = mqtt.Client()
if tls_context is not None:
self._mqtt.tls_set_context(tls_context)
else:
self._mqtt.tls_set()
self._mqtt.on_connect = self._mqtt_on_connect
self._mqtt.on_message = self._mqtt_on_message
self._mqtt.on_disconnect = self._mqtt_on_disconnect
self._subscribers: dict[str, set[Callable[[Any], None]]] = {}
self._pending_commands: list[tuple[str, bytes]] = []
@abstractmethod
async def _set_mqtt_info(self) -> None:
"""Get the MQTT info from the cloud API."""
raise NotImplementedError
[docs]
async def connect(self) -> None:
"""Connect the MQTT client to the server."""
await self._set_mqtt_info()
self._mqtt.connect_async(self._mqtt_host, self._mqtt_ssl_port)
self._mqtt.loop_start()
[docs]
def disconnect(self) -> None:
"""Disconnect the MQTT client to the server."""
self._mqtt.disconnect()
self._mqtt.loop_stop()
def _mqtt_on_connect(
self,
*args: Any,
) -> None:
for topic, callbacks in self._subscribers.items():
if len(callbacks) > 0:
self._mqtt.subscribe(topic)
if len(self._pending_commands) > 0:
for topic, command in self._pending_commands:
self._mqtt.publish(topic, command)
self._pending_commands.clear()
def _mqtt_on_disconnect(
self,
_mqtt: mqtt.Client,
_userdata: None,
result_code: int,
) -> None:
if result_code == 0: # User initiated disconnect
return
# Update MQTT info and wait for it to complete before reconnecting
# (reconnect is automatically handled by paho-mqtt by default)
asyncio.run_coroutine_threadsafe(self._set_mqtt_info(), self._loop).result()
@abstractmethod
def _process_message_payload(self, msg: mqtt.MQTTMessage) -> Any:
"""Process the message payload."""
raise NotImplementedError
def _mqtt_on_message(
self, _mqtt: mqtt.Client, _userdata: None, msg: mqtt.MQTTMessage
) -> None:
if msg.topic not in self._subscribers:
return
callbacks = self._subscribers[msg.topic]
try:
for callback in callbacks.copy():
self._loop.call_soon_threadsafe(
callback, self._process_message_payload(msg)
)
except (json.JSONDecodeError, KeyError):
pass
def _subscribe_topic(
self,
topic: str,
callback: Callable[[Any], None],
) -> Callable[[], None]:
if topic not in self._subscribers:
self._subscribers[topic] = set()
current_callback_len = len(self._subscribers[topic])
self._subscribers[topic].add(callback)
if self._mqtt.is_connected() and current_callback_len == 0:
self._mqtt.subscribe(topic)
def unsubscribe() -> None:
self._subscribers[topic].remove(callback)
if self._mqtt.is_connected() and len(self._subscribers[topic]) == 0:
self._mqtt.unsubscribe(topic)
return unsubscribe
[docs]
@abstractmethod
def subscribe_state_change(
self,
product_id: str,
device_id: str,
callback: Callable[[DeyeDeviceState], None],
) -> Callable[[], None]:
"""Subscribe to state changes of specified device."""
raise NotImplementedError
[docs]
@abstractmethod
def subscribe_availability_change(
self,
product_id: str,
device_id: str,
callback: Callable[[bool], None],
) -> Callable[[], None]:
"""Subscribe to availability changes of specified device."""
raise NotImplementedError
[docs]
@abstractmethod
async def publish_command(
self,
product_id: str,
device_id: str,
command: DeyeDeviceCommand,
properties: dict[str, int] | None = None,
) -> None:
"""Publish commands to a device"""
raise NotImplementedError
[docs]
@abstractmethod
async def query_device_state(
self, product_id: str, device_id: str
) -> DeyeDeviceState:
"""Query the latest device state."""
raise NotImplementedError
[docs]
class DeyeClassicMqttClient(BaseDeyeMqttClient):
"""MQTT client for the Classic platform."""
def _get_topic_prefix(self, product_id: str, device_id: str) -> str:
return f"{self._endpoint}/{product_id}/{device_id}"
async def _set_mqtt_info(self) -> None:
mqtt_info = await self._cloud_api.get_deye_platform_mqtt_info()
self._mqtt_host = mqtt_info["mqtthost"]
self._mqtt_ssl_port = mqtt_info["sslport"]
self._mqtt.username_pw_set(mqtt_info["loginname"], mqtt_info["password"])
self._endpoint = mqtt_info["endpoint"]
def _process_message_payload(self, msg: mqtt.MQTTMessage) -> Any:
"""Process the message payload for Classic platform."""
return json.loads(msg.payload)["data"]
[docs]
def subscribe_state_change(
self,
product_id: str,
device_id: str,
callback: Callable[[DeyeDeviceState], None],
) -> Callable[[], None]:
"""Subscribe to state changes of specified device."""
return self._subscribe_topic(
f"{self._get_topic_prefix(product_id, device_id)}/status/hex",
lambda payload: callback(DeyeDeviceState(payload)),
)
[docs]
def subscribe_availability_change(
self,
product_id: str,
device_id: str,
callback: Callable[[bool], None],
) -> Callable[[], None]:
"""Subscribe to availability changes of specified device."""
return self._subscribe_topic(
f"{self._get_topic_prefix(product_id, device_id)}/online/json",
lambda payload: callback(payload["online"]),
)
[docs]
async def publish_command(
self,
product_id: str,
device_id: str,
command: DeyeDeviceCommand | bytes,
properties: dict[str, int] | None = None,
) -> None:
"""Publish commands to a device"""
topic = f"{self._get_topic_prefix(product_id, device_id)}/command/hex"
command_bytes = (
command.to_bytes() if isinstance(command, DeyeDeviceCommand) else command
)
if self._mqtt.is_connected():
self._mqtt.publish(topic, command_bytes)
else:
self._pending_commands.append((topic, command_bytes))
[docs]
async def query_device_state(
self, product_id: str, device_id: str
) -> DeyeDeviceState:
"""Query the latest device state."""
future: Future[DeyeDeviceState] = Future()
unsubscribe: Callable[[], None] | None = None
def on_message(state: DeyeDeviceState) -> None:
if not future.done():
future.set_result(state)
if unsubscribe is not None:
unsubscribe()
unsubscribe = self.subscribe_state_change(product_id, device_id, on_message)
await self.publish_command(
product_id, device_id, QUERY_DEVICE_STATE_COMMAND_CLASSIC
)
return await future
[docs]
class DeyeFogMqttClient(BaseDeyeMqttClient):
"""MQTT client for the Fog platform."""
async def _set_mqtt_info(self) -> None:
mqtt_info = await self._cloud_api.get_fog_platform_mqtt_info()
self._mqtt_host = mqtt_info["mqtt_host"]
self._mqtt_ssl_port = int(mqtt_info["ssl_port"])
self._mqtt.username_pw_set(mqtt_info["username"], mqtt_info["password"])
self._topic = f"fogcloud/app/{mqtt_info['username']}/sub"
def _process_message_payload(self, msg: mqtt.MQTTMessage) -> Any:
"""Process the message payload for Fog platform."""
return json.loads(msg.payload)
[docs]
def subscribe_state_change(
self,
product_id: str,
device_id: str,
callback: Callable[[DeyeDeviceState], None],
) -> Callable[[], None]:
"""Subscribe to state changes of specified device."""
return self._subscribe_topic(
self._topic,
lambda payload: (
callback(
DeyeDeviceState(
cast(
DeyeApiResponseFogPlatformDeviceProperties,
payload["data"]["properties"],
)
)
)
if payload["device_id"] == device_id
and payload["biz_code"] == "device_data"
and payload["data"]["message_type"] == "thing_property"
else None
),
)
[docs]
def subscribe_availability_change(
self,
product_id: str,
device_id: str,
callback: Callable[[bool], None],
) -> Callable[[], None]:
"""Subscribe to availability changes of specified device."""
return self._subscribe_topic(
self._topic,
lambda payload: (
callback(payload["data"]["status"] == "online")
if payload["device_id"] == device_id
and payload["biz_code"] == "device_status"
else None
),
)
[docs]
async def publish_command(
self,
product_id: str,
device_id: str,
command: DeyeDeviceCommand,
properties: dict[str, int] | None = None,
) -> None:
"""
For Fog platform, commands are not published via MQTT.
Instead, use the cloud API to send commands.
"""
await self._cloud_api.set_fog_platform_device_properties(
device_id, properties if properties is not None else command.to_json()
)
[docs]
async def query_device_state(
self, product_id: str, device_id: str
) -> DeyeDeviceState:
"""Query the latest device state."""
device_properties = await self._cloud_api.get_fog_platform_device_properties(
device_id
)
return DeyeDeviceState(device_properties)