Skip to content
Snippets Groups Projects
Commit 4475e1cd authored by Daniel Brötzmann's avatar Daniel Brötzmann
Browse files

chore: MAM: Add type annotations

parent 5c9f034e
No related branches found
No related tags found
No related merge requests found
......@@ -14,26 +14,38 @@
# XEP-0313: Message Archive Management
from __future__ import annotations
from typing import Any
from typing import Optional
import time
from datetime import datetime
from datetime import timedelta
import nbxmpp
from nbxmpp.namespaces import Namespace
from nbxmpp.util import generate_id
from nbxmpp.errors import StanzaError
from nbxmpp.errors import MalformedStanzaError
from nbxmpp.errors import is_error
from nbxmpp.namespaces import Namespace
from nbxmpp.protocol import JID
from nbxmpp.protocol import Message
from nbxmpp.structs import DiscoInfo
from nbxmpp.structs import MessageProperties
from nbxmpp.structs import StanzaHandler
from nbxmpp.structs import StanzaIDData
from nbxmpp.task import Task
from nbxmpp.util import generate_id
from nbxmpp.modules.util import raise_if_error
from gajim.common import app
from gajim.common import types
from gajim.common.events import ArchivingIntervalFinished
from gajim.common.events import FeatureDiscovered
from gajim.common.events import MamMessageReceived
from gajim.common.events import MessageUpdated
from gajim.common.events import RawMamMessageReceived
from gajim.common.const import ArchiveState
from gajim.common.const import ArchiveState, ClientState
from gajim.common.const import KindConstant
from gajim.common.const import SyncThreshold
from gajim.common.helpers import AdditionalDataDict
......@@ -54,7 +66,7 @@ class MAM(BaseModule):
'make_query',
]
def __init__(self, con):
def __init__(self, con: types.Client) -> None:
BaseModule.__init__(self, con)
self.handlers = [
......@@ -67,15 +79,15 @@ def __init__(self, con):
]
self.available = False
self._mam_query_ids = {}
self._mam_query_ids: dict[str, str] = {}
# Holds archive jids where catch up was successful
self._catch_up_finished = []
self._catch_up_finished: list[str] = []
self._con.connect_signal('state-changed', self._on_client_state_changed)
self._con.connect_signal('resume-failed', self._on_client_resume_failed)
def pass_disco(self, info):
def pass_disco(self, info: DiscoInfo) -> None:
if Namespace.MAM_2 not in info.features:
return
......@@ -86,24 +98,34 @@ def pass_disco(self, info):
FeatureDiscovered(account=self._account,
feature=Namespace.MAM_2))
def _on_client_resume_failed(self, _client, _signal_name):
def _on_client_resume_failed(self,
_client: types.Client,
_signal_name: str
) -> None:
self._reset_state()
def _on_client_state_changed(self, _client, _signal_name, state):
def _on_client_state_changed(self,
_client: types.Client,
_signal_name: str,
state: ClientState
) -> None:
if state.is_disconnected:
self._reset_state()
def _reset_state(self):
def _reset_state(self) -> None:
self._mam_query_ids.clear()
self._catch_up_finished.clear()
def _remove_query_id(self, jid):
def _remove_query_id(self, jid: JID) -> None:
self._mam_query_ids.pop(jid, None)
def is_catch_up_finished(self, jid):
def is_catch_up_finished(self, jid: str) -> bool:
return jid in self._catch_up_finished
def _from_valid_archive(self, _stanza, properties):
def _from_valid_archive(self,
_stanza: Message,
properties: MessageProperties
) -> bool:
if properties.type.is_groupchat:
expected_archive = properties.jid
else:
......@@ -111,7 +133,9 @@ def _from_valid_archive(self, _stanza, properties):
return properties.mam.archive.bare_match(expected_archive)
def _get_unique_id(self, properties):
def _get_unique_id(self,
properties: MessageProperties
) -> tuple[Optional[str], Optional[str]]:
if properties.type.is_groupchat:
return properties.mam.id, None
......@@ -129,13 +153,19 @@ def _get_unique_id(self, properties):
return properties.mam.id, None
@staticmethod
def _get_stanza_id(properties, archive_jid):
def _get_stanza_id(properties: MessageProperties,
archive_jid: str
) -> Optional[StanzaIDData]:
for stanza_id in properties.stanza_ids:
if stanza_id.by == archive_jid:
return stanza_id
return None
def _set_message_archive_info(self, _con, _stanza, properties):
def _set_message_archive_info(self,
_con: types.xmppClient,
_stanza: Message,
properties: MessageProperties
) -> None:
if (properties.is_mam_message or
properties.is_pubsub or
properties.is_muc_subject):
......@@ -177,7 +207,11 @@ def _set_message_archive_info(self, _con, _stanza, properties):
last_mam_id=stanza_id.id,
last_muc_timestamp=timestamp)
def _mam_message_received(self, _con, stanza, properties):
def _mam_message_received(self,
_con: types.xmppClient,
stanza: Message,
properties: MessageProperties
) -> None:
if not properties.is_mam_message:
return
......@@ -265,7 +299,7 @@ def _mam_message_received(self, _con, stanza, properties):
return
stanza_id = message_id
event_attr = {
event_attr: dict[str, Any] = {
'account': self._account,
'jid': jid,
'msgtxt': properties.body,
......@@ -305,16 +339,16 @@ def _mam_message_received(self, _con, stanza, properties):
app.ged.raise_event(MamMessageReceived(**event_attr))
def _is_valid_request(self, properties):
def _is_valid_request(self, properties: MessageProperties) -> bool:
valid_id = self._mam_query_ids.get(properties.mam.archive, None)
return valid_id == properties.mam.query_id
def _get_query_id(self, jid):
def _get_query_id(self, jid: str) -> str:
query_id = generate_id()
self._mam_query_ids[jid] = query_id
return query_id
def _get_query_params(self):
def _get_query_params(self) -> tuple[Optional[ſtr], Optional[datetime]]:
own_jid = self._con.get_own_jid().bare
archive = app.storage.archive.get_archive_infos(own_jid)
......@@ -334,7 +368,10 @@ def _get_query_params(self):
own_jid, start_date)
return mam_id, start_date
def _get_muc_query_params(self, jid, threshold):
def _get_muc_query_params(self,
jid: JID,
threshold: SyncThreshold
) -> tuple[Optional[str], Optional[datetime]]:
archive = app.storage.archive.get_archive_infos(jid)
mam_id = None
start_date = None
......@@ -419,7 +456,9 @@ def request_archive_on_signin(self):
oldest_mam_timestamp=start_date.timestamp())
@as_task
def request_archive_on_muc_join(self, jid):
def request_archive_on_muc_join(self,
jid: JID
):
_task = yield
threshold = app.settings.get_group_chat_setting(self._account,
......@@ -455,7 +494,11 @@ def request_archive_on_muc_join(self, jid):
last_muc_timestamp=time.time())
@as_task
def _execute_query(self, jid, mam_id, start_date):
def _execute_query(self,
jid: JID,
mam_id: Optional[str],
start_date: Optional[datetime]
):
_task = yield
if jid in self._catch_up_finished:
......@@ -492,10 +535,11 @@ def _execute_query(self, jid, mam_id, start_date):
yield result
def request_archive_interval(self,
start_date,
end_date,
after=None,
queryid=None):
start_date: datetime,
end_date: datetime,
after: Optional[str] = None,
queryid: Optional[str] = None
) -> str:
jid = self._con.get_own_jid().bare
......@@ -518,7 +562,7 @@ def request_archive_interval(self,
user_data=(queryid, start_date, end_date))
return queryid
def _on_interval_result(self, task):
def _on_interval_result(self, task: Task) -> None:
queryid, start_date, end_date = task.get_user_data()
try:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment