ssh.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707
  1. # This file is dual licensed under the terms of the Apache License, Version
  2. # 2.0, and the BSD License. See the LICENSE file in the root of this repository
  3. # for complete details.
  4. import binascii
  5. import os
  6. import re
  7. import struct
  8. import typing
  9. from base64 import encodebytes as _base64_encode
  10. from cryptography import utils
  11. from cryptography.exceptions import UnsupportedAlgorithm
  12. from cryptography.hazmat.backends import _get_backend
  13. from cryptography.hazmat.primitives.asymmetric import dsa, ec, ed25519, rsa
  14. from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
  15. from cryptography.hazmat.primitives.serialization import (
  16. Encoding,
  17. NoEncryption,
  18. PrivateFormat,
  19. PublicFormat,
  20. )
  21. try:
  22. from bcrypt import kdf as _bcrypt_kdf
  23. _bcrypt_supported = True
  24. except ImportError:
  25. _bcrypt_supported = False
  26. def _bcrypt_kdf(
  27. password: bytes,
  28. salt: bytes,
  29. desired_key_bytes: int,
  30. rounds: int,
  31. ignore_few_rounds: bool = False,
  32. ) -> bytes:
  33. raise UnsupportedAlgorithm("Need bcrypt module")
  34. _SSH_ED25519 = b"ssh-ed25519"
  35. _SSH_RSA = b"ssh-rsa"
  36. _SSH_DSA = b"ssh-dss"
  37. _ECDSA_NISTP256 = b"ecdsa-sha2-nistp256"
  38. _ECDSA_NISTP384 = b"ecdsa-sha2-nistp384"
  39. _ECDSA_NISTP521 = b"ecdsa-sha2-nistp521"
  40. _CERT_SUFFIX = b"-cert-v01@openssh.com"
  41. _SSH_PUBKEY_RC = re.compile(br"\A(\S+)[ \t]+(\S+)")
  42. _SK_MAGIC = b"openssh-key-v1\0"
  43. _SK_START = b"-----BEGIN OPENSSH PRIVATE KEY-----"
  44. _SK_END = b"-----END OPENSSH PRIVATE KEY-----"
  45. _BCRYPT = b"bcrypt"
  46. _NONE = b"none"
  47. _DEFAULT_CIPHER = b"aes256-ctr"
  48. _DEFAULT_ROUNDS = 16
  49. _MAX_PASSWORD = 72
  50. # re is only way to work on bytes-like data
  51. _PEM_RC = re.compile(_SK_START + b"(.*?)" + _SK_END, re.DOTALL)
  52. # padding for max blocksize
  53. _PADDING = memoryview(bytearray(range(1, 1 + 16)))
  54. # ciphers that are actually used in key wrapping
  55. _SSH_CIPHERS = {
  56. b"aes256-ctr": (algorithms.AES, 32, modes.CTR, 16),
  57. b"aes256-cbc": (algorithms.AES, 32, modes.CBC, 16),
  58. }
  59. # map local curve name to key type
  60. _ECDSA_KEY_TYPE = {
  61. "secp256r1": _ECDSA_NISTP256,
  62. "secp384r1": _ECDSA_NISTP384,
  63. "secp521r1": _ECDSA_NISTP521,
  64. }
  65. _U32 = struct.Struct(b">I")
  66. _U64 = struct.Struct(b">Q")
  67. def _ecdsa_key_type(public_key):
  68. """Return SSH key_type and curve_name for private key."""
  69. curve = public_key.curve
  70. if curve.name not in _ECDSA_KEY_TYPE:
  71. raise ValueError(
  72. "Unsupported curve for ssh private key: %r" % curve.name
  73. )
  74. return _ECDSA_KEY_TYPE[curve.name]
  75. def _ssh_pem_encode(data, prefix=_SK_START + b"\n", suffix=_SK_END + b"\n"):
  76. return b"".join([prefix, _base64_encode(data), suffix])
  77. def _check_block_size(data, block_len):
  78. """Require data to be full blocks"""
  79. if not data or len(data) % block_len != 0:
  80. raise ValueError("Corrupt data: missing padding")
  81. def _check_empty(data):
  82. """All data should have been parsed."""
  83. if data:
  84. raise ValueError("Corrupt data: unparsed data")
  85. def _init_cipher(ciphername, password, salt, rounds, backend):
  86. """Generate key + iv and return cipher."""
  87. if not password:
  88. raise ValueError("Key is password-protected.")
  89. algo, key_len, mode, iv_len = _SSH_CIPHERS[ciphername]
  90. seed = _bcrypt_kdf(password, salt, key_len + iv_len, rounds, True)
  91. return Cipher(algo(seed[:key_len]), mode(seed[key_len:]), backend)
  92. def _get_u32(data):
  93. """Uint32"""
  94. if len(data) < 4:
  95. raise ValueError("Invalid data")
  96. return _U32.unpack(data[:4])[0], data[4:]
  97. def _get_u64(data):
  98. """Uint64"""
  99. if len(data) < 8:
  100. raise ValueError("Invalid data")
  101. return _U64.unpack(data[:8])[0], data[8:]
  102. def _get_sshstr(data):
  103. """Bytes with u32 length prefix"""
  104. n, data = _get_u32(data)
  105. if n > len(data):
  106. raise ValueError("Invalid data")
  107. return data[:n], data[n:]
  108. def _get_mpint(data):
  109. """Big integer."""
  110. val, data = _get_sshstr(data)
  111. if val and val[0] > 0x7F:
  112. raise ValueError("Invalid data")
  113. return int.from_bytes(val, "big"), data
  114. def _to_mpint(val):
  115. """Storage format for signed bigint."""
  116. if val < 0:
  117. raise ValueError("negative mpint not allowed")
  118. if not val:
  119. return b""
  120. nbytes = (val.bit_length() + 8) // 8
  121. return utils.int_to_bytes(val, nbytes)
  122. class _FragList(object):
  123. """Build recursive structure without data copy."""
  124. def __init__(self, init=None):
  125. self.flist = []
  126. if init:
  127. self.flist.extend(init)
  128. def put_raw(self, val):
  129. """Add plain bytes"""
  130. self.flist.append(val)
  131. def put_u32(self, val):
  132. """Big-endian uint32"""
  133. self.flist.append(_U32.pack(val))
  134. def put_sshstr(self, val):
  135. """Bytes prefixed with u32 length"""
  136. if isinstance(val, (bytes, memoryview, bytearray)):
  137. self.put_u32(len(val))
  138. self.flist.append(val)
  139. else:
  140. self.put_u32(val.size())
  141. self.flist.extend(val.flist)
  142. def put_mpint(self, val):
  143. """Big-endian bigint prefixed with u32 length"""
  144. self.put_sshstr(_to_mpint(val))
  145. def size(self):
  146. """Current number of bytes"""
  147. return sum(map(len, self.flist))
  148. def render(self, dstbuf, pos=0):
  149. """Write into bytearray"""
  150. for frag in self.flist:
  151. flen = len(frag)
  152. start, pos = pos, pos + flen
  153. dstbuf[start:pos] = frag
  154. return pos
  155. def tobytes(self):
  156. """Return as bytes"""
  157. buf = memoryview(bytearray(self.size()))
  158. self.render(buf)
  159. return buf.tobytes()
  160. class _SSHFormatRSA(object):
  161. """Format for RSA keys.
  162. Public:
  163. mpint e, n
  164. Private:
  165. mpint n, e, d, iqmp, p, q
  166. """
  167. def get_public(self, data):
  168. """RSA public fields"""
  169. e, data = _get_mpint(data)
  170. n, data = _get_mpint(data)
  171. return (e, n), data
  172. def load_public(self, key_type, data, backend):
  173. """Make RSA public key from data."""
  174. (e, n), data = self.get_public(data)
  175. public_numbers = rsa.RSAPublicNumbers(e, n)
  176. public_key = public_numbers.public_key(backend)
  177. return public_key, data
  178. def load_private(self, data, pubfields, backend):
  179. """Make RSA private key from data."""
  180. n, data = _get_mpint(data)
  181. e, data = _get_mpint(data)
  182. d, data = _get_mpint(data)
  183. iqmp, data = _get_mpint(data)
  184. p, data = _get_mpint(data)
  185. q, data = _get_mpint(data)
  186. if (e, n) != pubfields:
  187. raise ValueError("Corrupt data: rsa field mismatch")
  188. dmp1 = rsa.rsa_crt_dmp1(d, p)
  189. dmq1 = rsa.rsa_crt_dmq1(d, q)
  190. public_numbers = rsa.RSAPublicNumbers(e, n)
  191. private_numbers = rsa.RSAPrivateNumbers(
  192. p, q, d, dmp1, dmq1, iqmp, public_numbers
  193. )
  194. private_key = private_numbers.private_key(backend)
  195. return private_key, data
  196. def encode_public(self, public_key, f_pub):
  197. """Write RSA public key"""
  198. pubn = public_key.public_numbers()
  199. f_pub.put_mpint(pubn.e)
  200. f_pub.put_mpint(pubn.n)
  201. def encode_private(self, private_key, f_priv):
  202. """Write RSA private key"""
  203. private_numbers = private_key.private_numbers()
  204. public_numbers = private_numbers.public_numbers
  205. f_priv.put_mpint(public_numbers.n)
  206. f_priv.put_mpint(public_numbers.e)
  207. f_priv.put_mpint(private_numbers.d)
  208. f_priv.put_mpint(private_numbers.iqmp)
  209. f_priv.put_mpint(private_numbers.p)
  210. f_priv.put_mpint(private_numbers.q)
  211. class _SSHFormatDSA(object):
  212. """Format for DSA keys.
  213. Public:
  214. mpint p, q, g, y
  215. Private:
  216. mpint p, q, g, y, x
  217. """
  218. def get_public(self, data):
  219. """DSA public fields"""
  220. p, data = _get_mpint(data)
  221. q, data = _get_mpint(data)
  222. g, data = _get_mpint(data)
  223. y, data = _get_mpint(data)
  224. return (p, q, g, y), data
  225. def load_public(self, key_type, data, backend):
  226. """Make DSA public key from data."""
  227. (p, q, g, y), data = self.get_public(data)
  228. parameter_numbers = dsa.DSAParameterNumbers(p, q, g)
  229. public_numbers = dsa.DSAPublicNumbers(y, parameter_numbers)
  230. self._validate(public_numbers)
  231. public_key = public_numbers.public_key(backend)
  232. return public_key, data
  233. def load_private(self, data, pubfields, backend):
  234. """Make DSA private key from data."""
  235. (p, q, g, y), data = self.get_public(data)
  236. x, data = _get_mpint(data)
  237. if (p, q, g, y) != pubfields:
  238. raise ValueError("Corrupt data: dsa field mismatch")
  239. parameter_numbers = dsa.DSAParameterNumbers(p, q, g)
  240. public_numbers = dsa.DSAPublicNumbers(y, parameter_numbers)
  241. self._validate(public_numbers)
  242. private_numbers = dsa.DSAPrivateNumbers(x, public_numbers)
  243. private_key = private_numbers.private_key(backend)
  244. return private_key, data
  245. def encode_public(self, public_key, f_pub):
  246. """Write DSA public key"""
  247. public_numbers = public_key.public_numbers()
  248. parameter_numbers = public_numbers.parameter_numbers
  249. self._validate(public_numbers)
  250. f_pub.put_mpint(parameter_numbers.p)
  251. f_pub.put_mpint(parameter_numbers.q)
  252. f_pub.put_mpint(parameter_numbers.g)
  253. f_pub.put_mpint(public_numbers.y)
  254. def encode_private(self, private_key, f_priv):
  255. """Write DSA private key"""
  256. self.encode_public(private_key.public_key(), f_priv)
  257. f_priv.put_mpint(private_key.private_numbers().x)
  258. def _validate(self, public_numbers):
  259. parameter_numbers = public_numbers.parameter_numbers
  260. if parameter_numbers.p.bit_length() != 1024:
  261. raise ValueError("SSH supports only 1024 bit DSA keys")
  262. class _SSHFormatECDSA(object):
  263. """Format for ECDSA keys.
  264. Public:
  265. str curve
  266. bytes point
  267. Private:
  268. str curve
  269. bytes point
  270. mpint secret
  271. """
  272. def __init__(self, ssh_curve_name, curve):
  273. self.ssh_curve_name = ssh_curve_name
  274. self.curve = curve
  275. def get_public(self, data):
  276. """ECDSA public fields"""
  277. curve, data = _get_sshstr(data)
  278. point, data = _get_sshstr(data)
  279. if curve != self.ssh_curve_name:
  280. raise ValueError("Curve name mismatch")
  281. if point[0] != 4:
  282. raise NotImplementedError("Need uncompressed point")
  283. return (curve, point), data
  284. def load_public(self, key_type, data, backend):
  285. """Make ECDSA public key from data."""
  286. (curve_name, point), data = self.get_public(data)
  287. public_key = ec.EllipticCurvePublicKey.from_encoded_point(
  288. self.curve, point.tobytes()
  289. )
  290. return public_key, data
  291. def load_private(self, data, pubfields, backend):
  292. """Make ECDSA private key from data."""
  293. (curve_name, point), data = self.get_public(data)
  294. secret, data = _get_mpint(data)
  295. if (curve_name, point) != pubfields:
  296. raise ValueError("Corrupt data: ecdsa field mismatch")
  297. private_key = ec.derive_private_key(secret, self.curve, backend)
  298. return private_key, data
  299. def encode_public(self, public_key, f_pub):
  300. """Write ECDSA public key"""
  301. point = public_key.public_bytes(
  302. Encoding.X962, PublicFormat.UncompressedPoint
  303. )
  304. f_pub.put_sshstr(self.ssh_curve_name)
  305. f_pub.put_sshstr(point)
  306. def encode_private(self, private_key, f_priv):
  307. """Write ECDSA private key"""
  308. public_key = private_key.public_key()
  309. private_numbers = private_key.private_numbers()
  310. self.encode_public(public_key, f_priv)
  311. f_priv.put_mpint(private_numbers.private_value)
  312. class _SSHFormatEd25519(object):
  313. """Format for Ed25519 keys.
  314. Public:
  315. bytes point
  316. Private:
  317. bytes point
  318. bytes secret_and_point
  319. """
  320. def get_public(self, data):
  321. """Ed25519 public fields"""
  322. point, data = _get_sshstr(data)
  323. return (point,), data
  324. def load_public(self, key_type, data, backend):
  325. """Make Ed25519 public key from data."""
  326. (point,), data = self.get_public(data)
  327. public_key = ed25519.Ed25519PublicKey.from_public_bytes(
  328. point.tobytes()
  329. )
  330. return public_key, data
  331. def load_private(self, data, pubfields, backend):
  332. """Make Ed25519 private key from data."""
  333. (point,), data = self.get_public(data)
  334. keypair, data = _get_sshstr(data)
  335. secret = keypair[:32]
  336. point2 = keypair[32:]
  337. if point != point2 or (point,) != pubfields:
  338. raise ValueError("Corrupt data: ed25519 field mismatch")
  339. private_key = ed25519.Ed25519PrivateKey.from_private_bytes(secret)
  340. return private_key, data
  341. def encode_public(self, public_key, f_pub):
  342. """Write Ed25519 public key"""
  343. raw_public_key = public_key.public_bytes(
  344. Encoding.Raw, PublicFormat.Raw
  345. )
  346. f_pub.put_sshstr(raw_public_key)
  347. def encode_private(self, private_key, f_priv):
  348. """Write Ed25519 private key"""
  349. public_key = private_key.public_key()
  350. raw_private_key = private_key.private_bytes(
  351. Encoding.Raw, PrivateFormat.Raw, NoEncryption()
  352. )
  353. raw_public_key = public_key.public_bytes(
  354. Encoding.Raw, PublicFormat.Raw
  355. )
  356. f_keypair = _FragList([raw_private_key, raw_public_key])
  357. self.encode_public(public_key, f_priv)
  358. f_priv.put_sshstr(f_keypair)
  359. _KEY_FORMATS = {
  360. _SSH_RSA: _SSHFormatRSA(),
  361. _SSH_DSA: _SSHFormatDSA(),
  362. _SSH_ED25519: _SSHFormatEd25519(),
  363. _ECDSA_NISTP256: _SSHFormatECDSA(b"nistp256", ec.SECP256R1()),
  364. _ECDSA_NISTP384: _SSHFormatECDSA(b"nistp384", ec.SECP384R1()),
  365. _ECDSA_NISTP521: _SSHFormatECDSA(b"nistp521", ec.SECP521R1()),
  366. }
  367. def _lookup_kformat(key_type):
  368. """Return valid format or throw error"""
  369. if not isinstance(key_type, bytes):
  370. key_type = memoryview(key_type).tobytes()
  371. if key_type in _KEY_FORMATS:
  372. return _KEY_FORMATS[key_type]
  373. raise UnsupportedAlgorithm("Unsupported key type: %r" % key_type)
  374. _SSH_PRIVATE_KEY_TYPES = typing.Union[
  375. ec.EllipticCurvePrivateKey,
  376. rsa.RSAPrivateKey,
  377. dsa.DSAPrivateKey,
  378. ed25519.Ed25519PrivateKey,
  379. ]
  380. def load_ssh_private_key(
  381. data: bytes, password: typing.Optional[bytes], backend=None
  382. ) -> _SSH_PRIVATE_KEY_TYPES:
  383. """Load private key from OpenSSH custom encoding."""
  384. utils._check_byteslike("data", data)
  385. backend = _get_backend(backend)
  386. if password is not None:
  387. utils._check_bytes("password", password)
  388. m = _PEM_RC.search(data)
  389. if not m:
  390. raise ValueError("Not OpenSSH private key format")
  391. p1 = m.start(1)
  392. p2 = m.end(1)
  393. data = binascii.a2b_base64(memoryview(data)[p1:p2])
  394. if not data.startswith(_SK_MAGIC):
  395. raise ValueError("Not OpenSSH private key format")
  396. data = memoryview(data)[len(_SK_MAGIC) :]
  397. # parse header
  398. ciphername, data = _get_sshstr(data)
  399. kdfname, data = _get_sshstr(data)
  400. kdfoptions, data = _get_sshstr(data)
  401. nkeys, data = _get_u32(data)
  402. if nkeys != 1:
  403. raise ValueError("Only one key supported")
  404. # load public key data
  405. pubdata, data = _get_sshstr(data)
  406. pub_key_type, pubdata = _get_sshstr(pubdata)
  407. kformat = _lookup_kformat(pub_key_type)
  408. pubfields, pubdata = kformat.get_public(pubdata)
  409. _check_empty(pubdata)
  410. # load secret data
  411. edata, data = _get_sshstr(data)
  412. _check_empty(data)
  413. if (ciphername, kdfname) != (_NONE, _NONE):
  414. ciphername = ciphername.tobytes()
  415. if ciphername not in _SSH_CIPHERS:
  416. raise UnsupportedAlgorithm("Unsupported cipher: %r" % ciphername)
  417. if kdfname != _BCRYPT:
  418. raise UnsupportedAlgorithm("Unsupported KDF: %r" % kdfname)
  419. blklen = _SSH_CIPHERS[ciphername][3]
  420. _check_block_size(edata, blklen)
  421. salt, kbuf = _get_sshstr(kdfoptions)
  422. rounds, kbuf = _get_u32(kbuf)
  423. _check_empty(kbuf)
  424. ciph = _init_cipher(
  425. ciphername, password, salt.tobytes(), rounds, backend
  426. )
  427. edata = memoryview(ciph.decryptor().update(edata))
  428. else:
  429. blklen = 8
  430. _check_block_size(edata, blklen)
  431. ck1, edata = _get_u32(edata)
  432. ck2, edata = _get_u32(edata)
  433. if ck1 != ck2:
  434. raise ValueError("Corrupt data: broken checksum")
  435. # load per-key struct
  436. key_type, edata = _get_sshstr(edata)
  437. if key_type != pub_key_type:
  438. raise ValueError("Corrupt data: key type mismatch")
  439. private_key, edata = kformat.load_private(edata, pubfields, backend)
  440. comment, edata = _get_sshstr(edata)
  441. # yes, SSH does padding check *after* all other parsing is done.
  442. # need to follow as it writes zero-byte padding too.
  443. if edata != _PADDING[: len(edata)]:
  444. raise ValueError("Corrupt data: invalid padding")
  445. return private_key
  446. def serialize_ssh_private_key(
  447. private_key: _SSH_PRIVATE_KEY_TYPES,
  448. password: typing.Optional[bytes] = None,
  449. ):
  450. """Serialize private key with OpenSSH custom encoding."""
  451. if password is not None:
  452. utils._check_bytes("password", password)
  453. if password and len(password) > _MAX_PASSWORD:
  454. raise ValueError(
  455. "Passwords longer than 72 bytes are not supported by "
  456. "OpenSSH private key format"
  457. )
  458. if isinstance(private_key, ec.EllipticCurvePrivateKey):
  459. key_type = _ecdsa_key_type(private_key.public_key())
  460. elif isinstance(private_key, rsa.RSAPrivateKey):
  461. key_type = _SSH_RSA
  462. elif isinstance(private_key, dsa.DSAPrivateKey):
  463. key_type = _SSH_DSA
  464. elif isinstance(private_key, ed25519.Ed25519PrivateKey):
  465. key_type = _SSH_ED25519
  466. else:
  467. raise ValueError("Unsupported key type")
  468. kformat = _lookup_kformat(key_type)
  469. # setup parameters
  470. f_kdfoptions = _FragList()
  471. if password:
  472. ciphername = _DEFAULT_CIPHER
  473. blklen = _SSH_CIPHERS[ciphername][3]
  474. kdfname = _BCRYPT
  475. rounds = _DEFAULT_ROUNDS
  476. salt = os.urandom(16)
  477. f_kdfoptions.put_sshstr(salt)
  478. f_kdfoptions.put_u32(rounds)
  479. backend = _get_backend(None)
  480. ciph = _init_cipher(ciphername, password, salt, rounds, backend)
  481. else:
  482. ciphername = kdfname = _NONE
  483. blklen = 8
  484. ciph = None
  485. nkeys = 1
  486. checkval = os.urandom(4)
  487. comment = b""
  488. # encode public and private parts together
  489. f_public_key = _FragList()
  490. f_public_key.put_sshstr(key_type)
  491. kformat.encode_public(private_key.public_key(), f_public_key)
  492. f_secrets = _FragList([checkval, checkval])
  493. f_secrets.put_sshstr(key_type)
  494. kformat.encode_private(private_key, f_secrets)
  495. f_secrets.put_sshstr(comment)
  496. f_secrets.put_raw(_PADDING[: blklen - (f_secrets.size() % blklen)])
  497. # top-level structure
  498. f_main = _FragList()
  499. f_main.put_raw(_SK_MAGIC)
  500. f_main.put_sshstr(ciphername)
  501. f_main.put_sshstr(kdfname)
  502. f_main.put_sshstr(f_kdfoptions)
  503. f_main.put_u32(nkeys)
  504. f_main.put_sshstr(f_public_key)
  505. f_main.put_sshstr(f_secrets)
  506. # copy result info bytearray
  507. slen = f_secrets.size()
  508. mlen = f_main.size()
  509. buf = memoryview(bytearray(mlen + blklen))
  510. f_main.render(buf)
  511. ofs = mlen - slen
  512. # encrypt in-place
  513. if ciph is not None:
  514. ciph.encryptor().update_into(buf[ofs:mlen], buf[ofs:])
  515. txt = _ssh_pem_encode(buf[:mlen])
  516. # Ignore the following type because mypy wants
  517. # Sequence[bytes] but what we're passing is fine.
  518. # https://github.com/python/mypy/issues/9999
  519. buf[ofs:mlen] = bytearray(slen) # type: ignore
  520. return txt
  521. _SSH_PUBLIC_KEY_TYPES = typing.Union[
  522. ec.EllipticCurvePublicKey,
  523. rsa.RSAPublicKey,
  524. dsa.DSAPublicKey,
  525. ed25519.Ed25519PublicKey,
  526. ]
  527. def load_ssh_public_key(data: bytes, backend=None) -> _SSH_PUBLIC_KEY_TYPES:
  528. """Load public key from OpenSSH one-line format."""
  529. backend = _get_backend(backend)
  530. utils._check_byteslike("data", data)
  531. m = _SSH_PUBKEY_RC.match(data)
  532. if not m:
  533. raise ValueError("Invalid line format")
  534. key_type = orig_key_type = m.group(1)
  535. key_body = m.group(2)
  536. with_cert = False
  537. if _CERT_SUFFIX == key_type[-len(_CERT_SUFFIX) :]:
  538. with_cert = True
  539. key_type = key_type[: -len(_CERT_SUFFIX)]
  540. kformat = _lookup_kformat(key_type)
  541. try:
  542. data = memoryview(binascii.a2b_base64(key_body))
  543. except (TypeError, binascii.Error):
  544. raise ValueError("Invalid key format")
  545. inner_key_type, data = _get_sshstr(data)
  546. if inner_key_type != orig_key_type:
  547. raise ValueError("Invalid key format")
  548. if with_cert:
  549. nonce, data = _get_sshstr(data)
  550. public_key, data = kformat.load_public(key_type, data, backend)
  551. if with_cert:
  552. serial, data = _get_u64(data)
  553. cctype, data = _get_u32(data)
  554. key_id, data = _get_sshstr(data)
  555. principals, data = _get_sshstr(data)
  556. valid_after, data = _get_u64(data)
  557. valid_before, data = _get_u64(data)
  558. crit_options, data = _get_sshstr(data)
  559. extensions, data = _get_sshstr(data)
  560. reserved, data = _get_sshstr(data)
  561. sig_key, data = _get_sshstr(data)
  562. signature, data = _get_sshstr(data)
  563. _check_empty(data)
  564. return public_key
  565. def serialize_ssh_public_key(public_key: _SSH_PUBLIC_KEY_TYPES) -> bytes:
  566. """One-line public key format for OpenSSH"""
  567. if isinstance(public_key, ec.EllipticCurvePublicKey):
  568. key_type = _ecdsa_key_type(public_key)
  569. elif isinstance(public_key, rsa.RSAPublicKey):
  570. key_type = _SSH_RSA
  571. elif isinstance(public_key, dsa.DSAPublicKey):
  572. key_type = _SSH_DSA
  573. elif isinstance(public_key, ed25519.Ed25519PublicKey):
  574. key_type = _SSH_ED25519
  575. else:
  576. raise ValueError("Unsupported key type")
  577. kformat = _lookup_kformat(key_type)
  578. f_pub = _FragList()
  579. f_pub.put_sshstr(key_type)
  580. kformat.encode_public(public_key, f_pub)
  581. pub = binascii.b2a_base64(f_pub.tobytes()).strip()
  582. return b"".join([key_type, b" ", pub])