Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
Link Mauve
python-nbxmpp
Commits
508ec4ce
Commit
508ec4ce
authored
Apr 18, 2020
by
Emmanuel Gil Peyrot
Browse files
Util: Type this module
parent
1b683ae4
Pipeline
#5442
passed with stages
in 34 seconds
Changes
1
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
nbxmpp/util.py
View file @
508ec4ce
...
...
@@ -25,6 +25,7 @@ import re
import
logging
from
logging
import
LoggerAdapter
from
collections
import
defaultdict
from
typing
import
Type
,
Union
,
Tuple
,
Optional
,
Pattern
,
Dict
,
List
,
Callable
,
Sequence
,
Any
from
functools
import
wraps
from
functools
import
lru_cache
...
...
@@ -38,6 +39,9 @@ from nbxmpp.namespaces import Namespace
from
nbxmpp.protocol
import
StanzaMalformed
from
nbxmpp.protocol
import
StreamHeader
from
nbxmpp.protocol
import
WebsocketOpenHeader
from
nbxmpp.protocol
import
Protocol
from
nbxmpp.protocol
import
JID
from
nbxmpp.simplexml
import
Node
from
nbxmpp.structs
import
Properties
from
nbxmpp.structs
import
IqProperties
from
nbxmpp.structs
import
MessageProperties
...
...
@@ -45,13 +49,18 @@ from nbxmpp.structs import PresenceProperties
from
nbxmpp.structs
import
CommonError
from
nbxmpp.structs
import
HTTPUploadError
from
nbxmpp.structs
import
StanzaMalformedError
from
nbxmpp.structs
import
DiscoInfo
from
nbxmpp.modules.dataforms
import
extend_form
from
nbxmpp.third_party.hsluv
import
hsluv_to_rgb
log
=
logging
.
getLogger
(
'nbxmpp.util'
)
Base64
=
Union
[
str
,
bytes
]
Color
=
Tuple
[
float
,
float
,
float
]
Form
=
Any
def
b64decode
(
data
,
return_type
=
str
):
def
b64decode
(
data
:
Base64
,
return_type
:
Union
[
Type
[
str
],
Type
[
bytes
]]
=
str
)
->
Base64
:
if
not
data
:
raise
ValueError
(
'No data to decode'
)
if
isinstance
(
data
,
str
):
...
...
@@ -62,7 +71,7 @@ def b64decode(data, return_type=str):
return
result
.
decode
()
def
b64encode
(
data
,
return_type
=
str
)
:
def
b64encode
(
data
:
Base64
,
return_type
:
Union
[
Type
[
str
],
Type
[
bytes
]]
=
str
)
->
Base64
:
if
not
data
:
raise
ValueError
(
'No data to encode'
)
if
isinstance
(
data
,
str
):
...
...
@@ -73,7 +82,7 @@ def b64encode(data, return_type=str):
return
result
.
decode
()
def
get_properties_struct
(
name
)
:
def
get_properties_struct
(
name
:
str
)
->
Properties
:
if
name
==
'message'
:
return
MessageProperties
()
if
name
==
'iq'
:
...
...
@@ -84,7 +93,7 @@ def get_properties_struct(name):
def
call_on_response
(
cb
):
def
response_decorator
(
func
):
def
response_decorator
(
func
:
Callable
[...,
Protocol
]
):
@
wraps
(
func
)
def
func_wrapper
(
self
,
*
args
,
**
kwargs
):
user_data
=
kwargs
.
pop
(
'user_data'
,
None
)
...
...
@@ -111,7 +120,7 @@ def call_on_response(cb):
def
callback
(
func
):
@
wraps
(
func
)
def
func_wrapper
(
self
,
_con
,
stanza
,
**
kwargs
):
def
func_wrapper
(
self
,
_con
,
stanza
:
Protocol
,
**
kwargs
):
cb
=
kwargs
.
pop
(
'callback'
,
None
)
user_data
=
kwargs
.
pop
(
'user_data'
,
None
)
...
...
@@ -154,14 +163,14 @@ error_classes = {
Namespace
.
HTTPUPLOAD_0
:
HTTPUploadError
}
def
error_factory
(
stanza
,
condition
=
None
,
text
=
None
)
:
def
error_factory
(
stanza
:
Protocol
,
condition
:
Optional
[
str
]
=
None
,
text
:
Optional
[
str
]
=
None
)
->
CommonError
:
if
condition
==
'stanza-malformed'
:
return
StanzaMalformedError
(
stanza
,
text
)
app_namespace
=
stanza
.
getAppErrorNamespace
()
return
error_classes
.
get
(
app_namespace
,
CommonError
)(
stanza
)
def
raise_error
(
log_method
,
stanza
,
condition
=
None
,
text
=
None
)
:
def
raise_error
(
log_method
,
stanza
:
Protocol
,
condition
:
Optional
[
str
]
=
None
,
text
:
Optional
[
str
]
=
None
)
->
CommonError
:
if
not
isErrorNode
(
stanza
)
and
condition
!=
'stanza-malformed'
:
condition
=
'stanza-malformed'
if
log_method
.
__name__
not
in
(
'warning'
,
'error'
):
...
...
@@ -181,11 +190,11 @@ def raise_error(log_method, stanza, condition=None, text=None):
return
error
def
is_error_result
(
result
):
def
is_error_result
(
result
)
->
bool
:
return
isinstance
(
result
,
CommonError
)
def
clip_rgb
(
red
,
green
,
blue
)
:
def
clip_rgb
(
red
:
float
,
green
:
float
,
blue
:
float
)
->
Color
:
return
(
min
(
max
(
red
,
0
),
1
),
min
(
max
(
green
,
0
),
1
),
...
...
@@ -194,7 +203,7 @@ def clip_rgb(red, green, blue):
@
lru_cache
(
maxsize
=
1024
)
def
text_to_color
(
text
,
background_color
)
:
def
text_to_color
(
text
:
str
,
background_color
:
Color
)
->
Color
:
# background color = (rb, gb, bb)
hash_
=
hashlib
.
sha1
()
hash_
.
update
(
text
.
encode
())
...
...
@@ -215,7 +224,7 @@ def text_to_color(text, background_color):
return
rc
,
gc
,
bc
def
compute_caps_hash
(
info
,
compare
=
True
)
:
def
compute_caps_hash
(
info
:
DiscoInfo
,
compare
:
bool
=
True
)
->
str
:
"""
Compute caps hash according to XEP-0115, V1.5
https://xmpp.org/extensions/xep-0115.html#ver-proc
...
...
@@ -334,12 +343,12 @@ def compute_caps_hash(info, compare=True):
return
b64hash
def
generate_id
():
def
generate_id
()
->
str
:
return
str
(
uuid
.
uuid4
())
def
get_form
(
stanza
,
form_type
)
:
forms
=
stanza
.
getTags
(
'x'
,
namespace
=
Namespace
.
DATA
)
def
get_form
(
node
:
Node
,
form_type
:
str
)
->
Optional
[
Form
]
:
forms
=
node
.
getTags
(
'x'
,
namespace
=
Namespace
.
DATA
)
if
not
forms
:
return
None
...
...
@@ -354,7 +363,7 @@ def get_form(stanza, form_type):
return
None
def
validate_stream_header
(
stanza
,
domain
,
is_websocket
)
:
def
validate_stream_header
(
stanza
:
Protocol
,
domain
:
JID
,
is_websocket
:
bool
)
->
str
:
attrs
=
stanza
.
getAttrs
()
if
attrs
.
get
(
'from'
)
!=
domain
:
raise
StanzaMalformed
(
'Invalid from attr in stream header'
)
...
...
@@ -376,18 +385,18 @@ def validate_stream_header(stanza, domain, is_websocket):
return
stream_id
def
get_stream_header
(
domain
,
lang
,
is_websocket
)
:
def
get_stream_header
(
domain
:
JID
,
lang
:
str
,
is_websocket
:
bool
)
->
str
:
if
is_websocket
:
return
WebsocketOpenHeader
(
domain
,
lang
)
header
=
StreamHeader
(
domain
,
lang
)
return
"<?xml version='1.0'?>%s>"
%
str
(
header
)[:
-
3
]
def
get_stanza_id
():
def
get_stanza_id
()
->
str
:
return
str
(
uuid
.
uuid4
())
def
utf8_decode
(
data
)
:
def
utf8_decode
(
data
:
bytes
)
->
Tuple
[
str
,
bytes
]
:
'''
Decodes utf8 byte string to unicode string
Does handle invalid utf8 sequences by splitting
...
...
@@ -406,11 +415,11 @@ def utf8_decode(data):
raise
def
get_rand_number
():
def
get_rand_number
()
->
int
:
return
int
(
binascii
.
hexlify
(
os
.
urandom
(
6
)),
16
)
def
get_invalid_xml_regex
():
def
get_invalid_xml_regex
()
->
Pattern
:
# \ufddo -> \ufdef range
c
=
'
\ufdd0
'
r
=
c
...
...
@@ -449,7 +458,7 @@ def convert_tls_error_flags(flags):
return
set
(
filter
(
lambda
error
:
error
&
flags
,
GIO_TLS_ERRORS
.
keys
()))
def
get_websocket_close_string
(
websocket
):
def
get_websocket_close_string
(
websocket
)
->
str
:
data
=
websocket
.
get_close_data
()
code
=
websocket
.
get_close_code
()
...
...
@@ -458,29 +467,31 @@ def get_websocket_close_string(websocket):
return
' Data: %s Code: %s'
%
(
data
,
code
)
def
is_websocket_close
(
stanza
)
:
def
is_websocket_close
(
stanza
:
Protocol
)
->
bool
:
return
(
stanza
.
getName
()
==
'close'
and
stanza
.
getNamespace
()
==
Namespace
.
FRAMING
)
def
is_websocket_stream_error
(
stanza
)
:
def
is_websocket_stream_error
(
stanza
:
Protocol
)
->
bool
:
return
(
stanza
.
getName
()
==
'error'
and
stanza
.
getNamespace
()
==
Namespace
.
STREAMS
)
class
Observable
:
def
__init__
(
self
,
log_
):
_callbacks
:
Dict
[
str
,
List
[
Callable
[...,
None
]]]
def
__init__
(
self
,
log_
:
Any
)
->
None
:
self
.
_log
=
log_
self
.
_frozen
=
False
self
.
_callbacks
=
defaultdict
(
list
)
def
remove_subscriptions
(
self
):
def
remove_subscriptions
(
self
)
->
None
:
self
.
_callbacks
=
defaultdict
(
list
)
def
subscribe
(
self
,
signal_name
,
func
)
:
def
subscribe
(
self
,
signal_name
:
str
,
func
:
Callable
[...,
None
])
->
None
:
self
.
_callbacks
[
signal_name
].
append
(
func
)
def
notify
(
self
,
signal_name
,
*
args
,
**
kwargs
)
:
def
notify
(
self
,
signal_name
:
str
,
*
args
:
Sequence
[
Any
],
**
kwargs
:
Dict
[
Any
,
Any
])
->
None
:
if
self
.
_frozen
:
self
.
_frozen
=
False
return
...
...
@@ -493,8 +504,8 @@ class Observable:
class
LogAdapter
(
LoggerAdapter
):
def
set_context
(
self
,
context
):
def
set_context
(
self
,
context
)
->
None
:
self
.
extra
[
'context'
]
=
context
def
process
(
self
,
msg
,
kwargs
)
:
def
process
(
self
,
msg
:
str
,
kwargs
:
Dict
[
Any
,
Any
])
->
Tuple
[
str
,
Dict
[
Any
,
Any
]]
:
return
'(%s) %s'
%
(
self
.
extra
[
'context'
],
msg
),
kwargs
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment