scrypt.py 2.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. # This file is dual licensed under the terms of the Apache License, Version
  2. # 2.0, and the BSD License. See the LICENSE file in the root of this repository
  3. # for complete details.
  4. import sys
  5. from cryptography import utils
  6. from cryptography.exceptions import (
  7. AlreadyFinalized,
  8. InvalidKey,
  9. UnsupportedAlgorithm,
  10. _Reasons,
  11. )
  12. from cryptography.hazmat.backends import _get_backend
  13. from cryptography.hazmat.backends.interfaces import ScryptBackend
  14. from cryptography.hazmat.primitives import constant_time
  15. from cryptography.hazmat.primitives.kdf import KeyDerivationFunction
  16. # This is used by the scrypt tests to skip tests that require more memory
  17. # than the MEM_LIMIT
  18. _MEM_LIMIT = sys.maxsize // 2
  19. class Scrypt(KeyDerivationFunction):
  20. def __init__(
  21. self, salt: bytes, length: int, n: int, r: int, p: int, backend=None
  22. ):
  23. backend = _get_backend(backend)
  24. if not isinstance(backend, ScryptBackend):
  25. raise UnsupportedAlgorithm(
  26. "Backend object does not implement ScryptBackend.",
  27. _Reasons.BACKEND_MISSING_INTERFACE,
  28. )
  29. self._length = length
  30. utils._check_bytes("salt", salt)
  31. if n < 2 or (n & (n - 1)) != 0:
  32. raise ValueError("n must be greater than 1 and be a power of 2.")
  33. if r < 1:
  34. raise ValueError("r must be greater than or equal to 1.")
  35. if p < 1:
  36. raise ValueError("p must be greater than or equal to 1.")
  37. self._used = False
  38. self._salt = salt
  39. self._n = n
  40. self._r = r
  41. self._p = p
  42. self._backend = backend
  43. def derive(self, key_material: bytes) -> bytes:
  44. if self._used:
  45. raise AlreadyFinalized("Scrypt instances can only be used once.")
  46. self._used = True
  47. utils._check_byteslike("key_material", key_material)
  48. return self._backend.derive_scrypt(
  49. key_material, self._salt, self._length, self._n, self._r, self._p
  50. )
  51. def verify(self, key_material: bytes, expected_key: bytes) -> None:
  52. derived_key = self.derive(key_material)
  53. if not constant_time.bytes_eq(derived_key, expected_key):
  54. raise InvalidKey("Keys do not match.")