diff --git a/custom_components/versatile_thermostat/underlyings.py b/custom_components/versatile_thermostat/underlyings.py index e558ef6..6a78512 100644 --- a/custom_components/versatile_thermostat/underlyings.py +++ b/custom_components/versatile_thermostat/underlyings.py @@ -220,9 +220,7 @@ class UnderlyingSwitch(UnderlyingEntity): @overrides def startup(self): super().startup() - self._keep_alive.set_async_action( - self.turn_on if self.is_device_active else self.turn_off - ) + self._keep_alive.set_async_action(self._keep_alive_callback) # @overrides this breaks some unit tests TypeError: object MagicMock can't be used in 'await' expression async def set_hvac_mode(self, hvac_mode: HVACMode) -> bool: @@ -247,9 +245,14 @@ class UnderlyingSwitch(UnderlyingEntity): not self.is_inversed and real_state ) + async def _keep_alive_callback(self): + """Keep alive: Turn on if already turned on, turn off if already turned off.""" + await (self.turn_on() if self.is_device_active else self.turn_off()) + # @overrides this breaks some unit tests TypeError: object MagicMock can't be used in 'await' expression async def turn_off(self): """Turn heater toggleable device off.""" + self._keep_alive.cancel() # Cancel early to avoid a turn_on/turn_off race condition _LOGGER.debug("%s - Stopping underlying entity %s", self, self._entity_id) command = SERVICE_TURN_OFF if not self.is_inversed else SERVICE_TURN_ON domain = self._entity_id.split(".")[0] @@ -258,7 +261,7 @@ class UnderlyingSwitch(UnderlyingEntity): try: data = {ATTR_ENTITY_ID: self._entity_id} await self._hass.services.async_call(domain, command, data) - self._keep_alive.set_async_action(self.turn_off) + self._keep_alive.set_async_action(self._keep_alive_callback) except Exception: self._keep_alive.cancel() raise @@ -267,6 +270,7 @@ class UnderlyingSwitch(UnderlyingEntity): async def turn_on(self): """Turn heater toggleable device on.""" + self._keep_alive.cancel() # Cancel early to avoid a turn_on/turn_off race condition _LOGGER.debug("%s - Starting underlying entity %s", self, self._entity_id) command = SERVICE_TURN_ON if not self.is_inversed else SERVICE_TURN_OFF domain = self._entity_id.split(".")[0] @@ -274,7 +278,7 @@ class UnderlyingSwitch(UnderlyingEntity): try: data = {ATTR_ENTITY_ID: self._entity_id} await self._hass.services.async_call(domain, command, data) - self._keep_alive.set_async_action(self.turn_on) + self._keep_alive.set_async_action(self._keep_alive_callback) except Exception: self._keep_alive.cancel() raise diff --git a/tests/test_switch_keep_alive.py b/tests/test_switch_keep_alive.py index b714841..a6c9336 100644 --- a/tests/test_switch_keep_alive.py +++ b/tests/test_switch_keep_alive.py @@ -210,6 +210,7 @@ class TestKeepAlive: common_mocks, [call("switch", SERVICE_TURN_ON, {"entity_id": "switch.mock_switch"})], ) + common_mocks.mock_is_state.return_value = True # Call the keep-alive callback a few times (as if `async_track_time_interval` # had done it) and assert that the callback function is replaced each time. @@ -240,6 +241,7 @@ class TestKeepAlive: common_mocks, [call("switch", SERVICE_TURN_OFF, {"entity_id": "switch.mock_switch"})], ) + common_mocks.mock_is_state.return_value = False # Call the keep-alive callback a few times (as if `async_track_time_interval` # had done it) and assert that the callback function is replaced each time.