diff --git a/nbxmpp/modules/bookmarks.py b/nbxmpp/modules/bookmarks.py index 143748ae025796fb737517d67d81724d2a37c1c1..34bd0dade47617d3def8e3d984c8e286a3b1b72f 100644 --- a/nbxmpp/modules/bookmarks.py +++ b/nbxmpp/modules/bookmarks.py @@ -134,7 +134,7 @@ class Bookmarks(BaseModule): autojoin = False else: try: - autojoin = from_xs_boolean(autojoin) + autojoin = from_xs_boolean(autojoin, default=False) except ValueError as error: self._log.warning(error) self._log.warning(storage) @@ -176,7 +176,7 @@ class Bookmarks(BaseModule): self._log.warning(item) return None - autojoin = conference.getAttr('autojoin') in ('True', 'true', '1') + autojoin = from_xs_boolean(conference.getAttr('autojoin'), default=False) name = conference.getAttr('name') nick = conference.getTagData('nick') diff --git a/nbxmpp/modules/dataforms.py b/nbxmpp/modules/dataforms.py index 4eda96a04920c0c58b076a933345f329a78271a8..1d368c9d4df46efb59c34eb48755da1601dee90a 100644 --- a/nbxmpp/modules/dataforms.py +++ b/nbxmpp/modules/dataforms.py @@ -22,6 +22,7 @@ from nbxmpp.namespaces import Namespace from nbxmpp.protocol import JID from nbxmpp.simplexml import Node +from nbxmpp.util import from_xs_boolean # exceptions used in this module @@ -104,6 +105,22 @@ def extend_form(node): return SimpleDataForm(extend=node) +def get_form(node, form_type): + forms = node.getTags('x', namespace=Namespace.DATA) + if not forms: + return None + + for form in forms: + form = extend_form(node=form) + field = form.vars.get('FORM_TYPE') + if field is None: + continue + + if field.value == form_type: + return form + return None + + class DataField(ExtendedNode): """ Keeps data about one field - var, field type, labels, instructions... Base @@ -309,17 +326,14 @@ class BooleanField(DataField): Value of field. May contain True, False or None """ value = self.getTagData('value') - if value in ('0', 'false'): - return False - if value in ('1', 'true'): - return True - if value is None: - return False # default value is False - raise WrongFieldValue + try: + return from_xs_boolean(value, default=False) + except ValueError: + raise WrongFieldValue @value.setter def value(self, value): - self.setTagData('value', value and '1' or '0') + self.setTagData('value', '1' if value else '0') @value.deleter def value(self): diff --git a/nbxmpp/modules/muc.py b/nbxmpp/modules/muc.py index 0b029046adfc263de47ba793e8e260b405a73262..bb8fdc4560842d61aec08e42ff9dce187d3e659e 100644 --- a/nbxmpp/modules/muc.py +++ b/nbxmpp/modules/muc.py @@ -44,6 +44,7 @@ from nbxmpp.structs import MucDestroyed from nbxmpp.util import call_on_response from nbxmpp.util import callback from nbxmpp.util import raise_error +from nbxmpp.util import from_xs_boolean from nbxmpp.modules.dataforms import extend_form from nbxmpp.modules.base import BaseModule @@ -250,7 +251,7 @@ class MUC(BaseModule): data['from_'] = properties.jid data['reason'] = direct.getAttr('reason') data['password'] = direct.getAttr('password') - data['continued'] = direct.getAttr('continue') == 'true' + data['continued'] = from_xs_boolean(direct.getAttr('continue'), default=False) data['thread'] = direct.getAttr('thread') data['type'] = InviteType.DIRECT properties.muc_invite = InviteData(**data) diff --git a/nbxmpp/modules/register.py b/nbxmpp/modules/register.py index 45104c85ebb3af8f6eec09a126ab057f141f24f3..5d1e97308a9c6af796de43bb11918505751242c6 100644 --- a/nbxmpp/modules/register.py +++ b/nbxmpp/modules/register.py @@ -25,9 +25,9 @@ from nbxmpp.structs import ChangePasswordResult from nbxmpp.util import call_on_response from nbxmpp.util import callback from nbxmpp.util import raise_error -from nbxmpp.util import get_form from nbxmpp.const import REGISTER_FIELDS from nbxmpp.modules.bits_of_binary import parse_bob_data +from nbxmpp.modules.dataforms import get_form from nbxmpp.modules.dataforms import extend_form from nbxmpp.modules.dataforms import create_field from nbxmpp.modules.dataforms import SimpleDataForm diff --git a/nbxmpp/smacks.py b/nbxmpp/smacks.py index d75cb19a408e1a8776f8f37874eeeadc73af61c9..5dd6ff8ac31d4611b9347e3a9983cc34033a2903 100644 --- a/nbxmpp/smacks.py +++ b/nbxmpp/smacks.py @@ -22,6 +22,7 @@ from nbxmpp.namespaces import Namespace from nbxmpp.simplexml import Node from nbxmpp.const import StreamState from nbxmpp.util import LogAdapter +from nbxmpp.util import from_xs_boolean from nbxmpp.structs import StanzaHandler @@ -112,7 +113,7 @@ class Smacks: self._log.error('Received "enabled", but SM is already enabled') return resume = stanza.getAttr('resume') - if resume in ('true', 'True', '1'): + if from_xs_boolean(resume, default=False): self.resume_supported = True self._session_id = stanza.getAttr('id') diff --git a/nbxmpp/util.py b/nbxmpp/util.py index 1884df4c8e257b2638a5ca0f7ac1e9208eab460a..361f3a053a7faaf4520c1745c75090f1d63b0490 100644 --- a/nbxmpp/util.py +++ b/nbxmpp/util.py @@ -25,6 +25,7 @@ import re import logging from logging import LoggerAdapter from collections import defaultdict +from typing import Optional from functools import wraps from functools import lru_cache @@ -45,7 +46,6 @@ from nbxmpp.structs import PresenceProperties from nbxmpp.structs import CommonError from nbxmpp.structs import HTTPUploadError from nbxmpp.structs import StanzaMalformedError -from nbxmpp.modules.dataforms import extend_form from nbxmpp.third_party.hsluv import hsluv_to_rgb log = logging.getLogger('nbxmpp.util') @@ -124,17 +124,20 @@ def callback(func): return func_wrapper -def from_xs_boolean(value): - if value in ('1', 'true', 'True'): +def from_xs_boolean(value: Optional[str], *, default: Optional[bool] = None) -> bool: + if value in ('1', 'true'): return True - if value in ('0', 'false', 'False', ''): + if value in ('0', 'false'): return False + if value is None and default is not None: + return default + raise ValueError('Cant convert %s to python boolean' % value) -def to_xs_boolean(value): +def to_xs_boolean(value: Optional[bool]) -> str: # Convert to xs:boolean ('true', 'false') # from a python boolean (True, False) or None if value is True: @@ -338,22 +341,6 @@ def generate_id(): return str(uuid.uuid4()) -def get_form(stanza, form_type): - forms = stanza.getTags('x', namespace=Namespace.DATA) - if not forms: - return None - - for form in forms: - form = extend_form(node=form) - field = form.vars.get('FORM_TYPE') - if field is None: - continue - - if field.value == form_type: - return form - return None - - def validate_stream_header(stanza, domain, is_websocket): attrs = stanza.getAttrs() if attrs.get('from') != domain: