5.6 ENCRYPT THEN MAC

EXERCISE 5.6 ENCRYPT THEN MAC

Update the code from the beginning of the chapter to do a proper MAC by replacing the SHA-256 operation with HMAC or CMAC. Use two keys.


Defining the core classes

from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import cmac, hashes, hmac
from enum import Enum

class MessageWasTamperedWithError(Exception):
    '''
    Raised when the ciphertext has been tampered with 
    by Eve.
    '''

    def __init__(self, expected_mac: str, actual_mac: str):
        self.expected_mac = expected_mac
        self.actual_mac = actual_mac

    def __str__(self):
        return f"\nExpected: {self.expected_mac}.\nBut got: {self.actual_mac}"

class UnknownMacException(Exception): 
    pass 

class Mac(Enum):
    '''
    * If you are using HMAC_WITH_SHA256, the MAC's length  
    would be 32 bytes.

    * If you are using CMAC, the MAC's length would
    be 16 bytes.
    '''
    HMAC_WITH_SHA256 = 1
    CMAC = 2

def get_hasher(mac_type: Mac, mac_key: bytes) -> hmac.HMAC | cmac.CMAC: 
    if mac_type == Mac.HMAC_WITH_SHA256:
        hasher = hmac.HMAC(
                mac_key, hashes.SHA256(), backend=default_backend())
    elif mac_type == Mac.CMAC:
        hasher = cmac.CMAC(algorithms.AES(
                mac_key), backend=default_backend())
    else: 
        raise UnknownMacException()
    return hasher

class Encryptor:
    def __init__(
        self, 
        encryption_key: bytes, 
        nonce: bytes, 
        mac_type: Mac, 
        mac_key: bytes,
    ):
        aesContext = Cipher(
            algorithms.AES(encryption_key),
            modes.CTR(nonce),
            backend=default_backend()
        )
        self.encryptor = aesContext.encryptor()
        self.hasher = get_hasher(mac_type, mac_key)

    def update_encryptor(self, plaintext):
        ciphertext = self.encryptor.update(plaintext)
        self.hasher.update(ciphertext)
        return ciphertext

    def finalize_encryptor(self):
        return self.encryptor.finalize() + self.hasher.finalize()


class Decryptor:
    def __init__(
        self, 
        encryption_key: bytes, 
        nonce: bytes, 
        mac_type: Mac, 
        mac_key: bytes, 
        digest: bytes,
    ):
        '''
        Note that in Symmetric ciphers such as AES, the encryption and decryption
        keys are the same.
        '''
        aesContext = Cipher(
            algorithms.AES(encryption_key),
            modes.CTR(nonce),
            backend=default_backend()
        )
        self.decryptor = aesContext.decryptor()
        self.digest = digest
        self.hasher = get_hasher(mac_type, mac_key)

    def update_decryptor(self, ciphertext: bytes):
        '''Make sure that ciphertext doesn't include the MAC'''
        plaintext = self.decryptor.update(ciphertext)
        self.hasher.update(ciphertext)
        return plaintext

    def finalize_decryptor(self):
        expected_mac = self.hasher.finalize()

        if expected_mac != self.digest:
            raise MessageWasTamperedWithError(
                expected_mac=expected_mac,
                actual_mac=self.digest
            )
        return self.decryptor.finalize()

    @staticmethod
    def get_mac(ciphertext_with_mac: bytes, mac_type: Mac):
        if mac_type == Mac.HMAC_WITH_SHA256:
            return ciphertext_with_mac[-32:]
        elif mac_type == Mac.CMAC: 
            return ciphertext_with_mac[-16:]
        else: 
            raise UnknownMacException()

Now that we have defined the classes we need, let’s get to the fun part:

Using HMAC…

from secrets import token_bytes

encryption_key = token_bytes(32)
nonce = token_bytes(16) 
# we used 64 here because block size of sha256 is 64. see EX5.5.
mac_key = token_bytes(64) 
message = b"Hello Bob, how are you?"

print(f"encryption_key: {encryption_key.hex(' ')}")
print(f"nonce: {nonce.hex(' ')}")
print(f"mac_key: {mac_key.hex(' ')}")
print(f"message: {message.hex(' ')}")
encryption_key: 5d 9f 90 12 54 13 aa 8f b6 c7 52 6b 73 b7 a4 99 74 d2 1b c3 a1 b4 59 ac 6f 2f 83 65 8e f7 3d 40
nonce: 95 75 72 b3 dd 07 ab 67 76 e1 2f c6 18 63 fd c7
mac_key: c5 06 7e 07 e9 bb 2d 64 1f 5b 6f 63 af ea f4 e1 bd 52 ec 35 d5 60 30 3f 25 4a 4d f1 0a e5 ab 1a f6 0d 77 4d a4 be 2c f2 48 1f 18 12 64 a6 2e 8c a9 35 8a df 19 f4 07 50 73 50 9f b3 64 29 53 ab
message: 48 65 6c 6c 6f 20 42 6f 62 2c 20 68 6f 77 20 61 72 65 20 79 6f 75 3f
encryptor = Encryptor(
    encryption_key=encryption_key,
    nonce=nonce,
    mac_type=Mac.HMAC_WITH_SHA256,
    mac_key=mac_key,
)
ciphertext_with_mac = encryptor.update_encryptor(message)
ciphertext_with_mac += encryptor.finalize_encryptor()

print(ciphertext_with_mac.hex(' '))
30 f5 77 d0 34 66 08 b0 52 3b 87 00 d3 75 4d 10 a5 97 49 e1 ff d5 cc 2e d0 73 3b 67 9f c7 53 ae 89 c9 81 4c 61 5e 4f 05 66 3e 7e 13 9d 00 81 7a e0 91 51 67 32 75 82

Now let’s decrypt and verify:

decryptor = Decryptor(
    encryption_key=encryption_key,
    nonce=nonce,
    mac_type=Mac.HMAC_WITH_SHA256,
    mac_key=mac_key,
    digest=Decryptor.get_mac(ciphertext_with_mac, Mac.HMAC_WITH_SHA256)
)
plaintext = decryptor.update_decryptor(ciphertext_with_mac[:-32])
plaintext += decryptor.finalize_decryptor()

Phew, that didn’t throw an exception…

print(plaintext)
b'Hello Bob, how are you?'

Using CMAC…

There is no need to redefine encryption_key, nonce & message again here.

mac_key = token_bytes(32)

print(f"mac_key: {mac_key.hex(' ')}")
mac_key: 87 1b f3 2a a2 84 d5 50 a1 c7 04 7e 81 3f 43 e6 ed 9a ca 07 78 04 49 e3 1d 2d a0 29 56 77 e5 bf
encryptor = Encryptor(
    encryption_key=encryption_key,
    nonce=nonce,
    mac_type=Mac.CMAC,
    mac_key=mac_key,
)
ciphertext_with_mac = encryptor.update_encryptor(message)
ciphertext_with_mac += encryptor.finalize_encryptor()

print(ciphertext_with_mac.hex(' '))
30 f5 77 d0 34 66 08 b0 52 3b 87 00 d3 75 4d 10 a5 97 49 e1 ff d5 cc 3c cd da 7c 33 91 02 3c ac d5 cb fa cc 95 74 b6
decryptor = Decryptor(
    encryption_key=encryption_key,
    nonce=nonce,
    mac_type=Mac.CMAC,
    mac_key=mac_key,
    digest=Decryptor.get_mac(ciphertext_with_mac, Mac.CMAC)
)

Note that we are ignoring the last 16 bytes. This is because we are using CMAC:

plaintext = decryptor.update_decryptor(ciphertext_with_mac[:-16])
plaintext += decryptor.finalize_decryptor()
print(plaintext)
b'Hello Bob, how are you?'

🚀🚀🚀🚀