core.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409
  1. from . import idnadata
  2. import bisect
  3. import unicodedata
  4. import re
  5. from typing import Union, Optional
  6. from .intranges import intranges_contain
  7. _virama_combining_class = 9
  8. _alabel_prefix = b'xn--'
  9. _unicode_dots_re = re.compile('[\u002e\u3002\uff0e\uff61]')
  10. class IDNAError(UnicodeError):
  11. """ Base exception for all IDNA-encoding related problems """
  12. pass
  13. class IDNABidiError(IDNAError):
  14. """ Exception when bidirectional requirements are not satisfied """
  15. pass
  16. class InvalidCodepoint(IDNAError):
  17. """ Exception when a disallowed or unallocated codepoint is used """
  18. pass
  19. class InvalidCodepointContext(IDNAError):
  20. """ Exception when the codepoint is not valid in the context it is used """
  21. pass
  22. def _combining_class(cp):
  23. # type: (int) -> int
  24. v = unicodedata.combining(chr(cp))
  25. if v == 0:
  26. if not unicodedata.name(chr(cp)):
  27. raise ValueError('Unknown character in unicodedata')
  28. return v
  29. def _is_script(cp, script):
  30. # type: (str, str) -> bool
  31. return intranges_contain(ord(cp), idnadata.scripts[script])
  32. def _punycode(s):
  33. # type: (str) -> bytes
  34. return s.encode('punycode')
  35. def _unot(s):
  36. # type: (int) -> str
  37. return 'U+{:04X}'.format(s)
  38. def valid_label_length(label):
  39. # type: (Union[bytes, str]) -> bool
  40. if len(label) > 63:
  41. return False
  42. return True
  43. def valid_string_length(label, trailing_dot):
  44. # type: (Union[bytes, str], bool) -> bool
  45. if len(label) > (254 if trailing_dot else 253):
  46. return False
  47. return True
  48. def check_bidi(label, check_ltr=False):
  49. # type: (str, bool) -> bool
  50. # Bidi rules should only be applied if string contains RTL characters
  51. bidi_label = False
  52. for (idx, cp) in enumerate(label, 1):
  53. direction = unicodedata.bidirectional(cp)
  54. if direction == '':
  55. # String likely comes from a newer version of Unicode
  56. raise IDNABidiError('Unknown directionality in label {} at position {}'.format(repr(label), idx))
  57. if direction in ['R', 'AL', 'AN']:
  58. bidi_label = True
  59. if not bidi_label and not check_ltr:
  60. return True
  61. # Bidi rule 1
  62. direction = unicodedata.bidirectional(label[0])
  63. if direction in ['R', 'AL']:
  64. rtl = True
  65. elif direction == 'L':
  66. rtl = False
  67. else:
  68. raise IDNABidiError('First codepoint in label {} must be directionality L, R or AL'.format(repr(label)))
  69. valid_ending = False
  70. number_type = None # type: Optional[str]
  71. for (idx, cp) in enumerate(label, 1):
  72. direction = unicodedata.bidirectional(cp)
  73. if rtl:
  74. # Bidi rule 2
  75. if not direction in ['R', 'AL', 'AN', 'EN', 'ES', 'CS', 'ET', 'ON', 'BN', 'NSM']:
  76. raise IDNABidiError('Invalid direction for codepoint at position {} in a right-to-left label'.format(idx))
  77. # Bidi rule 3
  78. if direction in ['R', 'AL', 'EN', 'AN']:
  79. valid_ending = True
  80. elif direction != 'NSM':
  81. valid_ending = False
  82. # Bidi rule 4
  83. if direction in ['AN', 'EN']:
  84. if not number_type:
  85. number_type = direction
  86. else:
  87. if number_type != direction:
  88. raise IDNABidiError('Can not mix numeral types in a right-to-left label')
  89. else:
  90. # Bidi rule 5
  91. if not direction in ['L', 'EN', 'ES', 'CS', 'ET', 'ON', 'BN', 'NSM']:
  92. raise IDNABidiError('Invalid direction for codepoint at position {} in a left-to-right label'.format(idx))
  93. # Bidi rule 6
  94. if direction in ['L', 'EN']:
  95. valid_ending = True
  96. elif direction != 'NSM':
  97. valid_ending = False
  98. if not valid_ending:
  99. raise IDNABidiError('Label ends with illegal codepoint directionality')
  100. return True
  101. def check_initial_combiner(label):
  102. # type: (str) -> bool
  103. if unicodedata.category(label[0])[0] == 'M':
  104. raise IDNAError('Label begins with an illegal combining character')
  105. return True
  106. def check_hyphen_ok(label):
  107. # type: (str) -> bool
  108. if label[2:4] == '--':
  109. raise IDNAError('Label has disallowed hyphens in 3rd and 4th position')
  110. if label[0] == '-' or label[-1] == '-':
  111. raise IDNAError('Label must not start or end with a hyphen')
  112. return True
  113. def check_nfc(label):
  114. # type: (str) -> None
  115. if unicodedata.normalize('NFC', label) != label:
  116. raise IDNAError('Label must be in Normalization Form C')
  117. def valid_contextj(label, pos):
  118. # type: (str, int) -> bool
  119. cp_value = ord(label[pos])
  120. if cp_value == 0x200c:
  121. if pos > 0:
  122. if _combining_class(ord(label[pos - 1])) == _virama_combining_class:
  123. return True
  124. ok = False
  125. for i in range(pos-1, -1, -1):
  126. joining_type = idnadata.joining_types.get(ord(label[i]))
  127. if joining_type == ord('T'):
  128. continue
  129. if joining_type in [ord('L'), ord('D')]:
  130. ok = True
  131. break
  132. if not ok:
  133. return False
  134. ok = False
  135. for i in range(pos+1, len(label)):
  136. joining_type = idnadata.joining_types.get(ord(label[i]))
  137. if joining_type == ord('T'):
  138. continue
  139. if joining_type in [ord('R'), ord('D')]:
  140. ok = True
  141. break
  142. return ok
  143. if cp_value == 0x200d:
  144. if pos > 0:
  145. if _combining_class(ord(label[pos - 1])) == _virama_combining_class:
  146. return True
  147. return False
  148. else:
  149. return False
  150. def valid_contexto(label, pos, exception=False):
  151. # type: (str, int, bool) -> bool
  152. cp_value = ord(label[pos])
  153. if cp_value == 0x00b7:
  154. if 0 < pos < len(label)-1:
  155. if ord(label[pos - 1]) == 0x006c and ord(label[pos + 1]) == 0x006c:
  156. return True
  157. return False
  158. elif cp_value == 0x0375:
  159. if pos < len(label)-1 and len(label) > 1:
  160. return _is_script(label[pos + 1], 'Greek')
  161. return False
  162. elif cp_value == 0x05f3 or cp_value == 0x05f4:
  163. if pos > 0:
  164. return _is_script(label[pos - 1], 'Hebrew')
  165. return False
  166. elif cp_value == 0x30fb:
  167. for cp in label:
  168. if cp == '\u30fb':
  169. continue
  170. if _is_script(cp, 'Hiragana') or _is_script(cp, 'Katakana') or _is_script(cp, 'Han'):
  171. return True
  172. return False
  173. elif 0x660 <= cp_value <= 0x669:
  174. for cp in label:
  175. if 0x6f0 <= ord(cp) <= 0x06f9:
  176. return False
  177. return True
  178. elif 0x6f0 <= cp_value <= 0x6f9:
  179. for cp in label:
  180. if 0x660 <= ord(cp) <= 0x0669:
  181. return False
  182. return True
  183. return False
  184. def check_label(label):
  185. # type: (Union[str, bytes, bytearray]) -> None
  186. if isinstance(label, (bytes, bytearray)):
  187. label = label.decode('utf-8')
  188. if len(label) == 0:
  189. raise IDNAError('Empty Label')
  190. check_nfc(label)
  191. check_hyphen_ok(label)
  192. check_initial_combiner(label)
  193. for (pos, cp) in enumerate(label):
  194. cp_value = ord(cp)
  195. if intranges_contain(cp_value, idnadata.codepoint_classes['PVALID']):
  196. continue
  197. elif intranges_contain(cp_value, idnadata.codepoint_classes['CONTEXTJ']):
  198. try:
  199. if not valid_contextj(label, pos):
  200. raise InvalidCodepointContext('Joiner {} not allowed at position {} in {}'.format(
  201. _unot(cp_value), pos+1, repr(label)))
  202. except ValueError:
  203. raise IDNAError('Unknown codepoint adjacent to joiner {} at position {} in {}'.format(
  204. _unot(cp_value), pos+1, repr(label)))
  205. elif intranges_contain(cp_value, idnadata.codepoint_classes['CONTEXTO']):
  206. if not valid_contexto(label, pos):
  207. raise InvalidCodepointContext('Codepoint {} not allowed at position {} in {}'.format(_unot(cp_value), pos+1, repr(label)))
  208. else:
  209. raise InvalidCodepoint('Codepoint {} at position {} of {} not allowed'.format(_unot(cp_value), pos+1, repr(label)))
  210. check_bidi(label)
  211. def alabel(label):
  212. # type: (str) -> bytes
  213. try:
  214. label_bytes = label.encode('ascii')
  215. ulabel(label_bytes)
  216. if not valid_label_length(label_bytes):
  217. raise IDNAError('Label too long')
  218. return label_bytes
  219. except UnicodeEncodeError:
  220. pass
  221. if not label:
  222. raise IDNAError('No Input')
  223. label = str(label)
  224. check_label(label)
  225. label_bytes = _punycode(label)
  226. label_bytes = _alabel_prefix + label_bytes
  227. if not valid_label_length(label_bytes):
  228. raise IDNAError('Label too long')
  229. return label_bytes
  230. def ulabel(label):
  231. # type: (Union[str, bytes, bytearray]) -> str
  232. if not isinstance(label, (bytes, bytearray)):
  233. try:
  234. label_bytes = label.encode('ascii')
  235. except UnicodeEncodeError:
  236. check_label(label)
  237. return label
  238. else:
  239. label_bytes = label
  240. label_bytes = label_bytes.lower()
  241. if label_bytes.startswith(_alabel_prefix):
  242. label_bytes = label_bytes[len(_alabel_prefix):]
  243. if not label_bytes:
  244. raise IDNAError('Malformed A-label, no Punycode eligible content found')
  245. if label_bytes.decode('ascii')[-1] == '-':
  246. raise IDNAError('A-label must not end with a hyphen')
  247. else:
  248. check_label(label_bytes)
  249. return label_bytes.decode('ascii')
  250. label = label_bytes.decode('punycode')
  251. check_label(label)
  252. return label
  253. def uts46_remap(domain, std3_rules=True, transitional=False):
  254. # type: (str, bool, bool) -> str
  255. """Re-map the characters in the string according to UTS46 processing."""
  256. from .uts46data import uts46data
  257. output = ''
  258. for pos, char in enumerate(domain):
  259. code_point = ord(char)
  260. try:
  261. uts46row = uts46data[code_point if code_point < 256 else
  262. bisect.bisect_left(uts46data, (code_point, 'Z')) - 1]
  263. status = uts46row[1]
  264. replacement = None # type: Optional[str]
  265. if len(uts46row) == 3:
  266. replacement = uts46row[2] # type: ignore
  267. if (status == 'V' or
  268. (status == 'D' and not transitional) or
  269. (status == '3' and not std3_rules and replacement is None)):
  270. output += char
  271. elif replacement is not None and (status == 'M' or
  272. (status == '3' and not std3_rules) or
  273. (status == 'D' and transitional)):
  274. output += replacement
  275. elif status != 'I':
  276. raise IndexError()
  277. except IndexError:
  278. raise InvalidCodepoint(
  279. 'Codepoint {} not allowed at position {} in {}'.format(
  280. _unot(code_point), pos + 1, repr(domain)))
  281. return unicodedata.normalize('NFC', output)
  282. def encode(s, strict=False, uts46=False, std3_rules=False, transitional=False):
  283. # type: (Union[str, bytes, bytearray], bool, bool, bool, bool) -> bytes
  284. if isinstance(s, (bytes, bytearray)):
  285. s = s.decode('ascii')
  286. if uts46:
  287. s = uts46_remap(s, std3_rules, transitional)
  288. trailing_dot = False
  289. result = []
  290. if strict:
  291. labels = s.split('.')
  292. else:
  293. labels = _unicode_dots_re.split(s)
  294. if not labels or labels == ['']:
  295. raise IDNAError('Empty domain')
  296. if labels[-1] == '':
  297. del labels[-1]
  298. trailing_dot = True
  299. for label in labels:
  300. s = alabel(label)
  301. if s:
  302. result.append(s)
  303. else:
  304. raise IDNAError('Empty label')
  305. if trailing_dot:
  306. result.append(b'')
  307. s = b'.'.join(result)
  308. if not valid_string_length(s, trailing_dot):
  309. raise IDNAError('Domain too long')
  310. return s
  311. def decode(s, strict=False, uts46=False, std3_rules=False):
  312. # type: (Union[str, bytes, bytearray], bool, bool, bool) -> str
  313. if isinstance(s, (bytes, bytearray)):
  314. s = s.decode('ascii')
  315. if uts46:
  316. s = uts46_remap(s, std3_rules, False)
  317. trailing_dot = False
  318. result = []
  319. if not strict:
  320. labels = _unicode_dots_re.split(s)
  321. else:
  322. labels = s.split('.')
  323. if not labels or labels == ['']:
  324. raise IDNAError('Empty domain')
  325. if not labels[-1]:
  326. del labels[-1]
  327. trailing_dot = True
  328. for label in labels:
  329. s = ulabel(label)
  330. if s:
  331. result.append(s)
  332. else:
  333. raise IDNAError('Empty label')
  334. if trailing_dot:
  335. result.append('')
  336. return '.'.join(result)