Commit bc7cac2b authored by Philipp Hörist's avatar Philipp Hörist
Browse files

Add annotations

parent 53bfd31b
......@@ -18,6 +18,13 @@ data structures, including jabber-objects like JID or different stanzas and
sub- stanzas) handling routines
"""
from __future__ import annotations
from typing import Any
from typing import Union
from typing import Optional
from typing import cast
import time
import hashlib
import functools
......@@ -469,7 +476,7 @@ def deprecation_warning(message):
@functools.lru_cache(maxsize=None)
def validate_localpart(localpart):
def validate_localpart(localpart: str) -> str:
if not localpart or len(localpart.encode()) > 1023:
raise LocalpartByteLimit
......@@ -484,7 +491,7 @@ def validate_localpart(localpart):
@functools.lru_cache(maxsize=None)
def validate_resourcepart(resourcepart):
def validate_resourcepart(resourcepart: str) -> str:
if not resourcepart or len(resourcepart.encode()) > 1023:
raise ResourcepartByteLimit
......@@ -496,7 +503,7 @@ def validate_resourcepart(resourcepart):
@functools.lru_cache(maxsize=None)
def validate_domainpart(domainpart):
def validate_domainpart(domainpart: str) -> str:
if not domainpart:
raise DomainpartByteLimit
......@@ -520,12 +527,12 @@ def validate_domainpart(domainpart):
@functools.lru_cache(maxsize=None)
def idna_encode(domain):
def idna_encode(domain: str) -> str:
return idna.encode(domain, uts46=True).decode()
@functools.lru_cache(maxsize=None)
def escape_localpart(localpart):
def escape_localpart(localpart: str) -> str:
# https://xmpp.org/extensions/xep-0106.html#bizrules-algorithm
#
# If there are any instances of character sequences that correspond
......@@ -547,7 +554,7 @@ def escape_localpart(localpart):
@functools.lru_cache(maxsize=None)
def unescape_localpart(localpart):
def unescape_localpart(localpart: str) -> str:
if localpart.startswith('\\20') or localpart.endswith('\\20'):
# Escaped JIDs are not allowed to start or end with \20
# so this localpart must be already unescaped
......@@ -566,11 +573,15 @@ def unescape_localpart(localpart):
@dataclass(frozen=True)
class JID:
localpart: str = None
domain: str = None
resource: str = None
localpart: Optional[str] = None
domain: Optional[str] = None
resource: Optional[str] = None
def __init__(self,
localpart: Optional[str] = None,
domain: Optional[str] = None,
resource: Optional[str] = None):
def __init__(self, localpart=None, domain=None, resource=None):
if localpart is not None:
localpart = validate_localpart(localpart)
object.__setattr__(self, "localpart", localpart)
......@@ -584,7 +595,7 @@ class JID:
@classmethod
@functools.lru_cache(maxsize=None)
def from_string(cls, jid_string):
def from_string(cls, jid_string: str) -> JID:
# https://tools.ietf.org/html/rfc7622#section-3.2
# Remove any portion from the first '/' character to the end of the
......@@ -609,7 +620,7 @@ class JID:
@classmethod
@functools.lru_cache(maxsize=None)
def from_user_input(cls, user_input, escaped=False):
def from_user_input(cls, user_input: str, escaped: bool = False) -> JID:
# Use this if we want JIDs to be escaped according to XEP-0106
# The standard JID parsing cannot be applied because user_input is
# not a valid JID.
......@@ -640,11 +651,11 @@ class JID:
domain=domainpart,
resource=None)
def __str__(self):
def __str__(self) -> str:
if self.localpart:
jid = f'{self.localpart}@{self.domain}'
else:
jid = self.domain
jid = cast(str, self.domain)
if self.resource is not None:
return f'{jid}/{self.resource}'
......@@ -653,7 +664,7 @@ class JID:
def __hash__(self):
return hash(str(self))
def __eq__(self, other):
def __eq__(self, other: Union[str, JID]) -> bool:
if isinstance(other, str):
deprecation_warning('comparing string with JID is deprected')
try:
......@@ -661,57 +672,54 @@ class JID:
except Exception:
return False
if not isinstance(other, JID):
raise TypeError('eq with type (%s) not supported' % type(other))
return (self.localpart == other.localpart and
self.domain == other.domain and
self.resource == other.resource)
def __ne__(self, other):
def __ne__(self, other: Union[str, JID]) -> bool:
return not self.__eq__(other)
def domain_to_ascii(self):
def domain_to_ascii(self) -> str:
return idna_encode(self.domain)
@property
def bare(self):
def bare(self) -> Optional[str]:
if self.localpart is not None:
return f'{self.localpart}@{self.domain}'
return self.domain
@property
def is_bare(self):
def is_bare(self) -> bool:
return self.resource is None
def new_as_bare(self):
def new_as_bare(self) -> JID:
if self.resource is None:
return self
new = asdict(self)
new.pop('resource')
return JID(**new)
def bare_match(self, other):
def bare_match(self, other: Union[str, JID]) -> bool:
if isinstance(other, str):
other = JID.from_string(other)
return self.bare == other.bare
@property
def is_domain(self):
def is_domain(self) -> bool:
return self.localpart is None and self.resource is None
@property
def is_full(self):
def is_full(self) -> bool:
return (self.localpart is not None and
self.domain is not None and
self.resource is not None)
def new_with(self, **kwargs):
def new_with(self, **kwargs: dict[str, str]) -> JID:
new = asdict(self)
new.update(kwargs)
return JID(**new)
def to_user_string(self, show_punycode=True):
def to_user_string(self, show_punycode: bool = True) -> str:
domain = self.domain_to_ascii()
if domain.startswith('xn--') and show_punycode:
domain_encoded = f' ({domain})'
......@@ -727,7 +735,7 @@ class JID:
return f'{localpart}@{self.domain}{domain_encoded}'
return f'{localpart}@{self.domain}/{self.resource}{domain_encoded}'
def copy(self):
def copy(self) -> JID:
deprecation_warning('copy() is not needed, JID is immutable')
return self
......@@ -736,7 +744,7 @@ class StreamErrorNode(Node):
def __init__(self, node):
Node.__init__(self, node=node)
self._text = {}
self._text: dict[Optional[str], str] = {}
text_elements = self.getTags('text', namespace=Namespace.XMPP_STREAMS)
for element in text_elements:
......@@ -744,14 +752,14 @@ class StreamErrorNode(Node):
text = element.getData()
self._text[lang] = text
def get_condition(self):
def get_condition(self) -> Optional[str]:
for tag in self.getChildren():
if (tag.getName() != 'text' and
tag.getNamespace() == Namespace.XMPP_STREAMS):
return tag.getName()
return None
def get_text(self, pref_lang=None):
def get_text(self, pref_lang: Optional[str] = None) -> str:
if pref_lang is not None:
text = self._text.get(pref_lang)
if text is not None:
......@@ -884,13 +892,13 @@ class Protocol(Node):
return JID.from_string(attr)
return attr
def getID(self):
def getID(self) -> Optional[str]:
"""
Return the value of the 'id' attribute
"""
return self.getAttr('id')
def setTo(self, val):
def setTo(self, val: Union[str, JID]):
"""
Set the value of the 'to' attribute
"""
......@@ -898,13 +906,13 @@ class Protocol(Node):
val = JID.from_string(val)
self.setAttr('to', val)
def getType(self):
def getType(self) -> Optional[str]:
"""
Return the value of the 'type' attribute
"""
return self.getAttr('type')
def setFrom(self, val):
def setFrom(self, val: Union[str, JID]):
"""
Set the value of the 'from' attribute
"""
......@@ -912,13 +920,13 @@ class Protocol(Node):
val = JID.from_string(val)
self.setAttr('from', val)
def setType(self, val):
def setType(self, val: str):
"""
Set the value of the 'type' attribute
"""
self.setAttr('type', val)
def setID(self, val):
def setID(self, val: str):
"""
Set the value of the 'id' attribute
"""
......@@ -1035,7 +1043,11 @@ class Protocol(Node):
props.append(prop)
return props
def getTag(self, name, attrs=None, namespace=None, protocol=False):
def getTag(self,
name: str,
attrs: Optional[dict[str, Any]] = None,
namespace: Optional[str] = None,
protocol: bool = False) -> Optional[Node]:
"""
Return the Node instance for the tag.
If protocol is True convert to a new Protocol/Message instance.
......@@ -1047,14 +1059,14 @@ class Protocol(Node):
return Protocol(node=tag)
return tag
def __setitem__(self, item, val):
def __setitem__(self, item: str, val: Union[str, JID]):
"""
Set the item 'item' to the value 'val'
"""
if item in ['to', 'from']:
if not isinstance(val, JID):
val = JID.from_string(val)
return self.setAttr(item, val)
self.setAttr(item, val)
class Message(Protocol):
......@@ -1637,13 +1649,15 @@ class Hashes2(Node):
class BindRequest(Iq):
def __init__(self, resource):
if resource is not None:
resource = Node('resource', payload=resource)
def __init__(self, resource: Optional[str]):
if resource is None:
res = resource
else:
res = Node('resource', payload=resource)
Iq.__init__(self, typ='set')
self.addChild(node=Node('bind',
{'xmlns': Namespace.BIND},
payload=resource))
payload=res))
class TLSRequest(Node):
......@@ -1658,7 +1672,7 @@ class SessionRequest(Iq):
class StreamHeader(Node):
def __init__(self, domain, lang=None):
def __init__(self, domain: str, lang: Optional[str] = None):
if lang is None:
lang = 'en'
Node.__init__(self,
......@@ -1671,7 +1685,7 @@ class StreamHeader(Node):
class WebsocketOpenHeader(Node):
def __init__(self, domain, lang=None):
def __init__(self, domain: str, lang: Optional[str] = None):
if lang is None:
lang = 'en'
Node.__init__(self,
......@@ -1687,7 +1701,7 @@ class WebsocketCloseHeader(Node):
class Features(Node):
def __init__(self, node):
def __init__(self, node: Node):
Node.__init__(self, node=node)
def has_starttls(self):
......@@ -1700,7 +1714,7 @@ class Features(Node):
def has_sasl(self):
return self.getTag('mechanisms', namespace=Namespace.SASL) is not None
def get_mechs(self):
def get_mechs(self) -> set[str]:
mechanisms = self.getTag('mechanisms', namespace=Namespace.SASL)
if mechanisms is None:
return set()
......
......@@ -276,7 +276,7 @@ class Node:
"""
if not isinstance(node, Node):
node = self.getTag(node, attrs)
assert isinstance(node, Node)
self.kids.remove(node)
return node
......
......@@ -15,6 +15,9 @@
# You should have received a copy of the GNU General Public License
# along with this program; If not, see <http://www.gnu.org/licenses/>.
from typing import Any
from typing import Optional
from typing import Callable
from typing import Literal
from typing import Union
......@@ -38,6 +41,7 @@ from nbxmpp.namespaces import Namespace
from nbxmpp.protocol import StanzaMalformed
from nbxmpp.protocol import StreamHeader
from nbxmpp.protocol import WebsocketOpenHeader
from nbxmpp.simplexml import Node
from nbxmpp.structs import Properties
from nbxmpp.structs import IqProperties
from nbxmpp.structs import MessageProperties
......@@ -112,7 +116,9 @@ error_classes = {
Namespace.HTTPUPLOAD_0: HTTPUploadError
}
def error_factory(stanza, condition=None, text=None):
def error_factory(stanza: Node,
condition: Optional[str] = None,
text: Optional[str] = None) -> Any:
if condition == 'stanza-malformed':
return StanzaMalformedError(stanza, text)
app_namespace = stanza.getAppErrorNamespace()
......@@ -272,7 +278,7 @@ def generate_id() -> str:
return str(uuid.uuid4())
def get_form(stanza, form_type):
def get_form(stanza: Node, form_type: Any) -> Any:
forms = stanza.getTags('x', namespace=Namespace.DATA)
if not forms:
return None
......@@ -288,7 +294,7 @@ def get_form(stanza, form_type):
return None
def validate_stream_header(stanza, domain, is_websocket):
def validate_stream_header(stanza: Node, domain: str, is_websocket: bool) -> str:
attrs = stanza.getAttrs()
if attrs.get('from') != domain:
raise StanzaMalformed('Invalid from attr in stream header')
......@@ -310,9 +316,9 @@ def validate_stream_header(stanza, domain, is_websocket):
return stream_id
def get_stream_header(domain, lang, is_websocket):
def get_stream_header(domain: str, lang: str, is_websocket: bool) -> str:
if is_websocket:
return WebsocketOpenHeader(domain, lang)
return str(WebsocketOpenHeader(domain, lang))
header = StreamHeader(domain, lang)
return "<?xml version='1.0'?>%s>" % str(header)[:-3]
......@@ -379,7 +385,7 @@ def convert_tls_error_flags(flags):
return set(filter(lambda error: error & flags, GIO_TLS_ERRORS.keys()))
def get_websocket_close_string(websocket):
def get_websocket_close_string(websocket: Any) -> str:
data = websocket.get_close_data()
code = websocket.get_close_code()
......@@ -388,29 +394,29 @@ def get_websocket_close_string(websocket):
return ' Data: %s Code: %s' % (data, code)
def is_websocket_close(stanza):
def is_websocket_close(stanza: Node) -> bool:
return (stanza.getName() == 'close' and
stanza.getNamespace() == Namespace.FRAMING)
def is_websocket_stream_error(stanza):
def is_websocket_stream_error(stanza: Node) -> bool:
return (stanza.getName() == 'error' and
stanza.getNamespace() == Namespace.STREAMS)
class Observable:
def __init__(self, log_):
def __init__(self, log_: logging.Logger):
self._log = log_
self._frozen = False
self._callbacks = defaultdict(list)
self._callbacks: defaultdict[str, list[Callable[..., Any]]] = defaultdict(list)
def remove_subscriptions(self):
self._callbacks = defaultdict(list)
def subscribe(self, signal_name, func):
def subscribe(self, signal_name: str, func: Callable[..., Any]):
self._callbacks[signal_name].append(func)
def notify(self, signal_name, *args, **kwargs):
def notify(self, signal_name: str, *args: Any, **kwargs: dict[str, Any]):
if self._frozen:
self._frozen = False
return
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment