Commit 9c53fcbd authored by Philipp Hörist's avatar Philipp Hörist

Use nbxmpp's Discovery module

parent 4757fd95
Pipeline #3794 passed with stages
in 2 minutes and 47 seconds
...@@ -38,6 +38,7 @@ from pathlib import Path ...@@ -38,6 +38,7 @@ from pathlib import Path
from collections import namedtuple from collections import namedtuple
import nbxmpp import nbxmpp
from nbxmpp.structs import DiscoIdentity
from gi.repository import Gdk from gi.repository import Gdk
import gajim import gajim
...@@ -143,7 +144,9 @@ socks5queue = None ...@@ -143,7 +144,9 @@ socks5queue = None
gupnp_igd = None gupnp_igd = None
gajim_identity = {'type': 'pc', 'category': 'client', 'name': 'Gajim'} gajim_identity = DiscoIdentity(category='client',
type='pc',
name='Gajim')
gajim_common_features = [ gajim_common_features = [
nbxmpp.NS_BYTESTREAM, nbxmpp.NS_BYTESTREAM,
......
...@@ -30,7 +30,6 @@ through ClientCaps objects which are hold by contact instances. ...@@ -30,7 +30,6 @@ through ClientCaps objects which are hold by contact instances.
import base64 import base64
import hashlib import hashlib
import logging import logging
from collections import namedtuple
import nbxmpp import nbxmpp
from nbxmpp.const import Affiliation from nbxmpp.const import Affiliation
...@@ -114,19 +113,18 @@ def compute_caps_hash(identities, features, dataforms=None, hash_method='sha-1') ...@@ -114,19 +113,18 @@ def compute_caps_hash(identities, features, dataforms=None, hash_method='sha-1')
dataforms = [] dataforms = []
def sort_identities_key(i): def sort_identities_key(i):
return (i['category'], i.get('type', ''), i.get('xml:lang', '')) return (i.category, i.type, i.lang or '')
def sort_dataforms_key(dataform): def sort_dataforms_key(dataform):
f = dataform.getField('FORM_TYPE') return dataform['FORM_TYPE'].value
return (bool(f), f.getValue())
S = '' S = ''
identities.sort(key=sort_identities_key) identities.sort(key=sort_identities_key)
for i in identities: for i in identities:
c = i['category'] c = i.category
type_ = i.get('type', '') type_ = i.type
lang = i.get('xml:lang', '') lang = i.lang or ''
name = i.get('name', '') name = i.name or ''
S += '%s/%s/%s/%s<' % (c, type_, lang, name) S += '%s/%s/%s/%s<' % (c, type_, lang, name)
features.sort() features.sort()
for f in features: for f in features:
...@@ -135,15 +133,16 @@ def compute_caps_hash(identities, features, dataforms=None, hash_method='sha-1') ...@@ -135,15 +133,16 @@ def compute_caps_hash(identities, features, dataforms=None, hash_method='sha-1')
for dataform in dataforms: for dataform in dataforms:
# fields indexed by var # fields indexed by var
fields = {} fields = {}
for f in dataform.getChildren(): for f in dataform.iter_fields():
fields[f.getVar()] = f values = f.getTags('value')
fields[f.var] = [value.getData() for value in values]
form_type = fields.get('FORM_TYPE') form_type = fields.get('FORM_TYPE')
if form_type: if form_type:
S += form_type.getValue() + '<' S += form_type[0] + '<'
del fields['FORM_TYPE'] del fields['FORM_TYPE']
for var in sorted(fields.keys()): for var in sorted(fields.keys()):
S += '%s<' % var S += '%s<' % var
values = sorted(fields[var].getValues()) values = sorted(fields[var])
for value in values: for value in values:
S += '%s<' % value S += '%s<' % value
...@@ -331,27 +330,10 @@ class CapsCache: ...@@ -331,27 +330,10 @@ class CapsCache:
features = property(_get_features, _set_features) features = property(_get_features, _set_features)
def _get_identities(self): def _get_identities(self):
list_ = [] return self._identities
for i in self._identities:
# transforms it back in a dict
d = dict()
d['category'] = i[0]
if i[1]:
d['type'] = i[1]
if i[2]:
d['xml:lang'] = i[2]
if i[3]:
d['name'] = i[3]
list_.append(d)
return list_
def _set_identities(self, value): def _set_identities(self, value):
self._identities = [] self._identities = value
for identity in value:
# dict are not hashable, so transform it into a tuple
t = (identity['category'], identity.get('type'),
identity.get('xml:lang'), identity.get('name'))
self._identities.append(self.__names.setdefault(t, t))
identities = property(_get_identities, _set_identities) identities = property(_get_identities, _set_identities)
...@@ -432,37 +414,15 @@ class CapsCache: ...@@ -432,37 +414,15 @@ class CapsCache:
class MucCapsCache: class MucCapsCache:
DiscoInfo = namedtuple('DiscoInfo', ['identities', 'features', 'data'])
def __init__(self): def __init__(self):
self.cache = {} self.cache = {}
def append(self, stanza): def append(self, info):
jid = stanza.getFrom() if nbxmpp.NS_MUC not in info.features:
identities, features, data = [], [], []
query_childs = stanza.getQueryChildren()
if not query_childs:
log.warning('%s returned empty disco info', jid)
return
for child in query_childs:
if child.getName() == 'identity':
attr = {}
for key in child.getAttrs().keys():
attr[key] = child.getAttr(key)
identities.append(attr)
elif child.getName() == 'feature':
features.append(child.getAttr('var'))
elif child.getName() == 'x':
if child.getNamespace() == nbxmpp.NS_DATA:
from gajim.common.modules import dataforms
data.append(dataforms.extend_form(node=child))
if nbxmpp.NS_MUC not in features:
# Not a MUC, don't cache info # Not a MUC, don't cache info
return return
self.cache[jid] = self.DiscoInfo(identities, features, data) self.cache[info.jid] = info
def is_cached(self, jid): def is_cached(self, jid):
return jid in self.cache return jid in self.cache
...@@ -498,7 +458,7 @@ class MucCapsCache: ...@@ -498,7 +458,7 @@ class MucCapsCache:
return allowed return allowed
if jid in self.cache: if jid in self.cache:
for form in self.cache[jid].data: for form in self.cache[jid].dataforms:
try: try:
allowed = form['muc#roominfo_changesubject'].value allowed = form['muc#roominfo_changesubject'].value
except KeyError: except KeyError:
...@@ -520,7 +480,7 @@ class MucCapsCache: ...@@ -520,7 +480,7 @@ class MucCapsCache:
def get_room_infos(self, jid): def get_room_infos(self, jid):
room_info = {} room_info = {}
if jid in self.cache: if jid in self.cache:
for form in self.cache[jid].data: for form in self.cache[jid].dataforms:
try: try:
room_info['name'] = form['muc#roomconfig_roomname'].value room_info['name'] = form['muc#roomconfig_roomname'].value
except KeyError: except KeyError:
......
...@@ -38,6 +38,8 @@ from gzip import GzipFile ...@@ -38,6 +38,8 @@ from gzip import GzipFile
from io import BytesIO from io import BytesIO
from gi.repository import GLib from gi.repository import GLib
from nbxmpp.structs import DiscoIdentity
from gajim.common import exceptions from gajim.common import exceptions
from gajim.common import app from gajim.common import app
from gajim.common import configpaths from gajim.common import configpaths
...@@ -1027,8 +1029,10 @@ class Logger: ...@@ -1027,8 +1029,10 @@ class Logger:
type_ = data[i + 1] type_ = data[i + 1]
lang = data[i + 2] lang = data[i + 2]
name = data[i + 3] name = data[i + 3]
identities.append({'category': category, 'type': type_, identities.append(DiscoIdentity(category=category,
'xml:lang': lang, 'name': name}) type=type_,
lang=lang,
name=name))
i += 4 i += 4
i += 1 i += 1
while i < len(data): while i < len(data):
...@@ -1046,10 +1050,12 @@ class Logger: ...@@ -1046,10 +1050,12 @@ class Logger:
data = [] data = []
for identity in identities: for identity in identities:
# there is no FEAT category # there is no FEAT category
if identity['category'] == 'FEAT': if identity.category == 'FEAT':
return return
data.extend((identity.get('category'), identity.get('type', ''), data.extend((identity.category,
identity.get('xml:lang', ''), identity.get('name', ''))) identity.type,
identity.lang or '',
identity.name or ''))
data.append('FEAT') data.append('FEAT')
data.extend(features) data.extend(features)
data = '\0'.join(data) data = '\0'.join(data)
......
...@@ -50,8 +50,8 @@ class Blocking(BaseModule): ...@@ -50,8 +50,8 @@ class Blocking(BaseModule):
self.supported = False self.supported = False
def pass_disco(self, from_, _identities, features, _data, _node): def pass_disco(self, info):
if nbxmpp.NS_BLOCKING not in features: if nbxmpp.NS_BLOCKING not in info.features:
return return
self.supported = True self.supported = True
...@@ -60,7 +60,7 @@ class Blocking(BaseModule): ...@@ -60,7 +60,7 @@ class Blocking(BaseModule):
account=self._account, account=self._account,
feature=nbxmpp.NS_BLOCKING)) feature=nbxmpp.NS_BLOCKING))
self._log.info('Discovered blocking: %s', from_) self._log.info('Discovered blocking: %s', info.jid)
def _blocking_list_received(self, result): def _blocking_list_received(self, result):
if is_error_result(result): if is_error_result(result):
......
...@@ -82,11 +82,11 @@ class Bookmarks(BaseModule): ...@@ -82,11 +82,11 @@ class Bookmarks(BaseModule):
app.nec.push_incoming_event( app.nec.push_incoming_event(
NetworkEvent('bookmarks-received', account=self._account)) NetworkEvent('bookmarks-received', account=self._account))
def pass_disco(self, from_, _identities, features, _data, _node): def pass_disco(self, info):
if nbxmpp.NS_BOOKMARK_CONVERSION not in features: if nbxmpp.NS_BOOKMARK_CONVERSION not in info.features:
return return
self._conversion = True self._conversion = True
self._log.info('Discovered Bookmarks Conversion: %s', from_) self._log.info('Discovered Bookmarks Conversion: %s', info.jid)
def _act_on_changed_bookmarks(self, old_bookmarks): def _act_on_changed_bookmarks(self, old_bookmarks):
new_bookmarks = self._convert_to_set(self._bookmarks) new_bookmarks = self._convert_to_set(self._bookmarks)
......
...@@ -95,16 +95,16 @@ class Bytestream(BaseModule): ...@@ -95,16 +95,16 @@ class Bytestream(BaseModule):
callback=self._ResultCB), callback=self._ResultCB),
] ]
def pass_disco(self, from_, _identities, features, _data, _node): def pass_disco(self, info):
if nbxmpp.NS_BYTESTREAM not in features: if nbxmpp.NS_BYTESTREAM not in info.features:
return return
if app.config.get_per('accounts', self._account, 'use_ft_proxies'): if app.config.get_per('accounts', self._account, 'use_ft_proxies'):
log.info('Discovered proxy: %s', from_) log.info('Discovered proxy: %s', info.jid)
our_fjid = self._con.get_own_jid() our_fjid = self._con.get_own_jid()
testit = app.config.get_per( testit = app.config.get_per(
'accounts', self._account, 'test_ft_proxies_on_startup') 'accounts', self._account, 'test_ft_proxies_on_startup')
app.proxy65_manager.resolve( app.proxy65_manager.resolve(
from_, self._con.connection, str(our_fjid), info.jid, self._con.connection, str(our_fjid),
default=self._account, testit=testit) default=self._account, testit=testit)
raise nbxmpp.NodeProcessed raise nbxmpp.NodeProcessed
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
import nbxmpp import nbxmpp
from nbxmpp.structs import StanzaHandler from nbxmpp.structs import StanzaHandler
from nbxmpp.util import is_error_result
from gajim.common import caps_cache from gajim.common import caps_cache
from gajim.common import app from gajim.common import app
...@@ -94,16 +95,21 @@ class Caps(BaseModule): ...@@ -94,16 +95,21 @@ class Caps(BaseModule):
self._account, room_jid, resource) self._account, room_jid, resource)
return contact return contact
def contact_info_received(self, from_, identities, features, data, node): def contact_info_received(self, info):
""" """
callback to update our caps cache with queried information after callback to update our caps cache with queried information after
we have retrieved an unknown caps hash via a disco we have retrieved an unknown caps hash via a disco
""" """
bare_jid = from_.getStripped()
contact = self._get_contact_or_gc_contact_for_jid(from_) if is_error_result(info):
self._log.info(info)
return
bare_jid = info.jid.getBare()
contact = self._get_contact_or_gc_contact_for_jid(info.jid)
if not contact: if not contact:
self._log.info('Received Disco from unknown contact %s', from_) self._log.info('Received Disco from unknown contact %s', info.jid)
return return
lookup = contact.client_caps.get_cache_lookup_strategy() lookup = contact.client_caps.get_cache_lookup_strategy()
...@@ -115,10 +121,10 @@ class Caps(BaseModule): ...@@ -115,10 +121,10 @@ class Caps(BaseModule):
return return
validate = contact.client_caps.get_hash_validation_strategy() validate = contact.client_caps.get_hash_validation_strategy()
hash_is_valid = validate(identities, features, data) hash_is_valid = validate(info.identities, info.features, info.dataforms)
if hash_is_valid: if hash_is_valid:
cache_item.set_and_store(identities, features) cache_item.set_and_store(info.identities, info.features)
else: else:
node = caps_hash = hash_method = None node = caps_hash = hash_method = None
contact.client_caps = self._create_suitable_client_caps( contact.client_caps = self._create_suitable_client_caps(
...@@ -126,11 +132,10 @@ class Caps(BaseModule): ...@@ -126,11 +132,10 @@ class Caps(BaseModule):
self._log.warning( self._log.warning(
'Computed and retrieved caps hash differ. Ignoring ' 'Computed and retrieved caps hash differ. Ignoring '
'caps of contact %s', contact.get_full_jid()) 'caps of contact %s', contact.get_full_jid())
app.nec.push_incoming_event( app.nec.push_incoming_event(
NetworkEvent('caps-update', NetworkEvent('caps-update',
conn=self._con, conn=self._con,
fjid=str(from_), fjid=str(info.jid),
jid=bare_jid)) jid=bare_jid))
......
...@@ -25,12 +25,12 @@ class Carbons(BaseModule): ...@@ -25,12 +25,12 @@ class Carbons(BaseModule):
self.supported = False self.supported = False
def pass_disco(self, from_, _identities, features, _data, _node): def pass_disco(self, info):
if nbxmpp.NS_CARBONS not in features: if nbxmpp.NS_CARBONS not in info.features:
return return
self.supported = True self.supported = True
self._log.info('Discovered carbons: %s', from_) self._log.info('Discovered carbons: %s', info.jid)
iq = nbxmpp.Iq('set') iq = nbxmpp.Iq('set')
iq.setTag('enable', namespace=nbxmpp.NS_CARBONS) iq.setTag('enable', namespace=nbxmpp.NS_CARBONS)
......
This diff is collapsed.
...@@ -71,22 +71,22 @@ class HTTPUpload(BaseModule): ...@@ -71,22 +71,22 @@ class HTTPUpload(BaseModule):
ged.OUT_PREGUI, ged.OUT_PREGUI,
self.handle_outgoing_stanza) self.handle_outgoing_stanza)
def pass_disco(self, from_, _identities, features, data, _node): def pass_disco(self, info):
if NS_HTTPUPLOAD_0 in features: if NS_HTTPUPLOAD_0 in info.features:
self.httpupload_namespace = NS_HTTPUPLOAD_0 self.httpupload_namespace = NS_HTTPUPLOAD_0
elif NS_HTTPUPLOAD in features: elif NS_HTTPUPLOAD in info.features:
self.httpupload_namespace = NS_HTTPUPLOAD self.httpupload_namespace = NS_HTTPUPLOAD
else: else:
return return
self.component = from_ self.component = info.jid
self._log.info('Discovered component: %s', from_) self._log.info('Discovered component: %s', info.jid)
for form in data: for form in info.dataforms:
form_dict = form.asDict() form_dict = form.asDict()
if form_dict.get('FORM_TYPE', None) != self.httpupload_namespace: if form_dict.get('FORM_TYPE') != self.httpupload_namespace:
continue continue
size = form_dict.get('max-file-size', None) size = form_dict.get('max-file-size')
if size is not None: if size is not None:
self.max_file_size = int(size) self.max_file_size = int(size)
break break
......
...@@ -52,16 +52,17 @@ class MAM(BaseModule): ...@@ -52,16 +52,17 @@ class MAM(BaseModule):
# Holds archive jids where catch up was successful # Holds archive jids where catch up was successful
self._catch_up_finished = [] self._catch_up_finished = []
def pass_disco(self, from_, _identities, features, _data, _node): def pass_disco(self, info):
if nbxmpp.NS_MAM_2 in features: if nbxmpp.NS_MAM_2 in info.features:
self.archiving_namespace = nbxmpp.NS_MAM_2 self.archiving_namespace = nbxmpp.NS_MAM_2
elif nbxmpp.NS_MAM_1 in features: elif nbxmpp.NS_MAM_1 in info.features:
self.archiving_namespace = nbxmpp.NS_MAM_1 self.archiving_namespace = nbxmpp.NS_MAM_1
else: else:
return return
self.available = True self.available = True
self._log.info('Discovered MAM %s: %s', self.archiving_namespace, from_) self._log.info('Discovered MAM %s: %s',
self.archiving_namespace, info.jid)
app.nec.push_incoming_event( app.nec.push_incoming_event(
NetworkEvent('feature-discovered', NetworkEvent('feature-discovered',
......
...@@ -103,16 +103,16 @@ class MUC(BaseModule): ...@@ -103,16 +103,16 @@ class MUC(BaseModule):
self._muc_data = {} self._muc_data = {}
def pass_disco(self, from_, identities, features, _data, _node): def pass_disco(self, info):
for identity in identities: for identity in info.identities:
if identity.get('category') != 'conference': if identity.category != 'conference':
continue continue
if identity.get('type') != 'text': if identity.type != 'text':
continue continue
if nbxmpp.NS_MUC in features: if nbxmpp.NS_MUC in info.features:
self._log.info('Discovered MUC: %s', from_) self._log.info('Discovered MUC: %s', info.jid)
# TODO: make this nicer # TODO: make this nicer
self._con.muc_jid['jabber'] = from_ self._con.muc_jid['jabber'] = str(info.jid)
raise nbxmpp.NodeProcessed raise nbxmpp.NodeProcessed
def _get_muc_data(self, room_jid): def _get_muc_data(self, room_jid):
......
...@@ -15,12 +15,8 @@ ...@@ -15,12 +15,8 @@
# XEP-0163: Personal Eventing Protocol # XEP-0163: Personal Eventing Protocol
from typing import Any from typing import Any
from typing import Dict
from typing import List
from typing import Tuple from typing import Tuple
import nbxmpp
from gajim.common.types import ConnectionT from gajim.common.types import ConnectionT
from gajim.common.modules.base import BaseModule from gajim.common.modules.base import BaseModule
...@@ -31,16 +27,11 @@ class PEP(BaseModule): ...@@ -31,16 +27,11 @@ class PEP(BaseModule):
self.supported = False self.supported = False
def pass_disco(self, def pass_disco(self, info):
from_: nbxmpp.JID, for identity in info.identities:
identities: List[Dict[str, str]], if identity.category == 'pubsub':
_features: List[str], if identity.type == 'pep':
_data: List[nbxmpp.DataForm], self._log.info('Discovered PEP support: %s', info.jid)
_node: str) -> None:
for identity in identities:
if identity['category'] == 'pubsub':
if identity.get('type') == 'pep':
self._log.info('Discovered PEP support: %s', from_)
self.supported = True self.supported = True
......
...@@ -45,12 +45,12 @@ class PrivacyLists(BaseModule): ...@@ -45,12 +45,12 @@ class PrivacyLists(BaseModule):
self.supported = False self.supported = False
def pass_disco(self, from_, _identities, features, _data, _node): def pass_disco(self, info):
if nbxmpp.NS_PRIVACY not in features: if nbxmpp.NS_PRIVACY not in info.features:
return return
self.supported = True self.supported = True
self._log.info('Discovered XEP-0016: Privacy Lists: %s', from_) self._log.info('Discovered XEP-0016: Privacy Lists: %s', info.jid)
app.nec.push_incoming_event( app.nec.push_incoming_event(
NetworkEvent('feature-discovered', NetworkEvent('feature-discovered',
......
...@@ -34,12 +34,12 @@ class PubSub(BaseModule): ...@@ -34,12 +34,12 @@ class PubSub(BaseModule):
self.publish_options = False self.publish_options = False
def pass_disco(self, from_, _identities, features, _data, _node): def pass_disco(self, info):
if nbxmpp.NS_PUBSUB_PUBLISH_OPTIONS not in features: if nbxmpp.NS_PUBSUB_PUBLISH_OPTIONS not in info.features:
# Remove stored bookmarks accessible to everyone. # Remove stored bookmarks accessible to everyone.
self._con.get_module('Bookmarks').purge_pubsub_bookmarks() self._con.get_module('Bookmarks').purge_pubsub_bookmarks()
return return
self._log.info('Discovered Pubsub publish options: %s', from_) self._log.info('Discovered Pubsub publish options: %s', info.jid)
self.publish_options = True self.publish_options = True
def send_pb_subscription_query(self, jid, cb, **kwargs): def send_pb_subscription_query(self, jid, cb, **kwargs):
......
...@@ -28,12 +28,12 @@ class SecLabels(BaseModule): ...@@ -28,12 +28,12 @@ class SecLabels(BaseModule):
self._catalogs = {} self._catalogs = {}
self.supported = False self.supported = False
def pass_disco(self, from_, _identities, features, _data, _node): def pass_disco(self, info):
if nbxmpp.NS_SECLABEL not in features: if nbxmpp.NS_SECLABEL not in info.features:
return return
self.supported = True self.supported = True
self._log.info('Discovered security labels: %s', from_) self._log.info('Discovered security labels: %s', info.jid)
def request_catalog(self, jid): def request_catalog(self, jid):
server = app.get_jid_from_account(self._account).split("@")[1] server = app.get_jid_from_account(self._account).split("@")[1]
......
...@@ -36,12 +36,12 @@ class VCardTemp(BaseModule): ...@@ -36,12 +36,12 @@ class VCardTemp(BaseModule):
self.room_jids = [] self.room_jids = []
self.supported = False self.supported = False
def pass_disco(self, from_, _identities, features, _data, _node): def pass_disco(self, info):
if nbxmpp.NS_VCARD not in features: if nbxmpp.NS_VCARD not in info.features:
return return
self.supported = True self.supported = True
self._log.info('Discovered vcard-temp: %s', from_) self._log.info('Discovered vcard-temp: %s', info.jid)
app.nec.push_incoming_event(NetworkEvent('feature-discovered', app.nec.push_incoming_event(NetworkEvent('feature-discovered',
account=self._account, account=self._account,
......
This diff is collapsed.
...@@ -12,10 +12,14 @@ ...@@ -12,10 +12,14 @@