import sys
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from cryptography.hazmat.primitives.hmac import HMAC
from cryptography.hazmat.primitives.kdf.hkdf import HKDF
from cryptography.hazmat.primitives.padding import PKCS7
from .crypto_provider import CryptoProvider, HashFunction
from .. import aead
if sys.version_info >= (3, 11):
from typing import assert_never
else:
from typing_extensions import assert_never
__all__ = [
"CryptoProviderImpl"
]
def get_hash_algorithm(hash_function: HashFunction) -> hashes.HashAlgorithm:
"""
Args:
hash_function: Identifier of a hash function.
Returns:
The implementation of the hash function as a cryptography
:class:`~cryptography.hazmat.primitives.hashes.HashAlgorithm` object.
"""
if hash_function is HashFunction.SHA_256:
return hashes.SHA256()
if hash_function is HashFunction.SHA_512:
return hashes.SHA512()
if hash_function is HashFunction.SHA_512_256:
return hashes.SHA512_256()
return assert_never(hash_function)
[docs]
class CryptoProviderImpl(CryptoProvider):
"""
Cryptography provider based on the Python package `cryptography <https://github.com/pyca/cryptography>`_.
"""
[docs]
@staticmethod
async def hkdf_derive(
hash_function: HashFunction,
length: int,
salt: bytes,
info: bytes,
key_material: bytes
) -> bytes:
return HKDF(
algorithm=get_hash_algorithm(hash_function),
length=length,
salt=salt,
info=info,
backend=default_backend()
).derive(key_material)
[docs]
@staticmethod
async def hmac_calculate(key: bytes, hash_function: HashFunction, data: bytes) -> bytes:
hmac = HMAC(key, get_hash_algorithm(hash_function), backend=default_backend())
hmac.update(data)
return hmac.finalize()
[docs]
@staticmethod
async def aes_cbc_encrypt(key: bytes, initialization_vector: bytes, plaintext: bytes) -> bytes:
# Prepare PKCS#7 padded plaintext
padder = PKCS7(128).padder()
padded_plaintext = padder.update(plaintext) + padder.finalize()
# Encrypt the plaintext using AES-CBC
aes = Cipher(
algorithms.AES(key),
modes.CBC(initialization_vector),
backend=default_backend()
).encryptor()
return aes.update(padded_plaintext) + aes.finalize()
[docs]
@staticmethod
async def aes_cbc_decrypt(key: bytes, initialization_vector: bytes, ciphertext: bytes) -> bytes:
# Decrypt the plaintext using AES-CBC
try:
aes = Cipher(
algorithms.AES(key),
modes.CBC(initialization_vector),
backend=default_backend()
).decryptor()
padded_plaintext = aes.update(ciphertext) + aes.finalize()
except ValueError as e:
raise aead.DecryptionFailedException("Decryption failed.") from e
# Remove the PKCS#7 padding from the plaintext
try:
unpadder = PKCS7(128).unpadder()
return unpadder.update(padded_plaintext) + unpadder.finalize()
except ValueError as e:
raise aead.DecryptionFailedException("Plaintext padded incorrectly.") from e