3.14 PARALLEL COUNTER MODE

EXERCISE 3.14: PARALLEL COUNTER MODE

Extend your counter mode implementation to use a thread pool to generate the key stream in parallel. Remember that to generate a block of key stream, all that is required is the starting IV and which block of key stream is being generated (e.g., 0 for the first 16-byte block, 1 for the second 16-byte block, etc.). Start by creating a function that can generate any particular block of key stream, perhaps something like keystream(IV, i). Next, parallelize the generation of a key stream up to n by dividing the counter sequence among independent processes any way you please, and have them all work on generating their key stream blocks independently.


The following is a parallel version of the code in the previous question.

I used concurrent.futures.ThreadPoolExecutor to parallelize.

# ex3_14.py 

from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from cryptography.hazmat.backends import default_backend
import os 
import concurrent.futures


AES_BLOCK_SIZE_IN_BYTES = 16
AES_BLOCK_SIZE_IN_BITS = 128
NUMBER_OF_KEYSTREAM_BLOCKS_TO_GENERATE_AT_ONCE = 30
MAX_NUMBER_OF_WORKERS = 6

class MyOwnCTR:
    def __init__(self, key: bytes, nonce: bytes): 
        assert(len(nonce) == AES_BLOCK_SIZE_IN_BYTES)
        self.key = key 
        self.nonce = nonce 

    def encryptor(self): 
        return Encryptor(config=self) 
    
    def decryptor(self): 
        # Note that in CTR, encryption and decryption are exactly the same operations. 
        return Encryptor(config=self)
    
class Encryptor: 
    def __init__(self, config: MyOwnCTR):  
        self.config = config 
        self.keystream = []
        self.current_index = 0 
        self.e_current_nonce = b"" 
        self.buffer = b""

        self._encryptor = Cipher(
            algorithm=algorithms.AES(self.config.key), 
            mode=modes.ECB(), 
            backend=default_backend(),
        ).encryptor()
    
    def advance(self): 
        if len(self.keystream) > 0: 
            self.e_current_nonce = self.keystream[0] 
            del self.keystream[0] 
            return
        
        # since the keystream is empty, let's populate it. 
        with concurrent.futures.ThreadPoolExecutor(max_workers=MAX_NUMBER_OF_WORKERS) as executor: 
            future_to_index = dict()

            for i in range(NUMBER_OF_KEYSTREAM_BLOCKS_TO_GENERATE_AT_ONCE): 
                fut = executor.submit(self.get_the_ith_keyblock, self.config.nonce, self.current_index)
                future_to_index[fut] = i
                self.current_index += 1 
            
            result_to_index = []

            for future in concurrent.futures.as_completed(future_to_index): 
                i = future_to_index[future]
                keyblock = future.result() 
                result_to_index.append((keyblock,i))
            
            result_to_index.sort(key=lambda x: x[1])
            for t in result_to_index: 
                self.keystream.append(t[0])        

        self.e_current_nonce = self.keystream[0] 
        del self.keystream[0] 
        return
  
    def update(self, plaintext: bytes) -> bytes: 
        self.buffer += plaintext
        retval = b""  
        while len(self.buffer) > 0:
            if len(self.e_current_nonce) == 0: 
                self.advance()
            k = min(
                len(self.buffer), 
                len(self.e_current_nonce),
            )
            retval += xor_two_byte_strings(
                self.buffer[:k],
                self.e_current_nonce[:k],
            )
            self.buffer = self.buffer[k:]
            self.e_current_nonce = self.e_current_nonce[k:]
        return retval
    
    def finalize(self): 
        assert(len(self.buffer) == 0)
        return b"" 

    def get_the_ith_keyblock(self, nonce: bytes, i: int) -> bytes: 
        # get the i th nonce
        i_th_nonce = (int.from_bytes(nonce,'big') + i)%(2**128)
        i_th_nonce = int.to_bytes(i_th_nonce, length=16, byteorder='big')
        return self._encryptor.update(i_th_nonce)

# the following function is taken from Exercise 3.9. 
def xor_two_byte_strings(x: bytes, y: bytes) -> bytes: 
    assert(len(x) == len(y))
    
    result = []
    for _1, _2 in zip(x, y): 
        result.append(_1 ^ _2)
    return bytes(result)

if __name__ == '__main__': 
    key = os.urandom(32) 
    nonce = os.urandom(16) 

    plaintext = b"This is a very secret and long plaintext..."

    my_ctr = MyOwnCTR(key=key, nonce=nonce) 
    my_ctr_encryptor = my_ctr.encryptor()
    my_ctr_decryptor = my_ctr.decryptor()

    ciphertext1 = my_ctr_encryptor.update(plaintext)
    print(my_ctr_decryptor.update(ciphertext1))

    official_ctr = Cipher(
        algorithm=algorithms.AES(key), 
        mode=modes.CTR(nonce),
        backend=default_backend(),
    )
    official_ctr_encryptor = official_ctr.encryptor()
    official_ctr_decryptor = official_ctr.decryptor()

    ciphertext2 = official_ctr_encryptor.update(plaintext) 
    print(official_ctr_decryptor.update(ciphertext2))

    print(f"Passed: {ciphertext1 == ciphertext2}")

Running the above code gives the following result: