diff --git a/gajim/common/modules/caps.py b/gajim/common/modules/caps.py index 007c1d594d1bc3e05a949b2314b608b482d1cea5..6121d811e7c0793fa78b469681c3532a80246d7b 100644 --- a/gajim/common/modules/caps.py +++ b/gajim/common/modules/caps.py @@ -17,16 +17,27 @@ # XEP-0115: Entity Capabilities +from __future__ import annotations + +from typing import Any +from typing import Callable +from typing import Optional + import weakref from collections import defaultdict +from nbxmpp.errors import StanzaError from nbxmpp.namespaces import Namespace -from nbxmpp.structs import StanzaHandler +from nbxmpp.protocol import JID +from nbxmpp.protocol import Presence from nbxmpp.structs import DiscoIdentity +from nbxmpp.structs import StanzaHandler +from nbxmpp.structs import PresenceProperties +from nbxmpp.task import Task as nbxmpp_Task from nbxmpp.util import compute_caps_hash -from nbxmpp.errors import StanzaError from gajim.common import app +from gajim.common import types from gajim.common.const import COMMON_FEATURES from gajim.common.const import Entity from gajim.common.helpers import get_optional_features @@ -42,7 +53,7 @@ class Caps(BaseModule): 'set_caps' ] - def __init__(self, con): + def __init__(self, con: types.Client) -> None: BaseModule.__init__(self, con) self.handlers = [ @@ -56,10 +67,11 @@ def __init__(self, con): DiscoIdentity(category='client', type='pc', name='Gajim') ] - self._queued_tasks_by_hash = defaultdict(set) - self._queued_tasks_by_jid = {} + self._queued_tasks_by_hash: defaultdict[ + str, set[EntityCapsTask]] = defaultdict(set) + self._queued_tasks_by_jid: dict[JID, EntityCapsTask] = {} - def _queue_task(self, task): + def _queue_task(self, task: EntityCapsTask) -> None: old_task = self._get_task(task.entity.jid) if old_task is not None: self._remove_task(old_task) @@ -69,24 +81,28 @@ def _queue_task(self, task): self._queued_tasks_by_jid[task.entity.jid] = task app.task_manager.add_task(task) - def _get_task(self, jid): + def _get_task(self, jid: JID) -> Optional[EntityCapsTask]: return self._queued_tasks_by_jid.get(jid) - def _get_similar_tasks(self, task): + def _get_similar_tasks(self, task: EntityCapsTask) -> set[EntityCapsTask]: return self._queued_tasks_by_hash.pop(task.entity.hash) - def _remove_task(self, task): + def _remove_task(self, task: EntityCapsTask) -> None: task.set_obsolete() del self._queued_tasks_by_jid[task.entity.jid] self._queued_tasks_by_hash[task.entity.hash].discard(task) - def _remove_all_tasks(self): + def _remove_all_tasks(self) -> None: for task in self._queued_tasks_by_jid.values(): task.set_obsolete() self._queued_tasks_by_jid.clear() self._queued_tasks_by_hash.clear() - def _entity_caps(self, _con, _stanza, properties): + def _entity_caps(self, + _con: types.xmppClient, + _stanza: Presence, + properties: PresenceProperties + ) -> None: if properties.type.is_error or properties.type.is_unavailable: return @@ -113,7 +129,7 @@ def _entity_caps(self, _con, _stanza, properties): contact = self._con.get_module('Contacts').get_contact(properties.jid) contact.notify('caps-update') - def _execute_task(self, task): + def _execute_task(self, task: EntityCapsTask) -> None: self._log.info('Request %s from %s', task.entity.hash, task.entity.jid) self._con.get_module('Discovery').disco_info( task.entity.jid, @@ -121,7 +137,7 @@ def _execute_task(self, task): callback=self._on_disco_info, user_data=task.entity.jid) - def _on_disco_info(self, nbxmpp_task): + def _on_disco_info(self, nbxmpp_task: nbxmpp_Task) -> None: jid = nbxmpp_task.get_user_data() task = self._get_task(jid) if task is None: @@ -162,7 +178,7 @@ def _on_disco_info(self, nbxmpp_task): task.entity.jid) contact.notify('caps-update') - def update_caps(self): + def update_caps(self) -> None: if not app.account_is_connected(self._account): return @@ -178,13 +194,17 @@ def update_caps(self): app.connections[self._account].status, app.connections[self._account].status_message) - def cleanup(self): + def cleanup(self) -> None: self._remove_all_tasks() BaseModule.cleanup(self) class EntityCapsTask(Task): - def __init__(self, account, properties, callback): + def __init__(self, + account: str, + properties: PresenceProperties, + callback: Callable[..., Any] + ) -> None: Task.__init__(self) self._account = account self._callback = weakref.WeakMethod(callback) @@ -196,12 +216,12 @@ def __init__(self, account, properties, callback): self._from_muc = properties.from_muc - def execute(self): + def execute(self) -> None: callback = self._callback() if callback is not None: callback(self) - def preconditions_met(self): + def preconditions_met(self) -> bool: try: client = app.get_client(self._account) except Exception: @@ -218,8 +238,8 @@ def preconditions_met(self): return client.state.is_available - def __repr__(self): + def __repr__(self) -> str: return f'Entity Caps ({self.entity.jid} {self.entity.hash})' - def __hash__(self): + def __hash__(self) -> int: return hash(self.entity)