diff --git a/gajim/common/modules/mam.py b/gajim/common/modules/mam.py
index 3bb17e0781ec462e358f20cb1baa8f60d9c6c362..9cf5aa52a4dbddc51f85c496d2bb359bbe88c715 100644
--- a/gajim/common/modules/mam.py
+++ b/gajim/common/modules/mam.py
@@ -14,26 +14,38 @@
 
 # XEP-0313: Message Archive Management
 
+from __future__ import annotations
+
+from typing import Any
+from typing import Optional
+
 import time
 from datetime import datetime
 from datetime import timedelta
 
 import nbxmpp
-from nbxmpp.namespaces import Namespace
-from nbxmpp.util import generate_id
 from nbxmpp.errors import StanzaError
 from nbxmpp.errors import MalformedStanzaError
 from nbxmpp.errors import is_error
+from nbxmpp.namespaces import Namespace
+from nbxmpp.protocol import JID
+from nbxmpp.protocol import Message
+from nbxmpp.structs import DiscoInfo
+from nbxmpp.structs import MessageProperties
 from nbxmpp.structs import StanzaHandler
+from nbxmpp.structs import StanzaIDData
+from nbxmpp.task import Task
+from nbxmpp.util import generate_id
 from nbxmpp.modules.util import raise_if_error
 
 from gajim.common import app
+from gajim.common import types
 from gajim.common.events import ArchivingIntervalFinished
 from gajim.common.events import FeatureDiscovered
 from gajim.common.events import MamMessageReceived
 from gajim.common.events import MessageUpdated
 from gajim.common.events import RawMamMessageReceived
-from gajim.common.const import ArchiveState
+from gajim.common.const import ArchiveState, ClientState
 from gajim.common.const import KindConstant
 from gajim.common.const import SyncThreshold
 from gajim.common.helpers import AdditionalDataDict
@@ -54,7 +66,7 @@ class MAM(BaseModule):
         'make_query',
     ]
 
-    def __init__(self, con):
+    def __init__(self, con: types.Client) -> None:
         BaseModule.__init__(self, con)
 
         self.handlers = [
@@ -67,15 +79,15 @@ def __init__(self, con):
         ]
 
         self.available = False
-        self._mam_query_ids = {}
+        self._mam_query_ids: dict[str, str] = {}
 
         # Holds archive jids where catch up was successful
-        self._catch_up_finished = []
+        self._catch_up_finished: list[str] = []
 
         self._con.connect_signal('state-changed', self._on_client_state_changed)
         self._con.connect_signal('resume-failed', self._on_client_resume_failed)
 
-    def pass_disco(self, info):
+    def pass_disco(self, info: DiscoInfo) -> None:
         if Namespace.MAM_2 not in info.features:
             return
 
@@ -86,24 +98,34 @@ def pass_disco(self, info):
             FeatureDiscovered(account=self._account,
                               feature=Namespace.MAM_2))
 
-    def _on_client_resume_failed(self, _client, _signal_name):
+    def _on_client_resume_failed(self,
+                                 _client: types.Client,
+                                 _signal_name: str
+                                 ) -> None:
         self._reset_state()
 
-    def _on_client_state_changed(self, _client, _signal_name, state):
+    def _on_client_state_changed(self,
+                                 _client: types.Client,
+                                 _signal_name: str,
+                                 state: ClientState
+                                 ) -> None:
         if state.is_disconnected:
             self._reset_state()
 
-    def _reset_state(self):
+    def _reset_state(self) -> None:
         self._mam_query_ids.clear()
         self._catch_up_finished.clear()
 
-    def _remove_query_id(self, jid):
+    def _remove_query_id(self, jid: JID) -> None:
         self._mam_query_ids.pop(jid, None)
 
-    def is_catch_up_finished(self, jid):
+    def is_catch_up_finished(self, jid: str) -> bool:
         return jid in self._catch_up_finished
 
-    def _from_valid_archive(self, _stanza, properties):
+    def _from_valid_archive(self,
+                            _stanza: Message,
+                            properties: MessageProperties
+                            ) -> bool:
         if properties.type.is_groupchat:
             expected_archive = properties.jid
         else:
@@ -111,7 +133,9 @@ def _from_valid_archive(self, _stanza, properties):
 
         return properties.mam.archive.bare_match(expected_archive)
 
-    def _get_unique_id(self, properties):
+    def _get_unique_id(self,
+                       properties: MessageProperties
+                       ) -> tuple[Optional[str], Optional[str]]:
         if properties.type.is_groupchat:
             return properties.mam.id, None
 
@@ -129,13 +153,19 @@ def _get_unique_id(self, properties):
         return properties.mam.id, None
 
     @staticmethod
-    def _get_stanza_id(properties, archive_jid):
+    def _get_stanza_id(properties: MessageProperties,
+                       archive_jid: str
+                       ) -> Optional[StanzaIDData]:
         for stanza_id in properties.stanza_ids:
             if stanza_id.by == archive_jid:
                 return stanza_id
         return None
 
-    def _set_message_archive_info(self, _con, _stanza, properties):
+    def _set_message_archive_info(self,
+                                  _con: types.xmppClient,
+                                  _stanza: Message,
+                                  properties: MessageProperties
+                                  ) -> None:
         if (properties.is_mam_message or
                 properties.is_pubsub or
                 properties.is_muc_subject):
@@ -177,7 +207,11 @@ def _set_message_archive_info(self, _con, _stanza, properties):
             last_mam_id=stanza_id.id,
             last_muc_timestamp=timestamp)
 
-    def _mam_message_received(self, _con, stanza, properties):
+    def _mam_message_received(self,
+                              _con: types.xmppClient,
+                              stanza: Message,
+                              properties: MessageProperties
+                              ) -> None:
         if not properties.is_mam_message:
             return
 
@@ -265,7 +299,7 @@ def _mam_message_received(self, _con, stanza, properties):
                 return
             stanza_id = message_id
 
-        event_attr = {
+        event_attr: dict[str, Any] = {
             'account': self._account,
             'jid': jid,
             'msgtxt': properties.body,
@@ -305,16 +339,16 @@ def _mam_message_received(self, _con, stanza, properties):
 
         app.ged.raise_event(MamMessageReceived(**event_attr))
 
-    def _is_valid_request(self, properties):
+    def _is_valid_request(self, properties: MessageProperties) -> bool:
         valid_id = self._mam_query_ids.get(properties.mam.archive, None)
         return valid_id == properties.mam.query_id
 
-    def _get_query_id(self, jid):
+    def _get_query_id(self, jid: str) -> str:
         query_id = generate_id()
         self._mam_query_ids[jid] = query_id
         return query_id
 
-    def _get_query_params(self):
+    def _get_query_params(self) -> tuple[Optional[Å¿tr], Optional[datetime]]:
         own_jid = self._con.get_own_jid().bare
         archive = app.storage.archive.get_archive_infos(own_jid)
 
@@ -334,7 +368,10 @@ def _get_query_params(self):
                            own_jid, start_date)
         return mam_id, start_date
 
-    def _get_muc_query_params(self, jid, threshold):
+    def _get_muc_query_params(self,
+                              jid: JID,
+                              threshold: SyncThreshold
+                              ) -> tuple[Optional[str], Optional[datetime]]:
         archive = app.storage.archive.get_archive_infos(jid)
         mam_id = None
         start_date = None
@@ -419,7 +456,9 @@ def request_archive_on_signin(self):
                 oldest_mam_timestamp=start_date.timestamp())
 
     @as_task
-    def request_archive_on_muc_join(self, jid):
+    def request_archive_on_muc_join(self,
+                                    jid: JID
+                                    ):
         _task = yield
 
         threshold = app.settings.get_group_chat_setting(self._account,
@@ -455,7 +494,11 @@ def request_archive_on_muc_join(self, jid):
                 last_muc_timestamp=time.time())
 
     @as_task
-    def _execute_query(self, jid, mam_id, start_date):
+    def _execute_query(self,
+                       jid: JID,
+                       mam_id: Optional[str],
+                       start_date: Optional[datetime]
+                       ):
         _task = yield
 
         if jid in self._catch_up_finished:
@@ -492,10 +535,11 @@ def _execute_query(self, jid, mam_id, start_date):
         yield result
 
     def request_archive_interval(self,
-                                 start_date,
-                                 end_date,
-                                 after=None,
-                                 queryid=None):
+                                 start_date: datetime,
+                                 end_date: datetime,
+                                 after: Optional[str] = None,
+                                 queryid: Optional[str] = None
+                                 ) -> str:
 
         jid = self._con.get_own_jid().bare
 
@@ -518,7 +562,7 @@ def request_archive_interval(self,
                         user_data=(queryid, start_date, end_date))
         return queryid
 
-    def _on_interval_result(self, task):
+    def _on_interval_result(self, task: Task) -> None:
         queryid, start_date, end_date = task.get_user_data()
 
         try: