Skip to content
Snippets Groups Projects
Commit efaf720e authored by Daniel Brötzmann's avatar Daniel Brötzmann
Browse files

chore: Caps: Add type annotations

parent 7f1f80e2
No related branches found
No related tags found
No related merge requests found
......@@ -17,16 +17,27 @@
# XEP-0115: Entity Capabilities
from __future__ import annotations
from typing import Any
from typing import Callable
from typing import Optional
import weakref
from collections import defaultdict
from nbxmpp.errors import StanzaError
from nbxmpp.namespaces import Namespace
from nbxmpp.structs import StanzaHandler
from nbxmpp.protocol import JID
from nbxmpp.protocol import Presence
from nbxmpp.structs import DiscoIdentity
from nbxmpp.structs import StanzaHandler
from nbxmpp.structs import PresenceProperties
from nbxmpp.task import Task as nbxmpp_Task
from nbxmpp.util import compute_caps_hash
from nbxmpp.errors import StanzaError
from gajim.common import app
from gajim.common import types
from gajim.common.const import COMMON_FEATURES
from gajim.common.const import Entity
from gajim.common.helpers import get_optional_features
......@@ -42,7 +53,7 @@ class Caps(BaseModule):
'set_caps'
]
def __init__(self, con):
def __init__(self, con: types.Client) -> None:
BaseModule.__init__(self, con)
self.handlers = [
......@@ -56,10 +67,11 @@ def __init__(self, con):
DiscoIdentity(category='client', type='pc', name='Gajim')
]
self._queued_tasks_by_hash = defaultdict(set)
self._queued_tasks_by_jid = {}
self._queued_tasks_by_hash: defaultdict[
str, set[EntityCapsTask]] = defaultdict(set)
self._queued_tasks_by_jid: dict[JID, EntityCapsTask] = {}
def _queue_task(self, task):
def _queue_task(self, task: EntityCapsTask) -> None:
old_task = self._get_task(task.entity.jid)
if old_task is not None:
self._remove_task(old_task)
......@@ -69,24 +81,28 @@ def _queue_task(self, task):
self._queued_tasks_by_jid[task.entity.jid] = task
app.task_manager.add_task(task)
def _get_task(self, jid):
def _get_task(self, jid: JID) -> Optional[EntityCapsTask]:
return self._queued_tasks_by_jid.get(jid)
def _get_similar_tasks(self, task):
def _get_similar_tasks(self, task: EntityCapsTask) -> set[EntityCapsTask]:
return self._queued_tasks_by_hash.pop(task.entity.hash)
def _remove_task(self, task):
def _remove_task(self, task: EntityCapsTask) -> None:
task.set_obsolete()
del self._queued_tasks_by_jid[task.entity.jid]
self._queued_tasks_by_hash[task.entity.hash].discard(task)
def _remove_all_tasks(self):
def _remove_all_tasks(self) -> None:
for task in self._queued_tasks_by_jid.values():
task.set_obsolete()
self._queued_tasks_by_jid.clear()
self._queued_tasks_by_hash.clear()
def _entity_caps(self, _con, _stanza, properties):
def _entity_caps(self,
_con: types.xmppClient,
_stanza: Presence,
properties: PresenceProperties
) -> None:
if properties.type.is_error or properties.type.is_unavailable:
return
......@@ -113,7 +129,7 @@ def _entity_caps(self, _con, _stanza, properties):
contact = self._con.get_module('Contacts').get_contact(properties.jid)
contact.notify('caps-update')
def _execute_task(self, task):
def _execute_task(self, task: EntityCapsTask) -> None:
self._log.info('Request %s from %s', task.entity.hash, task.entity.jid)
self._con.get_module('Discovery').disco_info(
task.entity.jid,
......@@ -121,7 +137,7 @@ def _execute_task(self, task):
callback=self._on_disco_info,
user_data=task.entity.jid)
def _on_disco_info(self, nbxmpp_task):
def _on_disco_info(self, nbxmpp_task: nbxmpp_Task) -> None:
jid = nbxmpp_task.get_user_data()
task = self._get_task(jid)
if task is None:
......@@ -162,7 +178,7 @@ def _on_disco_info(self, nbxmpp_task):
task.entity.jid)
contact.notify('caps-update')
def update_caps(self):
def update_caps(self) -> None:
if not app.account_is_connected(self._account):
return
......@@ -178,13 +194,17 @@ def update_caps(self):
app.connections[self._account].status,
app.connections[self._account].status_message)
def cleanup(self):
def cleanup(self) -> None:
self._remove_all_tasks()
BaseModule.cleanup(self)
class EntityCapsTask(Task):
def __init__(self, account, properties, callback):
def __init__(self,
account: str,
properties: PresenceProperties,
callback: Callable[..., Any]
) -> None:
Task.__init__(self)
self._account = account
self._callback = weakref.WeakMethod(callback)
......@@ -196,12 +216,12 @@ def __init__(self, account, properties, callback):
self._from_muc = properties.from_muc
def execute(self):
def execute(self) -> None:
callback = self._callback()
if callback is not None:
callback(self)
def preconditions_met(self):
def preconditions_met(self) -> bool:
try:
client = app.get_client(self._account)
except Exception:
......@@ -218,8 +238,8 @@ def preconditions_met(self):
return client.state.is_available
def __repr__(self):
def __repr__(self) -> str:
return f'Entity Caps ({self.entity.jid} {self.entity.hash})'
def __hash__(self):
def __hash__(self) -> int:
return hash(self.entity)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment