Commit 508ec4ce authored by Emmanuel Gil Peyrot's avatar Emmanuel Gil Peyrot

Util: Type this module

parent 1b683ae4
Pipeline #5442 passed with stages
in 34 seconds
......@@ -25,6 +25,7 @@ import re
import logging
from logging import LoggerAdapter
from collections import defaultdict
from typing import Type, Union, Tuple, Optional, Pattern, Dict, List, Callable, Sequence, Any
from functools import wraps
from functools import lru_cache
......@@ -38,6 +39,9 @@ from nbxmpp.namespaces import Namespace
from nbxmpp.protocol import StanzaMalformed
from nbxmpp.protocol import StreamHeader
from nbxmpp.protocol import WebsocketOpenHeader
from nbxmpp.protocol import Protocol
from nbxmpp.protocol import JID
from nbxmpp.simplexml import Node
from nbxmpp.structs import Properties
from nbxmpp.structs import IqProperties
from nbxmpp.structs import MessageProperties
......@@ -45,13 +49,18 @@ from nbxmpp.structs import PresenceProperties
from nbxmpp.structs import CommonError
from nbxmpp.structs import HTTPUploadError
from nbxmpp.structs import StanzaMalformedError
from nbxmpp.structs import DiscoInfo
from nbxmpp.modules.dataforms import extend_form
from nbxmpp.third_party.hsluv import hsluv_to_rgb
log = logging.getLogger('nbxmpp.util')
Base64 = Union[str, bytes]
Color = Tuple[float, float, float]
Form = Any
def b64decode(data, return_type=str):
def b64decode(data: Base64, return_type: Union[Type[str], Type[bytes]] = str) -> Base64:
if not data:
raise ValueError('No data to decode')
if isinstance(data, str):
......@@ -62,7 +71,7 @@ def b64decode(data, return_type=str):
return result.decode()
def b64encode(data, return_type=str):
def b64encode(data: Base64, return_type: Union[Type[str], Type[bytes]] = str) -> Base64:
if not data:
raise ValueError('No data to encode')
if isinstance(data, str):
......@@ -73,7 +82,7 @@ def b64encode(data, return_type=str):
return result.decode()
def get_properties_struct(name):
def get_properties_struct(name: str) -> Properties:
if name == 'message':
return MessageProperties()
if name == 'iq':
......@@ -84,7 +93,7 @@ def get_properties_struct(name):
def call_on_response(cb):
def response_decorator(func):
def response_decorator(func: Callable[..., Protocol]):
@wraps(func)
def func_wrapper(self, *args, **kwargs):
user_data = kwargs.pop('user_data', None)
......@@ -111,7 +120,7 @@ def call_on_response(cb):
def callback(func):
@wraps(func)
def func_wrapper(self, _con, stanza, **kwargs):
def func_wrapper(self, _con, stanza: Protocol, **kwargs):
cb = kwargs.pop('callback', None)
user_data = kwargs.pop('user_data', None)
......@@ -154,14 +163,14 @@ error_classes = {
Namespace.HTTPUPLOAD_0: HTTPUploadError
}
def error_factory(stanza, condition=None, text=None):
def error_factory(stanza: Protocol, condition: Optional[str] = None, text: Optional[str] = None) -> CommonError:
if condition == 'stanza-malformed':
return StanzaMalformedError(stanza, text)
app_namespace = stanza.getAppErrorNamespace()
return error_classes.get(app_namespace, CommonError)(stanza)
def raise_error(log_method, stanza, condition=None, text=None):
def raise_error(log_method, stanza: Protocol, condition: Optional[str] = None, text: Optional[str] = None) -> CommonError:
if not isErrorNode(stanza) and condition != 'stanza-malformed':
condition = 'stanza-malformed'
if log_method.__name__ not in ('warning', 'error'):
......@@ -181,11 +190,11 @@ def raise_error(log_method, stanza, condition=None, text=None):
return error
def is_error_result(result):
def is_error_result(result) -> bool:
return isinstance(result, CommonError)
def clip_rgb(red, green, blue):
def clip_rgb(red: float, green: float, blue: float) -> Color:
return (
min(max(red, 0), 1),
min(max(green, 0), 1),
......@@ -194,7 +203,7 @@ def clip_rgb(red, green, blue):
@lru_cache(maxsize=1024)
def text_to_color(text, background_color):
def text_to_color(text: str, background_color: Color) -> Color:
# background color = (rb, gb, bb)
hash_ = hashlib.sha1()
hash_.update(text.encode())
......@@ -215,7 +224,7 @@ def text_to_color(text, background_color):
return rc, gc, bc
def compute_caps_hash(info, compare=True):
def compute_caps_hash(info: DiscoInfo, compare: bool = True) -> str:
"""
Compute caps hash according to XEP-0115, V1.5
https://xmpp.org/extensions/xep-0115.html#ver-proc
......@@ -334,12 +343,12 @@ def compute_caps_hash(info, compare=True):
return b64hash
def generate_id():
def generate_id() -> str:
return str(uuid.uuid4())
def get_form(stanza, form_type):
forms = stanza.getTags('x', namespace=Namespace.DATA)
def get_form(node: Node, form_type: str) -> Optional[Form]:
forms = node.getTags('x', namespace=Namespace.DATA)
if not forms:
return None
......@@ -354,7 +363,7 @@ def get_form(stanza, form_type):
return None
def validate_stream_header(stanza, domain, is_websocket):
def validate_stream_header(stanza: Protocol, domain: JID, is_websocket: bool) -> str:
attrs = stanza.getAttrs()
if attrs.get('from') != domain:
raise StanzaMalformed('Invalid from attr in stream header')
......@@ -376,18 +385,18 @@ def validate_stream_header(stanza, domain, is_websocket):
return stream_id
def get_stream_header(domain, lang, is_websocket):
def get_stream_header(domain: JID, lang: str, is_websocket: bool) -> str:
if is_websocket:
return WebsocketOpenHeader(domain, lang)
header = StreamHeader(domain, lang)
return "<?xml version='1.0'?>%s>" % str(header)[:-3]
def get_stanza_id():
def get_stanza_id() -> str:
return str(uuid.uuid4())
def utf8_decode(data):
def utf8_decode(data: bytes) -> Tuple[str, bytes]:
'''
Decodes utf8 byte string to unicode string
Does handle invalid utf8 sequences by splitting
......@@ -406,11 +415,11 @@ def utf8_decode(data):
raise
def get_rand_number():
def get_rand_number() -> int:
return int(binascii.hexlify(os.urandom(6)), 16)
def get_invalid_xml_regex():
def get_invalid_xml_regex() -> Pattern:
# \ufddo -> \ufdef range
c = '\ufdd0'
r = c
......@@ -449,7 +458,7 @@ def convert_tls_error_flags(flags):
return set(filter(lambda error: error & flags, GIO_TLS_ERRORS.keys()))
def get_websocket_close_string(websocket):
def get_websocket_close_string(websocket) -> str:
data = websocket.get_close_data()
code = websocket.get_close_code()
......@@ -458,29 +467,31 @@ def get_websocket_close_string(websocket):
return ' Data: %s Code: %s' % (data, code)
def is_websocket_close(stanza):
def is_websocket_close(stanza: Protocol) -> bool:
return (stanza.getName() == 'close' and
stanza.getNamespace() == Namespace.FRAMING)
def is_websocket_stream_error(stanza):
def is_websocket_stream_error(stanza: Protocol) -> bool:
return (stanza.getName() == 'error' and
stanza.getNamespace() == Namespace.STREAMS)
class Observable:
def __init__(self, log_):
_callbacks: Dict[str, List[Callable[..., None]]]
def __init__(self, log_: Any) -> None:
self._log = log_
self._frozen = False
self._callbacks = defaultdict(list)
def remove_subscriptions(self):
def remove_subscriptions(self) -> None:
self._callbacks = defaultdict(list)
def subscribe(self, signal_name, func):
def subscribe(self, signal_name: str, func: Callable[..., None]) -> None:
self._callbacks[signal_name].append(func)
def notify(self, signal_name, *args, **kwargs):
def notify(self, signal_name: str, *args: Sequence[Any], **kwargs: Dict[Any, Any]) -> None:
if self._frozen:
self._frozen = False
return
......@@ -493,8 +504,8 @@ class Observable:
class LogAdapter(LoggerAdapter):
def set_context(self, context):
def set_context(self, context) -> None:
self.extra['context'] = context
def process(self, msg, kwargs):
def process(self, msg: str, kwargs: Dict[Any, Any]) -> Tuple[str, Dict[Any, Any]]:
return '(%s) %s' % (self.extra['context'], msg), kwargs
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment