Commit 1732199d authored by Philipp Hörist's avatar Philipp Hörist
Browse files

asd

parent 25787e8a
Pipeline #8575 failed with stages
in 23 seconds
import gi
from .protocol import *
gi.require_version('Soup', '2.4')
__version__ = "3.0.0-dev1"
......@@ -17,9 +17,13 @@
from __future__ import annotations
from typing import Iterator
from typing import NamedTuple
from typing import Optional
import logging
from collections import namedtuple
from nbxmpp.structs import ProxyData
from nbxmpp.util import Observable
from nbxmpp.resolver import GioResolver
from nbxmpp.const import ConnectionType
......@@ -29,10 +33,14 @@ from nbxmpp.const import ConnectionProtocol
log = logging.getLogger('nbxmpp.addresses')
class ServerAddress(namedtuple('ServerAddress', 'domain service host uri '
'protocol type proxy')):
__slots__ = []
class ServerAddress(NamedTuple):
domain: str
service: Optional[str]
host: Optional[str]
uri: Optional[str]
protocol: ConnectionProtocol
type: ConnectionType
proxy: Optional[ProxyData]
@property
def is_service(self):
......@@ -58,7 +66,7 @@ class ServerAddresses(Observable):
'''
def __init__(self, domain):
def __init__(self, domain: str):
Observable.__init__(self, log)
self._domain = domain
......@@ -66,7 +74,7 @@ class ServerAddresses(Observable):
self._proxy = None
self._is_resolved = False
self._addresses = [
self._addresses: list[ServerAddress] = [
ServerAddress(domain=self._domain,
service='xmpps-client',
host=None,
......@@ -111,11 +119,11 @@ class ServerAddresses(Observable):
]
@property
def domain(self):
def domain(self) -> str:
return self._domain
@property
def is_resolved(self):
def is_resolved(self) -> bool:
return self._is_resolved
def resolve(self):
......@@ -138,7 +146,9 @@ class ServerAddresses(Observable):
def cancel_resolve(self):
self.remove_subscriptions()
def set_custom_host(self, address):
def set_custom_host(self, address: Optional[tuple[str,
ConnectionProtocol,
ConnectionType]]):
# Set a custom host, overwrites all other addresses
self._custom_host = address
if address is None:
......@@ -160,10 +170,10 @@ class ServerAddresses(Observable):
type=type_,
proxy=None)]
def set_proxy(self, proxy):
def set_proxy(self, proxy: Optional[ProxyData]):
self._proxy = proxy
def _on_alternatives_result(self, uri):
def _on_alternatives_result(self, uri: Optional[str]):
if uri is None:
self._on_request_resolved()
return
......@@ -184,6 +194,7 @@ class ServerAddresses(Observable):
protocol=ConnectionProtocol.WEBSOCKET,
type=type_,
proxy=None)
self._addresses.append(addr)
self._on_request_resolved()
......@@ -194,8 +205,8 @@ class ServerAddresses(Observable):
self.remove_subscriptions()
def get_next_address(self,
allowed_types,
allowed_protocols):
allowed_types: list[ConnectionType],
allowed_protocols: list[ConnectionProtocol]) -> Iterator[ServerAddress]:
'''
Selects next address
'''
......@@ -212,7 +223,7 @@ class ServerAddresses(Observable):
raise NoMoreAddresses
def _assure_proxy(self, addr):
def _assure_proxy(self, addr: ServerAddress):
if self._proxy is None:
return addr
......@@ -221,17 +232,22 @@ class ServerAddresses(Observable):
return addr
def _filter_allowed(self, addresses, allowed_types, allowed_protocols):
def _filter_allowed(self,
addresses: list[ServerAddress],
allowed_types: list[ConnectionType],
allowed_protocols: list[ConnectionProtocol]) -> list[ServerAddress]:
if self._proxy is not None:
addresses = filter(lambda addr: addr.host is not None, addresses)
addresses = list(filter(lambda addr: addr.host is not None,
addresses))
addresses = filter(lambda addr: addr.type in allowed_types,
addresses)
addresses = filter(lambda addr: addr.protocol in allowed_protocols,
addresses)
addresses = list(filter(lambda addr: addr.type in allowed_types,
addresses))
addresses = list(filter(lambda addr: addr.protocol in allowed_protocols,
addresses))
return addresses
def __str__(self):
def __str__(self) -> str:
addresses = self._addresses + self._fallback_addresses
return '\n'.join([str(addr) for addr in addresses])
......
......@@ -51,9 +51,9 @@ def make_sasl_element(name: str,
payload: Optional[str] = None):
if mechanism is None:
element = E(name, namespace=Namespace.SASL)
element = E(name, namespace=Namespace.XMPP_SASL)
else:
element = E(name, namespace=Namespace.SASL, mechanism=mechanism)
element = E(name, namespace=Namespace.XMPP_SASL, mechanism=mechanism)
element.text = payload
return element
......@@ -87,7 +87,7 @@ class SASL:
return self._password
def delegate(self, stanza):
if stanza.namespace != Namespace.SASL:
if stanza.namespace != Namespace.XMPP_SASL:
return
if stanza.localname == 'challenge':
self._on_challenge(stanza)
......
......@@ -25,6 +25,8 @@ import logging
from lxml import etree
from gi.repository import GLib
from gi.repository import Gio
from nbxmpp.connection import Connection
from nbxmpp.namespaces import Namespace
from nbxmpp.exceptions import StanzaMalformed
......@@ -34,6 +36,7 @@ from nbxmpp.errors import StanzaError
from nbxmpp.errors import CancelledError
from nbxmpp.addresses import ServerAddresses
from nbxmpp.addresses import NoMoreAddresses
from nbxmpp.structs import ProxyData
from nbxmpp.tcp import TCPConnection
from nbxmpp.websocket import WebsocketConnection
from nbxmpp.smacks import Smacks
......@@ -110,7 +113,7 @@ class Client(Observable):
self._ping_task = None
self._error = None, None, None
self._ignored_tls_errors = set()
self._ignored_tls_errors: set[Gio.TlsCertificateFlags] = set()
self._ignore_tls_errors = False
self._accepted_certificates = []
self._peer_certificate = None
......@@ -184,10 +187,16 @@ class Client(Observable):
self._mode = mode
@property
def custom_host(self):
def custom_host(self) -> Optional[tuple[str,
ConnectionProtocol,
ConnectionType]]:
return self._custom_host
def set_custom_host(self, host_or_uri, protocol, type_):
def set_custom_host(self,
host_or_uri: str,
protocol: ConnectionProtocol,
type_: ConnectionType):
if self._domain is None:
raise ValueError('Call set_domain() first before set_custom_host()')
self._custom_host = (host_or_uri, protocol, type_)
......@@ -196,26 +205,26 @@ class Client(Observable):
self._accepted_certificates = certificates
@property
def ignored_tls_errors(self):
def ignored_tls_errors(self) -> set[Gio.TlsCertificateFlags]:
return self._ignored_tls_errors
def set_ignored_tls_errors(self, errors):
def set_ignored_tls_errors(self, errors: set[Gio.TlsCertificateFlags]):
if errors is None:
errors = set()
errors: set[Gio.TlsCertificateFlags] = set()
self._ignored_tls_errors = errors
@property
def ignore_tls_errors(self):
def ignore_tls_errors(self) -> bool:
return self._ignore_tls_errors
def set_ignore_tls_errors(self, ignore):
def set_ignore_tls_errors(self, ignore: bool):
self._ignore_tls_errors = ignore
def set_password(self, password):
def set_password(self, password: str):
self._sasl.set_password(password)
@property
def password(self):
def password(self) -> Optional[str]:
return self._sasl.password
@property
......@@ -227,15 +236,15 @@ class Client(Observable):
return self._current_address
@property
def current_connection_type(self):
def current_connection_type(self) -> ConnectionType:
return self._current_address.type
@property
def is_websocket(self):
def is_websocket(self) -> bool:
return self._current_address.protocol == ConnectionProtocol.WEBSOCKET
@property
def stream_id(self):
def stream_id(self) -> Optional[str]:
return self._stream_id
@property
......@@ -269,32 +278,32 @@ class Client(Observable):
return self._remote_address
@property
def connection_types(self):
def connection_types(self) -> list[ConnectionType]:
if self._custom_host is not None:
return [self._custom_host[2]]
return list(self._allowed_con_types or [ConnectionType.DIRECT_TLS,
ConnectionType.START_TLS])
def set_connection_types(self, con_types):
def set_connection_types(self, con_types: list[ConnectionType]):
self._allowed_con_types = con_types
@property
def mechs(self):
def mechs(self) -> set[str]:
return set(self._allowed_mechs or set(['SCRAM-SHA-256',
'SCRAM-SHA-1',
'PLAIN']))
def set_mechs(self, mechs):
def set_mechs(self, mechs: list[str]):
self._allowed_mechs = mechs
@property
def protocols(self):
def protocols(self) -> list[ConnectionProtocol]:
if self._custom_host is not None:
return [self._custom_host[1]]
return list(self._allowed_protocols or [ConnectionProtocol.TCP,
ConnectionProtocol.WEBSOCKET])
def set_protocols(self, protocols):
def set_protocols(self, protocols: list[ConnectionProtocol]):
self._allowed_protocols = protocols
def set_sm_disabled(self, value: bool):
......@@ -312,15 +321,15 @@ class Client(Observable):
self._client_cert = client_cert
self._client_cert_pass = client_cert_pass
def set_proxy(self, proxy):
def set_proxy(self, proxy: ProxyData):
self._proxy = proxy
self._dispatcher.get_module('Muclumbus').set_proxy(proxy)
@property
def proxy(self):
def proxy(self) -> Optional[ProxyData]:
return self._proxy
def get_bound_jid(self) -> JID:
def get_bound_jid(self) -> Optional[JID]:
return self._jid
def _set_bound_jid(self, jid: str):
......@@ -336,7 +345,11 @@ class Client(Observable):
def _reset_error(self):
self._error = None, None, None
def _set_error(self, domain, error, text=None):
def _set_error(self,
domain: StreamError,
error: str,
text: Optional[str] = None):
self._log.info('Set error: %s, %s, %s', domain, error, text)
self._error = domain, error, text
......@@ -391,7 +404,7 @@ class Client(Observable):
self._addresses.subscribe('resolved', self._on_addresses_resolved)
self._addresses.resolve()
def _on_addresses_resolved(self, _addresses, _signal_name):
def _on_addresses_resolved(self, _addresses, _signal_name: str):
self._log.info('Domain resolved')
self._log.info(self._addresses)
self.state = StreamState.RESOLVED
......@@ -401,7 +414,7 @@ class Client(Observable):
self._try_next_ip()
def _try_next_ip(self, *args):
def _try_next_ip(self):
try:
self._current_address = next(self._address_generator)
except NoMoreAddresses:
......@@ -417,7 +430,7 @@ class Client(Observable):
self._log.info('Current address: %s', self._current_address)
self._connect()
def disconnect(self, immediate=False):
def disconnect(self, immediate: bool = False):
if self._state == StreamState.RESOLVE:
self._addresses.cancel_resolve()
self.state = StreamState.DISCONNECTED
......@@ -448,16 +461,12 @@ class Client(Observable):
else:
self._con.disconnect()
def send(self, stanza, *args, **kwargs):
# Alias for backwards compat
return self.send_stanza(stanza)
def _on_connected(self, connection, _signal_name):
def _on_connected(self, connection: Connection, _signal_name: str):
self.set_state(StreamState.CONNECTED)
self._local_address = connection.local_address
self._remote_address = connection.remote_address
def _on_disconnected(self, _connection, _signal_name):
def _on_disconnected(self, _connection: Connection, _signal_name: str):
self.state = StreamState.DISCONNECTED
for task in self._tasks:
task.cancel()
......@@ -466,7 +475,7 @@ class Client(Observable):
self._reset_stream()
self.notify('disconnected')
def _on_connection_failed(self, _connection, _signal_name):
def _on_connection_failed(self, _connection: Connection, _signal_name: str):
self.state = StreamState.DISCONNECTED
self._reset_stream()
if not self._connect_successful:
......@@ -478,17 +487,28 @@ class Client(Observable):
'successful address: {self._current_address}'))
self.notify('connection-failed')
def _disconnect_with_error(self, error_domain, error, text=None):
def _disconnect_with_error(self,
error_domain: StreamError,
error: str,
text: Optional[str] = None):
self._set_error(error_domain, error, text)
self.disconnect()
def _on_parsing_error(self, _dispatcher, _signal_name, error):
def _on_parsing_error(self,
_dispatcher: StanzaDispatcher,
_signal_name: str,
error: str):
if self._state == StreamState.DISCONNECTING:
# Don't notify about parsing errors if we already ended the stream
return
self._disconnect_with_error(StreamError.PARSING, 'parsing-error', error)
def _on_stream_end(self, _dispatcher, _signal_name, error):
def _on_stream_end(self,
_dispatcher: StanzaDispatcher,
_signal_name: str,
error: str):
if not self.has_error:
self._set_error(StreamError.STREAM, error or 'stream-end')
......@@ -515,12 +535,12 @@ class Client(Observable):
def get_module(self, name: str):
return self._dispatcher.get_module(name)
def _on_bad_certificate(self, connection, _signal_name):
def _on_bad_certificate(self, connection: Connection, _signal_name: str):
self._peer_certificate, self._peer_certificate_errors = \
connection.peer_certificate
self._set_error(StreamError.BAD_CERTIFICATE, 'bad certificate')
def _on_certificate_set(self, connection, _signal_name):
def _on_certificate_set(self, connection: Connection, _signal_name: str):
self._peer_certificate, self._peer_certificate_errors = \
connection.peer_certificate
......@@ -529,13 +549,25 @@ class Client(Observable):
self._accepted_certificates.append(self._peer_certificate)
self._connect()
def _on_data_sent(self, _connection, _signal_name, data):
def _on_data_sent(self,
_connection: Connection,
_signal_name: str,
data: Any):
self.notify('stanza-sent', data)
def _on_before_dispatch(self, _dispatcher, _signal_name, data):
def _on_before_dispatch(self,
_dispatcher: StanzaDispatcher,
_signal_name: str,
data: Any):
self.notify('stanza-received', data)
def _on_data_received(self, _connection, _signal_name, data):
def _on_data_received(self,
_connection: Connection,
_signal_name: str,
data: Any):
self._dispatcher.process_data(data)
self._reset_ping_timer()
......@@ -586,13 +618,10 @@ class Client(Observable):
self._smacks.save_in_queue(stanza)
return id_
def SendAndCallForResponse(self, stanza, callback, user_data=None):
self.send_stanza(stanza, callback=callback, user_data=user_data)
def send_nonza(self, nonza: types.Base, now: bool = False):
self._con.send(nonza, now)
def _xmpp_state_machine(self, stanza=None):
def _xmpp_state_machine(self, stanza: Optional[types.Base] = None):
self._log.info('Execute state machine')
if stanza is not None:
if stanza.localname == 'error':
......@@ -791,8 +820,8 @@ class Client(Observable):
self.send_stanza(bind_request)
self.state = StreamState.WAIT_FOR_BIND
def _on_bind(self, stanza):
if not isResultNode(stanza):
def _on_bind(self, stanza: types.Iq):
if not stanza.is_result():
self._disconnect_with_error(StreamError.BIND,
stanza.getError(),
stanza.getErrorMsg())
......@@ -811,8 +840,8 @@ class Client(Observable):
self.send_stanza(session_request)
self.state = StreamState.WAIT_FOR_SESSION
def _on_session(self, stanza):
if isResultNode(stanza):
def _on_session(self, stanza: types.Iq):
if stanza.is_result():
self._log.info('Successfully started session')
self.set_state(StreamState.BIND_SUCCESSFUL)
else:
......
......@@ -17,13 +17,16 @@
from __future__ import annotations
from typing import Any, Optional, cast
import logging
from gi.repository import Gio
from nbxmpp.const import TCPState
from nbxmpp.const import ConnectionType, TCPState
from nbxmpp.util import Observable
from nbxmpp.util import LogAdapter
from nbxmpp.addresses import ServerAddress
log = logging.getLogger('nbxmpp.connection')
......@@ -42,11 +45,11 @@ class Connection(Observable):
disconnected
'''
def __init__(self,
log_context,
address,
log_context: str,
address: ServerAddress,
accepted_certificates,
ignore_tls_errors,
ignored_tls_errors,
ignore_tls_errors: bool,
ignored_tls_errors: list[Gio.TlsCertificateFlags],
client_cert):
self._log = LogAdapter(log, {'context': log_context})
......@@ -57,42 +60,42 @@ class Connection(Observable):
self._address = address
self._local_address = None
self._remote_address = None
self._state = None
self._state = TCPState.DISCONNECTED
self._peer_certificate = None
self._peer_certificate_errors = None
self._peer_certificate_errors: set[Gio.TlsCertificateFlags] = None
self._accepted_certificates = accepted_certificates
self._ignore_tls_errors = ignore_tls_errors
self._ignored_tls_errors = ignored_tls_errors
self._ignored_tls_errors: list[Gio.TlsCertificateFlags] = ignored_tls_errors
@property
def local_address(self):
def local_address(self) -> Optional[str]:
return self._local_address
@property
def remote_address(self):
def remote_address(self) -> Optional[str]:
return self._remote_address
@property
def peer_certificate(self):
def peer_certificate(self) -> tuple[Optional[Any],
Optional[set[Gio.TlsCertificateFlags]]]:
return (self._peer_certificate, self._peer_certificate_errors)
@property
def connection_type(self):
def connection_type(self) -> ConnectionType:
return self._address.type
@property
def state(self):
def state(self) -> TCPState:
return self._state
@state.setter
def state(self, value):
def state(self, value: TCPState):
self._log.info('Set Connection State: %s', value)
self._state = value
def _accept_certificate(self):
def _accept_certificate(self) -> bool:
if not self._peer_certificate_errors:
return True
......@@ -119,33 +122,33 @@ class Connection(Observable):
return True
return False
def disconnect(self):
def disconnect(self) -> None:
raise NotImplementedError
def connect(self):
def connect(self) -> None:
raise NotImplementedError
def send(self, stanza, now=False):
def send(self, stanza: Any, now: bool = False) -> None:
raise NotImplementedError
def _log_stanza(self, data, received=True):
def _log_stanza(self, data: Any, received: bool = True):
if isinstance(data, bytes):
data = data.decode()
direction = 'RECEIVED' if received else 'SENT'
message = ('::::: DATA %s ::::\n\n%s\n')
self._log.info(message, direction, data)
def start_tls_negotiation(self):
def start_tls_negotiation(self) -> None:
raise NotImplementedError
def shutdown_output(self):
def shutdown_output(self) -> None: