common.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512
  1. # -*- coding: utf-8 -*-
  2. #
  3. # SelfTest/Hash/common.py: Common code for Crypto.SelfTest.Hash
  4. #
  5. # Written in 2008 by Dwayne C. Litzenberger <dlitz@dlitz.net>
  6. #
  7. # ===================================================================
  8. # The contents of this file are dedicated to the public domain. To
  9. # the extent that dedication to the public domain is not available,
  10. # everyone is granted a worldwide, perpetual, royalty-free,
  11. # non-exclusive license to exercise all rights associated with the
  12. # contents of this file for any purpose whatsoever.
  13. # No rights are reserved.
  14. #
  15. # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
  16. # EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
  17. # MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
  18. # NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
  19. # BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
  20. # ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
  21. # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  22. # SOFTWARE.
  23. # ===================================================================
  24. """Self-testing for PyCrypto hash modules"""
  25. import unittest
  26. from binascii import a2b_hex, b2a_hex, hexlify
  27. from Crypto.Util.py3compat import b
  28. from Crypto.Util.strxor import strxor_c
  29. class _NoDefault: pass # sentinel object
  30. def _extract(d, k, default=_NoDefault):
  31. """Get an item from a dictionary, and remove it from the dictionary."""
  32. try:
  33. retval = d[k]
  34. except KeyError:
  35. if default is _NoDefault:
  36. raise
  37. return default
  38. del d[k]
  39. return retval
  40. # Generic cipher test case
  41. class CipherSelfTest(unittest.TestCase):
  42. def __init__(self, module, params):
  43. unittest.TestCase.__init__(self)
  44. self.module = module
  45. # Extract the parameters
  46. params = params.copy()
  47. self.description = _extract(params, 'description')
  48. self.key = b(_extract(params, 'key'))
  49. self.plaintext = b(_extract(params, 'plaintext'))
  50. self.ciphertext = b(_extract(params, 'ciphertext'))
  51. self.module_name = _extract(params, 'module_name', None)
  52. self.assoc_data = _extract(params, 'assoc_data', None)
  53. self.mac = _extract(params, 'mac', None)
  54. if self.assoc_data:
  55. self.mac = b(self.mac)
  56. mode = _extract(params, 'mode', None)
  57. self.mode_name = str(mode)
  58. if mode is not None:
  59. # Block cipher
  60. self.mode = getattr(self.module, "MODE_" + mode)
  61. self.iv = _extract(params, 'iv', None)
  62. if self.iv is None:
  63. self.iv = _extract(params, 'nonce', None)
  64. if self.iv is not None:
  65. self.iv = b(self.iv)
  66. else:
  67. # Stream cipher
  68. self.mode = None
  69. self.iv = _extract(params, 'iv', None)
  70. if self.iv is not None:
  71. self.iv = b(self.iv)
  72. self.extra_params = params
  73. def shortDescription(self):
  74. return self.description
  75. def _new(self):
  76. params = self.extra_params.copy()
  77. key = a2b_hex(self.key)
  78. old_style = []
  79. if self.mode is not None:
  80. old_style = [ self.mode ]
  81. if self.iv is not None:
  82. old_style += [ a2b_hex(self.iv) ]
  83. return self.module.new(key, *old_style, **params)
  84. def isMode(self, name):
  85. if not hasattr(self.module, "MODE_"+name):
  86. return False
  87. return self.mode == getattr(self.module, "MODE_"+name)
  88. def runTest(self):
  89. plaintext = a2b_hex(self.plaintext)
  90. ciphertext = a2b_hex(self.ciphertext)
  91. assoc_data = []
  92. if self.assoc_data:
  93. assoc_data = [ a2b_hex(b(x)) for x in self.assoc_data]
  94. ct = None
  95. pt = None
  96. #
  97. # Repeat the same encryption or decryption twice and verify
  98. # that the result is always the same
  99. #
  100. for i in range(2):
  101. cipher = self._new()
  102. decipher = self._new()
  103. # Only AEAD modes
  104. for comp in assoc_data:
  105. cipher.update(comp)
  106. decipher.update(comp)
  107. ctX = b2a_hex(cipher.encrypt(plaintext))
  108. ptX = b2a_hex(decipher.decrypt(ciphertext))
  109. if ct:
  110. self.assertEqual(ct, ctX)
  111. self.assertEqual(pt, ptX)
  112. ct, pt = ctX, ptX
  113. self.assertEqual(self.ciphertext, ct) # encrypt
  114. self.assertEqual(self.plaintext, pt) # decrypt
  115. if self.mac:
  116. mac = b2a_hex(cipher.digest())
  117. self.assertEqual(self.mac, mac)
  118. decipher.verify(a2b_hex(self.mac))
  119. class CipherStreamingSelfTest(CipherSelfTest):
  120. def shortDescription(self):
  121. desc = self.module_name
  122. if self.mode is not None:
  123. desc += " in %s mode" % (self.mode_name,)
  124. return "%s should behave like a stream cipher" % (desc,)
  125. def runTest(self):
  126. plaintext = a2b_hex(self.plaintext)
  127. ciphertext = a2b_hex(self.ciphertext)
  128. # The cipher should work like a stream cipher
  129. # Test counter mode encryption, 3 bytes at a time
  130. ct3 = []
  131. cipher = self._new()
  132. for i in range(0, len(plaintext), 3):
  133. ct3.append(cipher.encrypt(plaintext[i:i+3]))
  134. ct3 = b2a_hex(b("").join(ct3))
  135. self.assertEqual(self.ciphertext, ct3) # encryption (3 bytes at a time)
  136. # Test counter mode decryption, 3 bytes at a time
  137. pt3 = []
  138. cipher = self._new()
  139. for i in range(0, len(ciphertext), 3):
  140. pt3.append(cipher.encrypt(ciphertext[i:i+3]))
  141. # PY3K: This is meant to be text, do not change to bytes (data)
  142. pt3 = b2a_hex(b("").join(pt3))
  143. self.assertEqual(self.plaintext, pt3) # decryption (3 bytes at a time)
  144. class RoundtripTest(unittest.TestCase):
  145. def __init__(self, module, params):
  146. from Crypto import Random
  147. unittest.TestCase.__init__(self)
  148. self.module = module
  149. self.iv = Random.get_random_bytes(module.block_size)
  150. self.key = b(params['key'])
  151. self.plaintext = 100 * b(params['plaintext'])
  152. self.module_name = params.get('module_name', None)
  153. def shortDescription(self):
  154. return """%s .decrypt() output of .encrypt() should not be garbled""" % (self.module_name,)
  155. def runTest(self):
  156. ## ECB mode
  157. mode = self.module.MODE_ECB
  158. encryption_cipher = self.module.new(a2b_hex(self.key), mode)
  159. ciphertext = encryption_cipher.encrypt(self.plaintext)
  160. decryption_cipher = self.module.new(a2b_hex(self.key), mode)
  161. decrypted_plaintext = decryption_cipher.decrypt(ciphertext)
  162. self.assertEqual(self.plaintext, decrypted_plaintext)
  163. class IVLengthTest(unittest.TestCase):
  164. def __init__(self, module, params):
  165. unittest.TestCase.__init__(self)
  166. self.module = module
  167. self.key = b(params['key'])
  168. def shortDescription(self):
  169. return "Check that all modes except MODE_ECB and MODE_CTR require an IV of the proper length"
  170. def runTest(self):
  171. self.assertRaises(TypeError, self.module.new, a2b_hex(self.key),
  172. self.module.MODE_ECB, b(""))
  173. def _dummy_counter(self):
  174. return "\0" * self.module.block_size
  175. class NoDefaultECBTest(unittest.TestCase):
  176. def __init__(self, module, params):
  177. unittest.TestCase.__init__(self)
  178. self.module = module
  179. self.key = b(params['key'])
  180. def runTest(self):
  181. self.assertRaises(TypeError, self.module.new, a2b_hex(self.key))
  182. class BlockSizeTest(unittest.TestCase):
  183. def __init__(self, module, params):
  184. unittest.TestCase.__init__(self)
  185. self.module = module
  186. self.key = a2b_hex(b(params['key']))
  187. def runTest(self):
  188. cipher = self.module.new(self.key, self.module.MODE_ECB)
  189. self.assertEqual(cipher.block_size, self.module.block_size)
  190. class ByteArrayTest(unittest.TestCase):
  191. """Verify we can use bytearray's for encrypting and decrypting"""
  192. def __init__(self, module, params):
  193. unittest.TestCase.__init__(self)
  194. self.module = module
  195. # Extract the parameters
  196. params = params.copy()
  197. self.description = _extract(params, 'description')
  198. self.key = b(_extract(params, 'key'))
  199. self.plaintext = b(_extract(params, 'plaintext'))
  200. self.ciphertext = b(_extract(params, 'ciphertext'))
  201. self.module_name = _extract(params, 'module_name', None)
  202. self.assoc_data = _extract(params, 'assoc_data', None)
  203. self.mac = _extract(params, 'mac', None)
  204. if self.assoc_data:
  205. self.mac = b(self.mac)
  206. mode = _extract(params, 'mode', None)
  207. self.mode_name = str(mode)
  208. if mode is not None:
  209. # Block cipher
  210. self.mode = getattr(self.module, "MODE_" + mode)
  211. self.iv = _extract(params, 'iv', None)
  212. if self.iv is None:
  213. self.iv = _extract(params, 'nonce', None)
  214. if self.iv is not None:
  215. self.iv = b(self.iv)
  216. else:
  217. # Stream cipher
  218. self.mode = None
  219. self.iv = _extract(params, 'iv', None)
  220. if self.iv is not None:
  221. self.iv = b(self.iv)
  222. self.extra_params = params
  223. def _new(self):
  224. params = self.extra_params.copy()
  225. key = a2b_hex(self.key)
  226. old_style = []
  227. if self.mode is not None:
  228. old_style = [ self.mode ]
  229. if self.iv is not None:
  230. old_style += [ a2b_hex(self.iv) ]
  231. return self.module.new(key, *old_style, **params)
  232. def runTest(self):
  233. plaintext = a2b_hex(self.plaintext)
  234. ciphertext = a2b_hex(self.ciphertext)
  235. assoc_data = []
  236. if self.assoc_data:
  237. assoc_data = [ bytearray(a2b_hex(b(x))) for x in self.assoc_data]
  238. cipher = self._new()
  239. decipher = self._new()
  240. # Only AEAD modes
  241. for comp in assoc_data:
  242. cipher.update(comp)
  243. decipher.update(comp)
  244. ct = b2a_hex(cipher.encrypt(bytearray(plaintext)))
  245. pt = b2a_hex(decipher.decrypt(bytearray(ciphertext)))
  246. self.assertEqual(self.ciphertext, ct) # encrypt
  247. self.assertEqual(self.plaintext, pt) # decrypt
  248. if self.mac:
  249. mac = b2a_hex(cipher.digest())
  250. self.assertEqual(self.mac, mac)
  251. decipher.verify(bytearray(a2b_hex(self.mac)))
  252. class MemoryviewTest(unittest.TestCase):
  253. """Verify we can use memoryviews for encrypting and decrypting"""
  254. def __init__(self, module, params):
  255. unittest.TestCase.__init__(self)
  256. self.module = module
  257. # Extract the parameters
  258. params = params.copy()
  259. self.description = _extract(params, 'description')
  260. self.key = b(_extract(params, 'key'))
  261. self.plaintext = b(_extract(params, 'plaintext'))
  262. self.ciphertext = b(_extract(params, 'ciphertext'))
  263. self.module_name = _extract(params, 'module_name', None)
  264. self.assoc_data = _extract(params, 'assoc_data', None)
  265. self.mac = _extract(params, 'mac', None)
  266. if self.assoc_data:
  267. self.mac = b(self.mac)
  268. mode = _extract(params, 'mode', None)
  269. self.mode_name = str(mode)
  270. if mode is not None:
  271. # Block cipher
  272. self.mode = getattr(self.module, "MODE_" + mode)
  273. self.iv = _extract(params, 'iv', None)
  274. if self.iv is None:
  275. self.iv = _extract(params, 'nonce', None)
  276. if self.iv is not None:
  277. self.iv = b(self.iv)
  278. else:
  279. # Stream cipher
  280. self.mode = None
  281. self.iv = _extract(params, 'iv', None)
  282. if self.iv is not None:
  283. self.iv = b(self.iv)
  284. self.extra_params = params
  285. def _new(self):
  286. params = self.extra_params.copy()
  287. key = a2b_hex(self.key)
  288. old_style = []
  289. if self.mode is not None:
  290. old_style = [ self.mode ]
  291. if self.iv is not None:
  292. old_style += [ a2b_hex(self.iv) ]
  293. return self.module.new(key, *old_style, **params)
  294. def runTest(self):
  295. plaintext = a2b_hex(self.plaintext)
  296. ciphertext = a2b_hex(self.ciphertext)
  297. assoc_data = []
  298. if self.assoc_data:
  299. assoc_data = [ memoryview(a2b_hex(b(x))) for x in self.assoc_data]
  300. cipher = self._new()
  301. decipher = self._new()
  302. # Only AEAD modes
  303. for comp in assoc_data:
  304. cipher.update(comp)
  305. decipher.update(comp)
  306. ct = b2a_hex(cipher.encrypt(memoryview(plaintext)))
  307. pt = b2a_hex(decipher.decrypt(memoryview(ciphertext)))
  308. self.assertEqual(self.ciphertext, ct) # encrypt
  309. self.assertEqual(self.plaintext, pt) # decrypt
  310. if self.mac:
  311. mac = b2a_hex(cipher.digest())
  312. self.assertEqual(self.mac, mac)
  313. decipher.verify(memoryview(a2b_hex(self.mac)))
  314. def make_block_tests(module, module_name, test_data, additional_params=dict()):
  315. tests = []
  316. extra_tests_added = False
  317. for i in range(len(test_data)):
  318. row = test_data[i]
  319. # Build the "params" dictionary with
  320. # - plaintext
  321. # - ciphertext
  322. # - key
  323. # - mode (default is ECB)
  324. # - (optionally) description
  325. # - (optionally) any other parameter that this cipher mode requires
  326. params = {}
  327. if len(row) == 3:
  328. (params['plaintext'], params['ciphertext'], params['key']) = row
  329. elif len(row) == 4:
  330. (params['plaintext'], params['ciphertext'], params['key'], params['description']) = row
  331. elif len(row) == 5:
  332. (params['plaintext'], params['ciphertext'], params['key'], params['description'], extra_params) = row
  333. params.update(extra_params)
  334. else:
  335. raise AssertionError("Unsupported tuple size %d" % (len(row),))
  336. if not "mode" in params:
  337. params["mode"] = "ECB"
  338. # Build the display-name for the test
  339. p2 = params.copy()
  340. p_key = _extract(p2, 'key')
  341. p_plaintext = _extract(p2, 'plaintext')
  342. p_ciphertext = _extract(p2, 'ciphertext')
  343. p_mode = _extract(p2, 'mode')
  344. p_description = _extract(p2, 'description', None)
  345. if p_description is not None:
  346. description = p_description
  347. elif p_mode == 'ECB' and not p2:
  348. description = "p=%s, k=%s" % (p_plaintext, p_key)
  349. else:
  350. description = "p=%s, k=%s, %r" % (p_plaintext, p_key, p2)
  351. name = "%s #%d: %s" % (module_name, i+1, description)
  352. params['description'] = name
  353. params['module_name'] = module_name
  354. params.update(additional_params)
  355. # Add extra test(s) to the test suite before the current test
  356. if not extra_tests_added:
  357. tests += [
  358. RoundtripTest(module, params),
  359. IVLengthTest(module, params),
  360. NoDefaultECBTest(module, params),
  361. ByteArrayTest(module, params),
  362. BlockSizeTest(module, params),
  363. ]
  364. extra_tests_added = True
  365. # Add the current test to the test suite
  366. tests.append(CipherSelfTest(module, params))
  367. return tests
  368. def make_stream_tests(module, module_name, test_data):
  369. tests = []
  370. extra_tests_added = False
  371. for i in range(len(test_data)):
  372. row = test_data[i]
  373. # Build the "params" dictionary
  374. params = {}
  375. if len(row) == 3:
  376. (params['plaintext'], params['ciphertext'], params['key']) = row
  377. elif len(row) == 4:
  378. (params['plaintext'], params['ciphertext'], params['key'], params['description']) = row
  379. elif len(row) == 5:
  380. (params['plaintext'], params['ciphertext'], params['key'], params['description'], extra_params) = row
  381. params.update(extra_params)
  382. else:
  383. raise AssertionError("Unsupported tuple size %d" % (len(row),))
  384. # Build the display-name for the test
  385. p2 = params.copy()
  386. p_key = _extract(p2, 'key')
  387. p_plaintext = _extract(p2, 'plaintext')
  388. p_ciphertext = _extract(p2, 'ciphertext')
  389. p_description = _extract(p2, 'description', None)
  390. if p_description is not None:
  391. description = p_description
  392. elif not p2:
  393. description = "p=%s, k=%s" % (p_plaintext, p_key)
  394. else:
  395. description = "p=%s, k=%s, %r" % (p_plaintext, p_key, p2)
  396. name = "%s #%d: %s" % (module_name, i+1, description)
  397. params['description'] = name
  398. params['module_name'] = module_name
  399. # Add extra test(s) to the test suite before the current test
  400. if not extra_tests_added:
  401. tests += [
  402. ByteArrayTest(module, params),
  403. ]
  404. import sys
  405. if sys.version[:3] != '2.6':
  406. tests.append(MemoryviewTest(module, params))
  407. extra_tests_added = True
  408. # Add the test to the test suite
  409. tests.append(CipherSelfTest(module, params))
  410. tests.append(CipherStreamingSelfTest(module, params))
  411. return tests
  412. # vim:set ts=4 sw=4 sts=4 expandtab: