diff --git a/gajim/common/modules/mam.py b/gajim/common/modules/mam.py index 3bb17e0781ec462e358f20cb1baa8f60d9c6c362..9cf5aa52a4dbddc51f85c496d2bb359bbe88c715 100644 --- a/gajim/common/modules/mam.py +++ b/gajim/common/modules/mam.py @@ -14,26 +14,38 @@ # XEP-0313: Message Archive Management +from __future__ import annotations + +from typing import Any +from typing import Optional + import time from datetime import datetime from datetime import timedelta import nbxmpp -from nbxmpp.namespaces import Namespace -from nbxmpp.util import generate_id from nbxmpp.errors import StanzaError from nbxmpp.errors import MalformedStanzaError from nbxmpp.errors import is_error +from nbxmpp.namespaces import Namespace +from nbxmpp.protocol import JID +from nbxmpp.protocol import Message +from nbxmpp.structs import DiscoInfo +from nbxmpp.structs import MessageProperties from nbxmpp.structs import StanzaHandler +from nbxmpp.structs import StanzaIDData +from nbxmpp.task import Task +from nbxmpp.util import generate_id from nbxmpp.modules.util import raise_if_error from gajim.common import app +from gajim.common import types from gajim.common.events import ArchivingIntervalFinished from gajim.common.events import FeatureDiscovered from gajim.common.events import MamMessageReceived from gajim.common.events import MessageUpdated from gajim.common.events import RawMamMessageReceived -from gajim.common.const import ArchiveState +from gajim.common.const import ArchiveState, ClientState from gajim.common.const import KindConstant from gajim.common.const import SyncThreshold from gajim.common.helpers import AdditionalDataDict @@ -54,7 +66,7 @@ class MAM(BaseModule): 'make_query', ] - def __init__(self, con): + def __init__(self, con: types.Client) -> None: BaseModule.__init__(self, con) self.handlers = [ @@ -67,15 +79,15 @@ def __init__(self, con): ] self.available = False - self._mam_query_ids = {} + self._mam_query_ids: dict[str, str] = {} # Holds archive jids where catch up was successful - self._catch_up_finished = [] + self._catch_up_finished: list[str] = [] self._con.connect_signal('state-changed', self._on_client_state_changed) self._con.connect_signal('resume-failed', self._on_client_resume_failed) - def pass_disco(self, info): + def pass_disco(self, info: DiscoInfo) -> None: if Namespace.MAM_2 not in info.features: return @@ -86,24 +98,34 @@ def pass_disco(self, info): FeatureDiscovered(account=self._account, feature=Namespace.MAM_2)) - def _on_client_resume_failed(self, _client, _signal_name): + def _on_client_resume_failed(self, + _client: types.Client, + _signal_name: str + ) -> None: self._reset_state() - def _on_client_state_changed(self, _client, _signal_name, state): + def _on_client_state_changed(self, + _client: types.Client, + _signal_name: str, + state: ClientState + ) -> None: if state.is_disconnected: self._reset_state() - def _reset_state(self): + def _reset_state(self) -> None: self._mam_query_ids.clear() self._catch_up_finished.clear() - def _remove_query_id(self, jid): + def _remove_query_id(self, jid: JID) -> None: self._mam_query_ids.pop(jid, None) - def is_catch_up_finished(self, jid): + def is_catch_up_finished(self, jid: str) -> bool: return jid in self._catch_up_finished - def _from_valid_archive(self, _stanza, properties): + def _from_valid_archive(self, + _stanza: Message, + properties: MessageProperties + ) -> bool: if properties.type.is_groupchat: expected_archive = properties.jid else: @@ -111,7 +133,9 @@ def _from_valid_archive(self, _stanza, properties): return properties.mam.archive.bare_match(expected_archive) - def _get_unique_id(self, properties): + def _get_unique_id(self, + properties: MessageProperties + ) -> tuple[Optional[str], Optional[str]]: if properties.type.is_groupchat: return properties.mam.id, None @@ -129,13 +153,19 @@ def _get_unique_id(self, properties): return properties.mam.id, None @staticmethod - def _get_stanza_id(properties, archive_jid): + def _get_stanza_id(properties: MessageProperties, + archive_jid: str + ) -> Optional[StanzaIDData]: for stanza_id in properties.stanza_ids: if stanza_id.by == archive_jid: return stanza_id return None - def _set_message_archive_info(self, _con, _stanza, properties): + def _set_message_archive_info(self, + _con: types.xmppClient, + _stanza: Message, + properties: MessageProperties + ) -> None: if (properties.is_mam_message or properties.is_pubsub or properties.is_muc_subject): @@ -177,7 +207,11 @@ def _set_message_archive_info(self, _con, _stanza, properties): last_mam_id=stanza_id.id, last_muc_timestamp=timestamp) - def _mam_message_received(self, _con, stanza, properties): + def _mam_message_received(self, + _con: types.xmppClient, + stanza: Message, + properties: MessageProperties + ) -> None: if not properties.is_mam_message: return @@ -265,7 +299,7 @@ def _mam_message_received(self, _con, stanza, properties): return stanza_id = message_id - event_attr = { + event_attr: dict[str, Any] = { 'account': self._account, 'jid': jid, 'msgtxt': properties.body, @@ -305,16 +339,16 @@ def _mam_message_received(self, _con, stanza, properties): app.ged.raise_event(MamMessageReceived(**event_attr)) - def _is_valid_request(self, properties): + def _is_valid_request(self, properties: MessageProperties) -> bool: valid_id = self._mam_query_ids.get(properties.mam.archive, None) return valid_id == properties.mam.query_id - def _get_query_id(self, jid): + def _get_query_id(self, jid: str) -> str: query_id = generate_id() self._mam_query_ids[jid] = query_id return query_id - def _get_query_params(self): + def _get_query_params(self) -> tuple[Optional[Å¿tr], Optional[datetime]]: own_jid = self._con.get_own_jid().bare archive = app.storage.archive.get_archive_infos(own_jid) @@ -334,7 +368,10 @@ def _get_query_params(self): own_jid, start_date) return mam_id, start_date - def _get_muc_query_params(self, jid, threshold): + def _get_muc_query_params(self, + jid: JID, + threshold: SyncThreshold + ) -> tuple[Optional[str], Optional[datetime]]: archive = app.storage.archive.get_archive_infos(jid) mam_id = None start_date = None @@ -419,7 +456,9 @@ def request_archive_on_signin(self): oldest_mam_timestamp=start_date.timestamp()) @as_task - def request_archive_on_muc_join(self, jid): + def request_archive_on_muc_join(self, + jid: JID + ): _task = yield threshold = app.settings.get_group_chat_setting(self._account, @@ -455,7 +494,11 @@ def request_archive_on_muc_join(self, jid): last_muc_timestamp=time.time()) @as_task - def _execute_query(self, jid, mam_id, start_date): + def _execute_query(self, + jid: JID, + mam_id: Optional[str], + start_date: Optional[datetime] + ): _task = yield if jid in self._catch_up_finished: @@ -492,10 +535,11 @@ def _execute_query(self, jid, mam_id, start_date): yield result def request_archive_interval(self, - start_date, - end_date, - after=None, - queryid=None): + start_date: datetime, + end_date: datetime, + after: Optional[str] = None, + queryid: Optional[str] = None + ) -> str: jid = self._con.get_own_jid().bare @@ -518,7 +562,7 @@ def request_archive_interval(self, user_data=(queryid, start_date, end_date)) return queryid - def _on_interval_result(self, task): + def _on_interval_result(self, task: Task) -> None: queryid, start_date, end_date = task.get_user_data() try: