basic_util.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342
  1. # -*- coding: utf-8 -*-
  2. """基本方法
  3. 创建中文数字系统 方法
  4. 中文字符串 <=> 数字串 方法
  5. 数字串 <=> 中文字符串 方法
  6. """
  7. __author__ = "Zhiyang Zhou <zyzhou@stu.xmu.edu.cn>"
  8. __data__ = "2019-05-02"
  9. from fish_speech.text.chn_text_norm.basic_class import *
  10. from fish_speech.text.chn_text_norm.basic_constant import *
  11. def create_system(numbering_type=NUMBERING_TYPES[1]):
  12. """
  13. 根据数字系统类型返回创建相应的数字系统,默认为 mid
  14. NUMBERING_TYPES = ['low', 'mid', 'high']: 中文数字系统类型
  15. low: '兆' = '亿' * '十' = $10^{9}$, '京' = '兆' * '十', etc.
  16. mid: '兆' = '亿' * '万' = $10^{12}$, '京' = '兆' * '万', etc.
  17. high: '兆' = '亿' * '亿' = $10^{16}$, '京' = '兆' * '兆', etc.
  18. 返回对应的数字系统
  19. """
  20. # chinese number units of '亿' and larger
  21. all_larger_units = zip(
  22. LARGER_CHINESE_NUMERING_UNITS_SIMPLIFIED,
  23. LARGER_CHINESE_NUMERING_UNITS_TRADITIONAL,
  24. )
  25. larger_units = [
  26. CNU.create(i, v, numbering_type, False) for i, v in enumerate(all_larger_units)
  27. ]
  28. # chinese number units of '十, 百, 千, 万'
  29. all_smaller_units = zip(
  30. SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED,
  31. SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL,
  32. )
  33. smaller_units = [
  34. CNU.create(i, v, small_unit=True) for i, v in enumerate(all_smaller_units)
  35. ]
  36. # digis
  37. chinese_digis = zip(
  38. CHINESE_DIGIS,
  39. CHINESE_DIGIS,
  40. BIG_CHINESE_DIGIS_SIMPLIFIED,
  41. BIG_CHINESE_DIGIS_TRADITIONAL,
  42. )
  43. digits = [CND.create(i, v) for i, v in enumerate(chinese_digis)]
  44. digits[0].alt_s, digits[0].alt_t = ZERO_ALT, ZERO_ALT
  45. digits[1].alt_s, digits[1].alt_t = ONE_ALT, ONE_ALT
  46. digits[2].alt_s, digits[2].alt_t = TWO_ALTS[0], TWO_ALTS[1]
  47. # symbols
  48. positive_cn = CM(POSITIVE[0], POSITIVE[1], "+", lambda x: x)
  49. negative_cn = CM(NEGATIVE[0], NEGATIVE[1], "-", lambda x: -x)
  50. point_cn = CM(POINT[0], POINT[1], ".", lambda x, y: float(str(x) + "." + str(y)))
  51. # sil_cn = CM(SIL[0], SIL[1], '-', lambda x, y: float(str(x) + '-' + str(y)))
  52. system = NumberSystem()
  53. system.units = smaller_units + larger_units
  54. system.digits = digits
  55. system.math = MathSymbol(positive_cn, negative_cn, point_cn)
  56. # system.symbols = OtherSymbol(sil_cn)
  57. return system
  58. def chn2num(chinese_string, numbering_type=NUMBERING_TYPES[1]):
  59. def get_symbol(char, system):
  60. for u in system.units:
  61. if char in [u.traditional, u.simplified, u.big_s, u.big_t]:
  62. return u
  63. for d in system.digits:
  64. if char in [
  65. d.traditional,
  66. d.simplified,
  67. d.big_s,
  68. d.big_t,
  69. d.alt_s,
  70. d.alt_t,
  71. ]:
  72. return d
  73. for m in system.math:
  74. if char in [m.traditional, m.simplified]:
  75. return m
  76. def string2symbols(chinese_string, system):
  77. int_string, dec_string = chinese_string, ""
  78. for p in [system.math.point.simplified, system.math.point.traditional]:
  79. if p in chinese_string:
  80. int_string, dec_string = chinese_string.split(p)
  81. break
  82. return [get_symbol(c, system) for c in int_string], [
  83. get_symbol(c, system) for c in dec_string
  84. ]
  85. def correct_symbols(integer_symbols, system):
  86. """
  87. 一百八 to 一百八十
  88. 一亿一千三百万 to 一亿 一千万 三百万
  89. """
  90. if integer_symbols and isinstance(integer_symbols[0], CNU):
  91. if integer_symbols[0].power == 1:
  92. integer_symbols = [system.digits[1]] + integer_symbols
  93. if len(integer_symbols) > 1:
  94. if isinstance(integer_symbols[-1], CND) and isinstance(
  95. integer_symbols[-2], CNU
  96. ):
  97. integer_symbols.append(
  98. CNU(integer_symbols[-2].power - 1, None, None, None, None)
  99. )
  100. result = []
  101. unit_count = 0
  102. for s in integer_symbols:
  103. if isinstance(s, CND):
  104. result.append(s)
  105. unit_count = 0
  106. elif isinstance(s, CNU):
  107. current_unit = CNU(s.power, None, None, None, None)
  108. unit_count += 1
  109. if unit_count == 1:
  110. result.append(current_unit)
  111. elif unit_count > 1:
  112. for i in range(len(result)):
  113. if (
  114. isinstance(result[-i - 1], CNU)
  115. and result[-i - 1].power < current_unit.power
  116. ):
  117. result[-i - 1] = CNU(
  118. result[-i - 1].power + current_unit.power,
  119. None,
  120. None,
  121. None,
  122. None,
  123. )
  124. return result
  125. def compute_value(integer_symbols):
  126. """
  127. Compute the value.
  128. When current unit is larger than previous unit, current unit * all previous units will be used as all previous units.
  129. e.g. '两千万' = 2000 * 10000 not 2000 + 10000
  130. """
  131. value = [0]
  132. last_power = 0
  133. for s in integer_symbols:
  134. if isinstance(s, CND):
  135. value[-1] = s.value
  136. elif isinstance(s, CNU):
  137. value[-1] *= pow(10, s.power)
  138. if s.power > last_power:
  139. value[:-1] = list(map(lambda v: v * pow(10, s.power), value[:-1]))
  140. last_power = s.power
  141. value.append(0)
  142. return sum(value)
  143. system = create_system(numbering_type)
  144. int_part, dec_part = string2symbols(chinese_string, system)
  145. int_part = correct_symbols(int_part, system)
  146. int_str = str(compute_value(int_part))
  147. dec_str = "".join([str(d.value) for d in dec_part])
  148. if dec_part:
  149. return "{0}.{1}".format(int_str, dec_str)
  150. else:
  151. return int_str
  152. def num2chn(
  153. number_string,
  154. numbering_type=NUMBERING_TYPES[1],
  155. big=False,
  156. traditional=False,
  157. alt_zero=False,
  158. alt_one=False,
  159. alt_two=True,
  160. use_zeros=True,
  161. use_units=True,
  162. ):
  163. def get_value(value_string, use_zeros=True):
  164. striped_string = value_string.lstrip("0")
  165. # record nothing if all zeros
  166. if not striped_string:
  167. return []
  168. # record one digits
  169. elif len(striped_string) == 1:
  170. if use_zeros and len(value_string) != len(striped_string):
  171. return [system.digits[0], system.digits[int(striped_string)]]
  172. else:
  173. return [system.digits[int(striped_string)]]
  174. # recursively record multiple digits
  175. else:
  176. result_unit = next(
  177. u for u in reversed(system.units) if u.power < len(striped_string)
  178. )
  179. result_string = value_string[: -result_unit.power]
  180. return (
  181. get_value(result_string)
  182. + [result_unit]
  183. + get_value(striped_string[-result_unit.power :])
  184. )
  185. system = create_system(numbering_type)
  186. int_dec = number_string.split(".")
  187. if len(int_dec) == 1:
  188. int_string = int_dec[0]
  189. dec_string = ""
  190. elif len(int_dec) == 2:
  191. int_string = int_dec[0]
  192. dec_string = int_dec[1]
  193. else:
  194. raise ValueError(
  195. "invalid input num string with more than one dot: {}".format(number_string)
  196. )
  197. if use_units and len(int_string) > 1:
  198. result_symbols = get_value(int_string)
  199. else:
  200. result_symbols = [system.digits[int(c)] for c in int_string]
  201. dec_symbols = [system.digits[int(c)] for c in dec_string]
  202. if dec_string:
  203. result_symbols += [system.math.point] + dec_symbols
  204. if alt_two:
  205. liang = CND(
  206. 2,
  207. system.digits[2].alt_s,
  208. system.digits[2].alt_t,
  209. system.digits[2].big_s,
  210. system.digits[2].big_t,
  211. )
  212. for i, v in enumerate(result_symbols):
  213. if isinstance(v, CND) and v.value == 2:
  214. next_symbol = (
  215. result_symbols[i + 1] if i < len(result_symbols) - 1 else None
  216. )
  217. previous_symbol = result_symbols[i - 1] if i > 0 else None
  218. if isinstance(next_symbol, CNU) and isinstance(
  219. previous_symbol, (CNU, type(None))
  220. ):
  221. if next_symbol.power != 1 and (
  222. (previous_symbol is None) or (previous_symbol.power != 1)
  223. ):
  224. result_symbols[i] = liang
  225. # if big is True, '两' will not be used and `alt_two` has no impact on output
  226. if big:
  227. attr_name = "big_"
  228. if traditional:
  229. attr_name += "t"
  230. else:
  231. attr_name += "s"
  232. else:
  233. if traditional:
  234. attr_name = "traditional"
  235. else:
  236. attr_name = "simplified"
  237. result = "".join([getattr(s, attr_name) for s in result_symbols])
  238. # if not use_zeros:
  239. # result = result.strip(getattr(system.digits[0], attr_name))
  240. if alt_zero:
  241. result = result.replace(
  242. getattr(system.digits[0], attr_name), system.digits[0].alt_s
  243. )
  244. if alt_one:
  245. result = result.replace(
  246. getattr(system.digits[1], attr_name), system.digits[1].alt_s
  247. )
  248. for i, p in enumerate(POINT):
  249. if result.startswith(p):
  250. return CHINESE_DIGIS[0] + result
  251. # ^10, 11, .., 19
  252. if (
  253. len(result) >= 2
  254. and result[1]
  255. in [
  256. SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED[0],
  257. SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL[0],
  258. ]
  259. and result[0]
  260. in [
  261. CHINESE_DIGIS[1],
  262. BIG_CHINESE_DIGIS_SIMPLIFIED[1],
  263. BIG_CHINESE_DIGIS_TRADITIONAL[1],
  264. ]
  265. ):
  266. result = result[1:]
  267. return result
  268. if __name__ == "__main__":
  269. # 测试程序
  270. all_chinese_number_string = (
  271. CHINESE_DIGIS
  272. + BIG_CHINESE_DIGIS_SIMPLIFIED
  273. + BIG_CHINESE_DIGIS_TRADITIONAL
  274. + LARGER_CHINESE_NUMERING_UNITS_SIMPLIFIED
  275. + LARGER_CHINESE_NUMERING_UNITS_TRADITIONAL
  276. + SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED
  277. + SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL
  278. + ZERO_ALT
  279. + ONE_ALT
  280. + "".join(TWO_ALTS + POSITIVE + NEGATIVE + POINT)
  281. )
  282. print("num:", chn2num("一万零四百零三点八零五"))
  283. print("num:", chn2num("一亿六点三"))
  284. print("num:", chn2num("一亿零六点三"))
  285. print("num:", chn2num("两千零一亿六点三"))
  286. # print('num:', chn2num('一零零八六'))
  287. print("txt:", num2chn("10260.03", alt_zero=True))
  288. print("txt:", num2chn("20037.090", numbering_type="low", traditional=True))
  289. print("txt:", num2chn("100860001.77", numbering_type="high", big=True))
  290. print(
  291. "txt:",
  292. num2chn(
  293. "059523810880",
  294. alt_one=True,
  295. alt_two=False,
  296. use_lzeros=True,
  297. use_rzeros=True,
  298. use_units=False,
  299. ),
  300. )
  301. print(all_chinese_number_string)