Commit a58ff4f6 authored by Philipp Hörist's avatar Philipp Hörist

[omemo] Port python-omemo changes from master

parent d38fe92d
......@@ -18,6 +18,7 @@
#
import sys
import logging
log = logging.getLogger('gajim.plugin_system.omemo')
try:
......@@ -35,7 +36,11 @@ def encrypt(key, iv, plaintext):
def decrypt(key, iv, ciphertext):
return aes_decrypt(key, iv, ciphertext)
plaintext = aes_decrypt(key, iv, ciphertext).decode('utf-8')
if sys.version_info < (3, 0):
return unicode(plaintext)
else:
return plaintext
class NoValidSessions(Exception):
......
......@@ -29,11 +29,14 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import logging
from struct import pack, unpack
from Crypto.Cipher import AES
from Crypto.Util import strxor
log = logging.getLogger('gajim.plugin_system.omemo')
def gcm_rightshift(vec):
for x in range(15, 0, -1):
......@@ -140,13 +143,20 @@ def gcm_encrypt(k, iv, plaintext, auth_data):
def aes_encrypt(key, nonce, plaintext):
""" Use AES128 GCM with the given key and iv to encrypt the payload. """
c, t = gcm_encrypt(key, nonce, plaintext, '')
result = c + t
return result
return gcm_encrypt(key, nonce, plaintext, '')
def aes_decrypt(key, nonce, payload):
def aes_decrypt(_key, nonce, payload):
""" Use AES128 GCM with the given key and iv to decrypt the payload. """
ciphertext = payload[:-16]
mac = payload[-16:]
if len(_key) >= 32:
# XEP-0384
log.debug('XEP Compliant Key/Tag')
ciphertext = payload
key = _key[:16]
mac = _key[16:]
else:
# Legacy
log.debug('Legacy Key/Tag')
ciphertext = payload[:-16]
key = _key
mac = payload[-16:]
return gcm_decrypt(key, nonce, ciphertext, '', mac)
......@@ -19,6 +19,7 @@
import os
import logging
from cryptography.hazmat.primitives.ciphers import Cipher
from cryptography.hazmat.primitives.ciphers import algorithms
from cryptography.hazmat.primitives.ciphers.modes import GCM
......@@ -32,11 +33,22 @@ if os.name == 'nt':
else:
from cryptography.hazmat.backends import default_backend
log = logging.getLogger('gajim.plugin_system.omemo')
def aes_decrypt(key, iv, payload):
def aes_decrypt(_key, iv, payload):
""" Use AES128 GCM with the given key and iv to decrypt the payload. """
data = payload[:-16]
tag = payload[-16:]
if len(_key) >= 32:
# XEP-0384
log.debug('XEP Compliant Key/Tag')
data = payload
key = _key[:16]
tag = _key[16:]
else:
# Legacy
log.debug('Legacy Key/Tag')
data = payload[:-16]
key = _key
tag = payload[-16:]
if os.name == 'nt':
_backend = backend
else:
......@@ -58,4 +70,4 @@ def aes_encrypt(key, iv, plaintext):
algorithms.AES(key),
GCM(iv),
backend=_backend).encryptor()
return encryptor.update(plaintext) + encryptor.finalize() + encryptor.tag
return encryptor.update(plaintext) + encryptor.finalize(), encryptor.tag
......@@ -83,10 +83,16 @@ class LiteAxolotlStore(AxolotlStore):
def saveIdentity(self, recepientId, identityKey):
self.identityKeyStore.saveIdentity(recepientId, identityKey)
def deleteIdentity(self, recipientId, identityKey):
self.identityKeyStore.deleteIdentity(recipientId, identityKey)
def isTrustedIdentity(self, recepientId, identityKey):
return self.identityKeyStore.isTrustedIdentity(recepientId,
identityKey)
def setTrust(self, identityKey, trust):
return self.identityKeyStore.setTrust(identityKey, trust)
def getTrustedFingerprints(self, jid):
return self.identityKeyStore.getTrustedFingerprints(jid)
......@@ -127,6 +133,9 @@ class LiteAxolotlStore(AxolotlStore):
# TODO Reuse this
return self.sessionStore.getSubDeviceSessions(recepientId)
def getJidFromDevice(self, device_id):
return self.sessionStore.getJidFromDevice(device_id)
def storeSession(self, recepientId, deviceId, sessionRecord):
self.sessionStore.storeSession(recepientId, deviceId, sessionRecord)
......@@ -139,6 +148,15 @@ class LiteAxolotlStore(AxolotlStore):
def deleteAllSessions(self, recepientId):
self.sessionStore.deleteAllSessions(recepientId)
def getSessionsFromJid(self, recipientId):
return self.sessionStore.getSessionsFromJid(recipientId)
def getSessionsFromJids(self, recipientId):
return self.sessionStore.getSessionsFromJids(recipientId)
def getAllSessions(self):
return self.sessionStore.getAllSessions()
def loadSignedPreKey(self, signedPreKeyId):
return self.signedPreKeyStore.loadSignedPreKey(signedPreKeyId)
......
......@@ -86,6 +86,13 @@ class LiteIdentityKeyStore(IdentityKeyStore):
return result is not None
def deleteIdentity(self, recipientId, identityKey):
q = "DELETE FROM identities WHERE recipient_id = ? AND public_key = ?"
c = self.dbConn.cursor()
c.execute(q, (recipientId,
identityKey.getPublicKey().serialize()))
self.dbConn.commit()
def isTrustedIdentity(self, recipientId, identityKey):
q = "SELECT trust FROM identities WHERE recipient_id = ? " \
"AND public_key = ?"
......@@ -160,8 +167,8 @@ class LiteIdentityKeyStore(IdentityKeyStore):
c.execute(q, fingerprints)
self.dbConn.commit()
def setTrust(self, _id, trust):
q = "UPDATE identities SET trust = ? WHERE _id = ?"
def setTrust(self, identityKey, trust):
q = "UPDATE identities SET trust = ? WHERE public_key = ?"
c = self.dbConn.cursor()
c.execute(q, (trust, _id))
c.execute(q, (trust, identityKey.getPublicKey().serialize()))
self.dbConn.commit()
......@@ -48,6 +48,14 @@ class LiteSessionStore(SessionStore):
deviceIds = [r[0] for r in result]
return deviceIds
def getJidFromDevice(self, device_id):
q = "SELECT recipient_id from sessions WHERE device_id = ?"
c = self.dbConn.cursor()
c.execute(q, (device_id, ))
result = c.fetchone()
return result[0].decode('utf-8')
def getActiveDeviceTuples(self):
q = "SELECT recipient_id, device_id FROM sessions WHERE active = 1"
c = self.dbConn.cursor()
......@@ -82,6 +90,33 @@ class LiteSessionStore(SessionStore):
self.dbConn.cursor().execute(q, (recipientId, ))
self.dbConn.commit()
def getAllSessions(self):
q = "SELECT _id, recipient_id, device_id, record, active from sessions"
c = self.dbConn.cursor()
result = []
for row in c.execute(q):
result.append((row[0], row[1].decode('utf-8'), row[2], row[3], row[4]))
return result
def getSessionsFromJid(self, recipientId):
q = "SELECT _id, recipient_id, device_id, record, active from sessions" \
" WHERE recipient_id = ?"
c = self.dbConn.cursor()
result = []
for row in c.execute(q, (recipientId,)):
result.append((row[0], row[1].decode('utf-8'), row[2], row[3], row[4]))
return result
def getSessionsFromJids(self, recipientId):
q = "SELECT _id, recipient_id, device_id, record, active from sessions" \
" WHERE recipient_id IN ({})" \
.format(', '.join(['?'] * len(recipientId)))
c = self.dbConn.cursor()
result = []
for row in c.execute(q, recipientId):
result.append((row[0], row[1].decode('utf-8'), row[2], row[3], row[4]))
return result
def setActiveState(self, deviceList, jid):
c = self.dbConn.cursor()
......@@ -96,28 +131,6 @@ class LiteSessionStore(SessionStore):
c.execute(q, deviceList)
self.dbConn.commit()
def getActiveSessionsKeys(self, recipientId):
q = "SELECT record FROM sessions WHERE active = 1 AND recipient_id = ?"
c = self.dbConn.cursor()
result = []
for row in c.execute(q, (recipientId,)):
public_key = (SessionRecord(serialized=row[0]).
getSessionState().getRemoteIdentityKey().
getPublicKey())
result.append(public_key.serialize())
return result
def getAllActiveSessionsKeys(self):
q = "SELECT record FROM sessions WHERE active = 1"
c = self.dbConn.cursor()
result = []
for row in c.execute(q):
public_key = (SessionRecord(serialized=row[0]).
getSessionState().getRemoteIdentityKey().
getPublicKey())
result.append(public_key.serialize())
return result
def getInactiveSessionsKeys(self, recipientId):
q = "SELECT record FROM sessions WHERE active = 0 AND recipient_id = ?"
c = self.dbConn.cursor()
......
......@@ -29,6 +29,14 @@ class SQLDatabase():
self.dbConn = dbConn
self.createDb()
self.migrateDb()
c = self.dbConn.cursor()
c.execute("PRAGMA synchronous=NORMAL;")
c.execute("PRAGMA journal_mode;")
mode = c.fetchone()[0]
# WAL is a persistent DB mode, dont override it if user has set it
if mode != 'wal':
c.execute("PRAGMA journal_mode=MEMORY;")
self.dbConn.commit()
def createDb(self):
if user_version(self.dbConn) == 0:
......
......@@ -200,8 +200,8 @@ class OmemoState:
key = self.handleWhisperMessage(sender_jid, sid, encrypted_key)
except (NoSessionException, InvalidMessageException) as e:
log.warning('No Session found ' + e.message)
log.warning('sender_jid => ' + str(sender_jid) +
' sid =>' + sid)
log.warning('sender_jid => ' + str(sender_jid) + ' sid =>' +
str(sid))
return
except (DuplicateMessageException) as e:
log.warning('Duplicate message found ' + str(e.args))
......@@ -211,7 +211,7 @@ class OmemoState:
log.warning('Duplicate message found ' + str(e.args))
return
result = decrypt(key, iv, payload).decode('utf-8')
result = decrypt(key, iv, payload)
log.debug("Decrypted Message => " + result)
return result
......@@ -226,24 +226,97 @@ class OmemoState:
log.error('No known devices')
return
for dev in devices_list:
self.get_session_cipher(jid, dev)
session_ciphers = self.session_ciphers[jid]
if not session_ciphers:
log.warning('No session ciphers for ' + jid)
return
payload, tag = encrypt(key, iv, plaintext)
# for XEP-384 Compliance uncomment
# key += tag
payload += tag
# Encrypt the message key with for each of receivers devices
for rid, cipher in session_ciphers.items():
for device in devices_list:
try:
if self.isTrusted(cipher) == TRUSTED:
encrypted_keys[rid] = cipher.encrypt(key).serialize()
if self.isTrusted(jid, device) == TRUSTED:
cipher = self.get_session_cipher(jid, device)
cipher_key = cipher.encrypt(key)
prekey = isinstance(cipher_key, PreKeyWhisperMessage)
encrypted_keys[device] = (cipher_key.serialize(), prekey)
else:
log.debug('Skipped Device because Trust is: ' +
str(self.isTrusted(cipher)))
str(self.isTrusted(jid, device)))
except:
log.warning('Failed to find key for device ' + str(rid))
log.warning('Failed to find key for device ' + str(device))
if len(encrypted_keys) == 0:
log.error('Encrypted keys empty')
raise NoValidSessions('Encrypted keys empty')
my_other_devices = set(self.own_devices) - set({self.own_device_id})
# Encrypt the message key with for each of our own devices
for device in my_other_devices:
try:
if self.isTrusted(from_jid, device) == TRUSTED:
cipher = self.get_session_cipher(from_jid, device)
cipher_key = cipher.encrypt(key)
prekey = isinstance(cipher_key, PreKeyWhisperMessage)
encrypted_keys[device] = (cipher_key.serialize(), prekey)
else:
log.debug('Skipped own Device because Trust is: ' +
str(self.isTrusted(from_jid, device)))
except:
log.warning('Failed to find key for device ' + str(device))
result = {'sid': self.own_device_id,
'keys': encrypted_keys,
'jid': jid,
'iv': iv,
'payload': payload}
log.debug('Finished encrypting message')
return result
def create_gc_msg(self, from_jid, jid, plaintext):
key = get_random_bytes(16)
iv = get_random_bytes(16)
encrypted_keys = {}
room = jid
encrypted_jids = []
devices_list = self.device_list_for(jid, True)
if len(devices_list) == 0:
log.error('No known devices')
return
payload, tag = encrypt(key, iv, plaintext)
# for XEP-384 Compliance uncomment
# key += tag
payload += tag
for tup in devices_list:
self.get_session_cipher(tup[0], tup[1])
# Encrypt the message key with for each of receivers devices
for nick in self.plugin.groupchat[room]:
jid_to = self.plugin.groupchat[room][nick]
if jid_to == self.own_jid:
continue
if jid_to in encrypted_jids: # We already encrypted to this JID
continue
for rid, cipher in self.session_ciphers[jid_to].items():
try:
if self.isTrusted(jid_to, rid) == TRUSTED:
cipher_key = cipher.encrypt(key)
prekey = isinstance(cipher_key, PreKeyWhisperMessage)
encrypted_keys[rid] = (cipher_key.serialize(), prekey)
else:
log.debug('Skipped Device because Trust is: ' +
str(self.isTrusted(jid_to, rid)))
except:
log.exception('ERROR:')
log.warning('Failed to find key for device ' +
str(rid))
encrypted_jids.append(jid_to)
if len(encrypted_keys) == 0:
log_msg = 'Encrypted keys empty'
log.error(log_msg)
......@@ -254,16 +327,17 @@ class OmemoState:
for dev in my_other_devices:
try:
cipher = self.get_session_cipher(from_jid, dev)
if self.isTrusted(cipher) == TRUSTED:
encrypted_keys[dev] = cipher.encrypt(key).serialize()
if self.isTrusted(from_jid, dev) == TRUSTED:
cipher_key = cipher.encrypt(key)
prekey = isinstance(cipher_key, PreKeyWhisperMessage)
encrypted_keys[dev] = (cipher_key.serialize(), prekey)
else:
log.debug('Skipped own Device because Trust is: ' +
str(self.isTrusted(cipher)))
str(self.isTrusted(from_jid, dev)))
except:
log.exception('ERROR:')
log.warning('Failed to find key for device ' + str(dev))
payload = encrypt(key, iv, plaintext)
result = {'sid': self.own_device_id,
'keys': encrypted_keys,
'jid': jid,
......@@ -273,14 +347,36 @@ class OmemoState:
log.debug('Finished encrypting message')
return result
def isTrusted(self, cipher):
self.cipher = cipher
self.state = self.cipher.sessionStore. \
loadSession(self.cipher.recipientId, self.cipher.deviceId). \
getSessionState()
self.key = self.state.getRemoteIdentityKey()
return self.store.identityKeyStore. \
isTrustedIdentity(self.cipher.recipientId, self.key)
def device_list_for(self, jid, gc=False):
""" Return a list of known device ids for the specified jid.
Parameters
----------
jid : string
The contacts jid
gc : bool
Groupchat Message
"""
if gc:
room = jid
devicelist = []
for nick in self.plugin.groupchat[room]:
jid_to = self.plugin.groupchat[room][nick]
if jid_to == self.own_jid:
continue
for device in self.device_ids[jid_to]:
devicelist.append((jid_to, device))
return devicelist
if jid == self.own_jid:
return set(self.own_devices) - set({self.own_device_id})
if jid not in self.device_ids:
return set()
return set(self.device_ids[jid])
def isTrusted(self, recipient_id, device_id):
record = self.store.loadSession(recipient_id, device_id)
identity_key = record.getSessionState().getRemoteIdentityKey()
return self.store.isTrustedIdentity(recipient_id, identity_key)
def getTrustedFingerprints(self, recipient_id):
inactive = self.store.getInactiveSessionsKeys(recipient_id)
......@@ -296,20 +392,6 @@ class OmemoState:
return undecided
def device_list_for(self, jid):
""" Return a list of known device ids for the specified jid.
Parameters
----------
jid : string
The contacts jid
"""
if jid == self.own_jid:
return set(self.own_devices) - set({self.own_device_id})
if jid not in self.device_ids:
return set()
return set(self.device_ids[jid])
def devices_without_sessions(self, jid):
""" List device_ids for the given jid which have no axolotl session.
......@@ -364,10 +446,10 @@ class OmemoState:
def handleWhisperMessage(self, recipient_id, device_id, key):
whisperMessage = WhisperMessage(serialized=key)
sessionCipher = self.get_session_cipher(recipient_id, device_id)
log.debug(self.account + " => Received WhisperMessage from " +
recipient_id)
if self.isTrusted(sessionCipher) >= TRUSTED:
if self.isTrusted(recipient_id, device_id):
sessionCipher = self.get_session_cipher(recipient_id, device_id)
key = sessionCipher.decryptMsg(whisperMessage, textMsg=False)
return key
else:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment