concatkdf.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
  1. # This file is dual licensed under the terms of the Apache License, Version
  2. # 2.0, and the BSD License. See the LICENSE file in the root of this repository
  3. # for complete details.
  4. import struct
  5. import typing
  6. from cryptography import utils
  7. from cryptography.exceptions import (
  8. AlreadyFinalized,
  9. InvalidKey,
  10. UnsupportedAlgorithm,
  11. _Reasons,
  12. )
  13. from cryptography.hazmat.backends import _get_backend
  14. from cryptography.hazmat.backends.interfaces import HMACBackend
  15. from cryptography.hazmat.backends.interfaces import HashBackend
  16. from cryptography.hazmat.primitives import constant_time, hashes, hmac
  17. from cryptography.hazmat.primitives.kdf import KeyDerivationFunction
  18. def _int_to_u32be(n: int) -> bytes:
  19. return struct.pack(">I", n)
  20. def _common_args_checks(
  21. algorithm: hashes.HashAlgorithm,
  22. length: int,
  23. otherinfo: typing.Optional[bytes],
  24. ):
  25. max_length = algorithm.digest_size * (2 ** 32 - 1)
  26. if length > max_length:
  27. raise ValueError(
  28. "Can not derive keys larger than {} bits.".format(max_length)
  29. )
  30. if otherinfo is not None:
  31. utils._check_bytes("otherinfo", otherinfo)
  32. def _concatkdf_derive(
  33. key_material: bytes,
  34. length: int,
  35. auxfn: typing.Callable[[], hashes.HashContext],
  36. otherinfo: bytes,
  37. ) -> bytes:
  38. utils._check_byteslike("key_material", key_material)
  39. output = [b""]
  40. outlen = 0
  41. counter = 1
  42. while length > outlen:
  43. h = auxfn()
  44. h.update(_int_to_u32be(counter))
  45. h.update(key_material)
  46. h.update(otherinfo)
  47. output.append(h.finalize())
  48. outlen += len(output[-1])
  49. counter += 1
  50. return b"".join(output)[:length]
  51. class ConcatKDFHash(KeyDerivationFunction):
  52. def __init__(
  53. self,
  54. algorithm: hashes.HashAlgorithm,
  55. length: int,
  56. otherinfo: typing.Optional[bytes],
  57. backend=None,
  58. ):
  59. backend = _get_backend(backend)
  60. _common_args_checks(algorithm, length, otherinfo)
  61. self._algorithm = algorithm
  62. self._length = length
  63. self._otherinfo: bytes = otherinfo if otherinfo is not None else b""
  64. if not isinstance(backend, HashBackend):
  65. raise UnsupportedAlgorithm(
  66. "Backend object does not implement HashBackend.",
  67. _Reasons.BACKEND_MISSING_INTERFACE,
  68. )
  69. self._backend = backend
  70. self._used = False
  71. def _hash(self) -> hashes.Hash:
  72. return hashes.Hash(self._algorithm, self._backend)
  73. def derive(self, key_material: bytes) -> bytes:
  74. if self._used:
  75. raise AlreadyFinalized
  76. self._used = True
  77. return _concatkdf_derive(
  78. key_material, self._length, self._hash, self._otherinfo
  79. )
  80. def verify(self, key_material: bytes, expected_key: bytes) -> None:
  81. if not constant_time.bytes_eq(self.derive(key_material), expected_key):
  82. raise InvalidKey
  83. class ConcatKDFHMAC(KeyDerivationFunction):
  84. def __init__(
  85. self,
  86. algorithm: hashes.HashAlgorithm,
  87. length: int,
  88. salt: typing.Optional[bytes],
  89. otherinfo: typing.Optional[bytes],
  90. backend=None,
  91. ):
  92. backend = _get_backend(backend)
  93. _common_args_checks(algorithm, length, otherinfo)
  94. self._algorithm = algorithm
  95. self._length = length
  96. self._otherinfo: bytes = otherinfo if otherinfo is not None else b""
  97. if algorithm.block_size is None:
  98. raise TypeError(
  99. "{} is unsupported for ConcatKDF".format(algorithm.name)
  100. )
  101. if salt is None:
  102. salt = b"\x00" * algorithm.block_size
  103. else:
  104. utils._check_bytes("salt", salt)
  105. self._salt = salt
  106. if not isinstance(backend, HMACBackend):
  107. raise UnsupportedAlgorithm(
  108. "Backend object does not implement HMACBackend.",
  109. _Reasons.BACKEND_MISSING_INTERFACE,
  110. )
  111. self._backend = backend
  112. self._used = False
  113. def _hmac(self) -> hmac.HMAC:
  114. return hmac.HMAC(self._salt, self._algorithm, self._backend)
  115. def derive(self, key_material: bytes) -> bytes:
  116. if self._used:
  117. raise AlreadyFinalized
  118. self._used = True
  119. return _concatkdf_derive(
  120. key_material, self._length, self._hmac, self._otherinfo
  121. )
  122. def verify(self, key_material: bytes, expected_key: bytes) -> None:
  123. if not constant_time.bytes_eq(self.derive(key_material), expected_key):
  124. raise InvalidKey