Commit beeff59c authored by Emmanuel Gil Peyrot's avatar Emmanuel Gil Peyrot
Browse files

Util: Type this module

parent 514b4460
Pipeline #5406 failed with stages
......@@ -25,6 +25,7 @@ import re
import logging
from logging import LoggerAdapter
from collections import defaultdict
from typing import AnyStr, Type, Union, Tuple, Optional, Pattern, Dict, Any, List, Callable, Sequence
from functools import wraps
from functools import lru_cache
......@@ -42,6 +43,8 @@ from nbxmpp.protocol import NS_FRAMING
from nbxmpp.protocol import StanzaMalformed
from nbxmpp.protocol import StreamHeader
from nbxmpp.protocol import WebsocketOpenHeader
from nbxmpp.protocol import Protocol
from nbxmpp.simplexml import Node
from nbxmpp.structs import Properties
from nbxmpp.structs import IqProperties
from nbxmpp.structs import MessageProperties
......@@ -49,13 +52,18 @@ from nbxmpp.structs import PresenceProperties
from nbxmpp.structs import CommonError
from nbxmpp.structs import HTTPUploadError
from nbxmpp.structs import StanzaMalformedError
from nbxmpp.structs import DiscoInfo
from nbxmpp.modules.dataforms import extend_form
from nbxmpp.third_party.hsluv import hsluv_to_rgb
log = logging.getLogger('nbxmpp.util')
Base64 = Union[str, bytes]
Color = Tuple[float, float, float]
Form = Any
def b64decode(data, return_type=str):
def b64decode(data: Base64, return_type: Union[Type[str], Type[bytes]] = str) -> Base64:
if not data:
raise ValueError('No data to decode')
if isinstance(data, str):
......@@ -66,7 +74,7 @@ def b64decode(data, return_type=str):
return result.decode()
def b64encode(data, return_type=str):
def b64encode(data: Base64, return_type: Union[Type[str], Type[bytes]] = str) -> Base64:
if not data:
raise ValueError('No data to encode')
if isinstance(data, str):
......@@ -77,7 +85,7 @@ def b64encode(data, return_type=str):
return result.decode()
def get_properties_struct(name):
def get_properties_struct(name: str) -> Properties:
if name == 'message':
return MessageProperties()
if name == 'iq':
......@@ -88,7 +96,7 @@ def get_properties_struct(name):
def call_on_response(cb):
def response_decorator(func):
def response_decorator(func: Callable[..., Protocol]):
@wraps(func)
def func_wrapper(self, *args, **kwargs):
user_data = kwargs.pop('user_data', None)
......@@ -115,7 +123,7 @@ def call_on_response(cb):
def callback(func):
@wraps(func)
def func_wrapper(self, _con, stanza, **kwargs):
def func_wrapper(self, _con, stanza: Protocol, **kwargs):
cb = kwargs.pop('callback', None)
user_data = kwargs.pop('user_data', None)
......@@ -158,14 +166,14 @@ error_classes = {
NS_HTTPUPLOAD_0: HTTPUploadError
}
def error_factory(stanza, condition=None, text=None):
def error_factory(stanza: Protocol, condition: Optional[str] = None, text: Optional[str] = None) -> CommonError:
if condition == 'stanza-malformed':
return StanzaMalformedError(stanza, text)
app_namespace = stanza.getAppErrorNamespace()
return error_classes.get(app_namespace, CommonError)(stanza)
def raise_error(log_method, stanza, condition=None, text=None):
def raise_error(log_method, stanza: Protocol, condition: Optional[str] = None, text: Optional[str] = None) -> CommonError:
if not isErrorNode(stanza) and condition != 'stanza-malformed':
condition = 'stanza-malformed'
if log_method.__name__ not in ('warning', 'error'):
......@@ -185,11 +193,11 @@ def raise_error(log_method, stanza, condition=None, text=None):
return error
def is_error_result(result):
def is_error_result(result) -> bool:
return isinstance(result, CommonError)
def clip_rgb(red, green, blue):
def clip_rgb(red: float, green: float, blue: float) -> Color:
return (
min(max(red, 0), 1),
min(max(green, 0), 1),
......@@ -198,7 +206,7 @@ def clip_rgb(red, green, blue):
@lru_cache(maxsize=1024)
def text_to_color(text, background_color):
def text_to_color(text: str, background_color: Color) -> Color:
# background color = (rb, gb, bb)
hash_ = hashlib.sha1()
hash_.update(text.encode())
......@@ -219,7 +227,7 @@ def text_to_color(text, background_color):
return rc, gc, bc
def compute_caps_hash(info, compare=True):
def compute_caps_hash(info: DiscoInfo, compare: bool = True) -> str:
"""
Compute caps hash according to XEP-0115, V1.5
https://xmpp.org/extensions/xep-0115.html#ver-proc
......@@ -338,12 +346,12 @@ def compute_caps_hash(info, compare=True):
return b64hash
def generate_id():
def generate_id() -> str:
return str(uuid.uuid4())
def get_form(stanza, form_type):
forms = stanza.getTags('x', namespace=NS_DATA)
def get_form(node: Node, form_type: str) -> Optional[Form]:
forms = node.getTags('x', namespace=NS_DATA)
if not forms:
return None
......@@ -358,7 +366,7 @@ def get_form(stanza, form_type):
return None
def validate_stream_header(stanza, domain, is_websocket):
def validate_stream_header(stanza: Protocol, domain: JID, is_websocket: bool) -> str:
attrs = stanza.getAttrs()
if attrs.get('from') != domain:
raise StanzaMalformed('Invalid from attr in stream header')
......@@ -380,18 +388,18 @@ def validate_stream_header(stanza, domain, is_websocket):
return stream_id
def get_stream_header(domain, lang, is_websocket):
def get_stream_header(domain: JID, lang: str, is_websocket: bool) -> str:
if is_websocket:
return WebsocketOpenHeader(domain, lang)
header = StreamHeader(domain, lang)
return "<?xml version='1.0'?>%s>" % str(header)[:-3]
def get_stanza_id():
def get_stanza_id() -> str:
return str(uuid.uuid4())
def utf8_decode(data):
def utf8_decode(data: bytes) -> Tuple[str, bytes]:
'''
Decodes utf8 byte string to unicode string
Does handle invalid utf8 sequences by splitting
......@@ -410,11 +418,11 @@ def utf8_decode(data):
raise
def get_rand_number():
def get_rand_number() -> int:
return int(binascii.hexlify(os.urandom(6)), 16)
def get_invalid_xml_regex():
def get_invalid_xml_regex() -> Pattern:
# \ufddo -> \ufdef range
c = '\ufdd0'
r = c
......@@ -453,7 +461,7 @@ def convert_tls_error_flags(flags):
return set(filter(lambda error: error & flags, GIO_TLS_ERRORS.keys()))
def get_websocket_close_string(websocket):
def get_websocket_close_string(websocket) -> str:
data = websocket.get_close_data()
code = websocket.get_close_code()
......@@ -462,27 +470,29 @@ def get_websocket_close_string(websocket):
return ' Data: %s Code: %s' % (data, code)
def is_websocket_close(stanza):
def is_websocket_close(stanza: Protocol) -> bool:
return stanza.getName() == 'close' and stanza.getNamespace() == NS_FRAMING
def is_websocket_stream_error(stanza):
def is_websocket_stream_error(stanza: Protocol) -> bool:
return stanza.getName() == 'error' and stanza.getNamespace() == NS_STREAMS
class Observable:
def __init__(self, log_):
_callbacks: Dict[str, List[Callable[..., None]]]
def __init__(self, log_: Any) -> None:
self._log = log_
self._frozen = False
self._callbacks = defaultdict(list)
def remove_subscriptions(self):
def remove_subscriptions(self) -> None:
self._callbacks = defaultdict(list)
def subscribe(self, signal_name, func):
def subscribe(self, signal_name: str, func: Callable[..., None]) -> None:
self._callbacks[signal_name].append(func)
def notify(self, signal_name, *args, **kwargs):
def notify(self, signal_name: str, *args: Sequence[Any], **kwargs: Dict[Any, Any]) -> None:
if self._frozen:
self._frozen = False
return
......@@ -495,8 +505,8 @@ class Observable:
class LogAdapter(LoggerAdapter):
def set_context(self, context):
def set_context(self, context) -> None:
self.extra['context'] = context
def process(self, msg, kwargs):
def process(self, msg: str, kwargs: Dict[Any, Any]) -> Tuple[str, Dict[Any, Any]]:
return '(%s) %s' % (self.extra['context'], msg), kwargs
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment