mellotron_utils.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490
  1. import re
  2. import numpy as np
  3. import music21 as m21
  4. import torch
  5. import torch.nn.functional as F
  6. from text import text_to_sequence, get_arpabet, cmudict
  7. # import phkit
  8. CMUDICT_PATH = "data/cmu_dictionary"
  9. CMUDICT = cmudict.CMUDict(CMUDICT_PATH)
  10. PHONEME2GRAPHEME = {
  11. 'AA': ['a', 'o', 'ah'],
  12. 'AE': ['a', 'e'],
  13. 'AH': ['u', 'e', 'a', 'h', 'o'],
  14. 'AO': ['o', 'u', 'au'],
  15. 'AW': ['ou', 'ow'],
  16. 'AX': ['a'],
  17. 'AXR': ['er'],
  18. 'AY': ['i'],
  19. 'EH': ['e', 'ae'],
  20. 'EY': ['a', 'ai', 'ei', 'e', 'y'],
  21. 'IH': ['i', 'e', 'y'],
  22. 'IX': ['e', 'i'],
  23. 'IY': ['ea', 'ey', 'y', 'i'],
  24. 'OW': ['oa', 'o'],
  25. 'OY': ['oy'],
  26. 'UH': ['oo'],
  27. 'UW': ['oo', 'u', 'o'],
  28. 'UX': ['u'],
  29. 'B': ['b'],
  30. 'CH': ['ch', 'tch'],
  31. 'D': ['d', 'e', 'de'],
  32. 'DH': ['th'],
  33. 'DX': ['tt'],
  34. 'EL': ['le'],
  35. 'EM': ['m'],
  36. 'EN': ['on'],
  37. 'ER': ['i', 'er'],
  38. 'F': ['f'],
  39. 'G': ['g'],
  40. 'HH': ['h'],
  41. 'JH': ['j'],
  42. 'K': ['k', 'c', 'ch'],
  43. 'KS': ['x'],
  44. 'L': ['ll', 'l'],
  45. 'M': ['m'],
  46. 'N': ['n', 'gn'],
  47. 'NG': ['ng'],
  48. 'NX': ['nn'],
  49. 'P': ['p','pp'],
  50. 'Q': ['-'],
  51. 'R': ['wr', 'r'],
  52. 'S': ['s', 'ce'],
  53. 'SH': ['sh'],
  54. 'T': ['t'],
  55. 'TH': ['th'],
  56. 'V': ['v', 'f', 'e'],
  57. 'W': ['w'],
  58. 'WH': ['wh'],
  59. 'Y': ['y', 'j'],
  60. 'Z': ['z', 's'],
  61. 'ZH': ['s']
  62. }
  63. ########################
  64. # CONSONANT DURATION #
  65. ########################
  66. PHONEMEDURATION = {
  67. 'B': 0.05,
  68. 'CH': 0.1,
  69. 'D': 0.075,
  70. 'DH': 0.05,
  71. 'DX': 0.05,
  72. 'EL': 0.05,
  73. 'EM': 0.05,
  74. 'EN': 0.05,
  75. 'F': 0.1,
  76. 'G': 0.05,
  77. 'HH': 0.05,
  78. 'JH': 0.05,
  79. 'K': 0.05,
  80. 'L': 0.05,
  81. 'M': 0.15,
  82. 'N': 0.15,
  83. 'NG': 0.15,
  84. 'NX': 0.05,
  85. 'P': 0.05,
  86. 'Q': 0.075,
  87. 'R': 0.05,
  88. 'S': 0.1,
  89. 'SH': 0.05,
  90. 'T': 0.075,
  91. 'TH': 0.1,
  92. 'V': 0.05,
  93. 'Y': 0.05,
  94. 'W': 0.05,
  95. 'WH': 0.05,
  96. 'Z': 0.05,
  97. 'ZH': 0.05
  98. }
  99. def add_space_between_events(events, connect=False):
  100. new_events = []
  101. for i in range(1, len(events)):
  102. token_a, freq_a, start_time_a, end_time_a = events[i-1][-1]
  103. token_b, freq_b, start_time_b, end_time_b = events[i][0]
  104. if token_a in (' ', '') and len(events[i-1]) == 1:
  105. new_events.append(events[i-1])
  106. elif token_a not in (' ', '') and token_b not in (' ', ''):
  107. new_events.append(events[i-1])
  108. if connect:
  109. new_events.append([[' ', 0, end_time_a, start_time_b]])
  110. else:
  111. new_events.append([[' ', 0, end_time_a, end_time_a]])
  112. else:
  113. new_events.append(events[i-1])
  114. if new_events[-1][0][0] != ' ':
  115. new_events.append([[' ', 0, end_time_a, end_time_a]])
  116. new_events.append(events[-1])
  117. return new_events
  118. def adjust_words(events):
  119. new_events = []
  120. for event in events:
  121. if len(event) == 1 and event[0][0] == ' ':
  122. new_events.append(event)
  123. else:
  124. for e in event:
  125. if e[0][0].isupper():
  126. new_events.append([e])
  127. else:
  128. new_events[-1].extend([e])
  129. return new_events
  130. def adjust_extensions(events, phoneme_durations):
  131. if len(events) == 1:
  132. return events
  133. idx_last_vowel = None
  134. n_consonants_after_last_vowel = 0
  135. target_ids = np.arange(len(events))
  136. for i in range(len(events)):
  137. token = re.sub('[0-9{}]', '', events[i][0])
  138. if idx_last_vowel is None and token not in phoneme_durations:
  139. idx_last_vowel = i
  140. n_consonants_after_last_vowel = 0
  141. else:
  142. if token == '_' and not n_consonants_after_last_vowel:
  143. events[i][0] = events[idx_last_vowel][0]
  144. elif token == '_' and n_consonants_after_last_vowel:
  145. events[i][0] = events[idx_last_vowel][0]
  146. start = idx_last_vowel + 1
  147. target_ids[start:start+n_consonants_after_last_vowel] += 1
  148. target_ids[i] -= n_consonants_after_last_vowel
  149. elif token in phoneme_durations:
  150. n_consonants_after_last_vowel += 1
  151. else:
  152. n_consonants_after_last_vowel = 0
  153. idx_last_vowel = i
  154. new_events = [0] * len(events)
  155. for i in range(len(events)):
  156. new_events[target_ids[i]] = events[i]
  157. # adjust time of consonants that were repositioned
  158. for i in range(1, len(new_events)):
  159. if new_events[i][2] < new_events[i-1][2]:
  160. new_events[i][2] = new_events[i-1][2]
  161. new_events[i][3] = new_events[i-1][3]
  162. return new_events
  163. def adjust_consonant_lengths(events, phoneme_durations):
  164. t_init = events[0][2]
  165. idx_last_vowel = None
  166. for i in range(len(events)):
  167. task = re.sub('[0-9{}]', '', events[i][0])
  168. if task in phoneme_durations:
  169. duration = phoneme_durations[task]
  170. if idx_last_vowel is None: # consonant comes before any vowel
  171. events[i][2] = t_init
  172. events[i][3] = t_init + duration
  173. else: # consonant comes after a vowel, must offset
  174. events[idx_last_vowel][3] -= duration
  175. for k in range(idx_last_vowel+1, i):
  176. events[k][2] -= duration
  177. events[k][3] -= duration
  178. events[i][2] = events[i-1][3]
  179. events[i][3] = events[i-1][3] + duration
  180. else:
  181. events[i][2] = t_init
  182. events[i][3] = events[i][3]
  183. t_init = events[i][3]
  184. idx_last_vowel = i
  185. t_init = events[i][3]
  186. return events
  187. def adjust_consonants(events, phoneme_durations):
  188. if len(events) == 1:
  189. return events
  190. start = 0
  191. split_ids = []
  192. t_init = events[0][2]
  193. # get each substring group
  194. for i in range(1, len(events)):
  195. if events[i][2] != t_init:
  196. split_ids.append((start, i))
  197. start = i
  198. t_init = events[i][2]
  199. split_ids.append((start, len(events)))
  200. for (start, end) in split_ids:
  201. events[start:end] = adjust_consonant_lengths(
  202. events[start:end], phoneme_durations)
  203. return events
  204. def adjust_event(event, hop_length=256, sampling_rate=22050):
  205. tokens, freq, start_time, end_time = event
  206. if tokens == ' ':
  207. return [event] if freq == 0 else [['_', freq, start_time, end_time]]
  208. return [[token, freq, start_time, end_time] for token in tokens]
  209. def musicxml2score(filepath, bpm=60):
  210. track = {}
  211. beat_length_seconds = 60/bpm
  212. data = m21.converter.parse(filepath)
  213. for i in range(len(data.parts)):
  214. part = data.parts[i].flat
  215. events = []
  216. for k in range(len(part.notesAndRests)):
  217. event = part.notesAndRests[k]
  218. if isinstance(event, m21.note.Note):
  219. freq = event.pitch.frequency
  220. token = event.lyrics[0].text if len(event.lyrics) > 0 else ' '
  221. start_time = event.offset * beat_length_seconds
  222. end_time = start_time + event.duration.quarterLength * beat_length_seconds
  223. event = [token, freq, start_time, end_time]
  224. elif isinstance(event, m21.note.Rest):
  225. freq = 0
  226. token = ' '
  227. start_time = event.offset * beat_length_seconds
  228. end_time = start_time + event.duration.quarterLength * beat_length_seconds
  229. event = [token, freq, start_time, end_time]
  230. if token == '_':
  231. raise Exception("Unexpected token {}".format(token))
  232. if len(events) == 0:
  233. events.append(event)
  234. else:
  235. if token == ' ':
  236. if freq == 0:
  237. if events[-1][1] == 0:
  238. events[-1][3] = end_time
  239. else:
  240. events.append(event)
  241. elif freq == events[-1][1]: # is event duration extension ?
  242. events[-1][-1] = end_time
  243. else: # must be different note on same syllable
  244. events.append(event)
  245. else:
  246. events.append(event)
  247. track[part.partName] = events
  248. return track
  249. def track2events(track):
  250. events = []
  251. for e in track:
  252. events.extend(adjust_event(e))
  253. group_ids = [i for i in range(len(events))
  254. if events[i][0] in [' '] or events[i][0].isupper()]
  255. events_grouped = []
  256. for i in range(1, len(group_ids)):
  257. start, end = group_ids[i-1], group_ids[i]
  258. events_grouped.append(events[start:end])
  259. if events[-1][0] != ' ':
  260. events_grouped.append(events[group_ids[-1]:])
  261. return events_grouped
  262. def events2eventsarpabet(event):
  263. if event[0][0] == ' ':
  264. return event
  265. # get word and word arpabet
  266. word = ''.join([e[0] for e in event if e[0] not in('_', ' ')])
  267. word_arpabet = get_arpabet(word, CMUDICT)
  268. if word_arpabet[0] != '{':
  269. return event
  270. word_arpabet = word_arpabet.split()
  271. # align tokens to arpabet
  272. i, k = 0, 0
  273. new_events = []
  274. while i < len(event) and k < len(word_arpabet):
  275. # single token
  276. token_a, freq_a, start_time_a, end_time_a = event[i]
  277. if token_a == ' ':
  278. new_events.append([token_a, freq_a, start_time_a, end_time_a])
  279. i += 1
  280. continue
  281. if token_a == '_':
  282. new_events.append([token_a, freq_a, start_time_a, end_time_a])
  283. i += 1
  284. continue
  285. # two tokens
  286. if i < len(event) - 1:
  287. j = i + 1
  288. token_b, freq_b, start_time_b, end_time_b = event[j]
  289. between_events = []
  290. while j < len(event) and event[j][0] == '_':
  291. between_events.append([token_b, freq_b, start_time_b, end_time_b])
  292. j += 1
  293. if j < len(event):
  294. token_b, freq_b, start_time_b, end_time_b = event[j]
  295. token_compound_2 = (token_a + token_b).lower()
  296. # single arpabet
  297. arpabet = re.sub('[0-9{}]', '', word_arpabet[k])
  298. if k < len(word_arpabet) - 1:
  299. arpabet_compound_2 = ''.join(word_arpabet[k:k+2])
  300. arpabet_compound_2 = re.sub('[0-9{}]', '', arpabet_compound_2)
  301. if i < len(event) - 1 and token_compound_2 in PHONEME2GRAPHEME[arpabet]:
  302. new_events.append([word_arpabet[k], freq_a, start_time_a, end_time_a])
  303. if len(between_events):
  304. new_events.extend(between_events)
  305. if start_time_a != start_time_b:
  306. new_events.append([word_arpabet[k], freq_b, start_time_b, end_time_b])
  307. i += 2 + len(between_events)
  308. k += 1
  309. elif token_a.lower() in PHONEME2GRAPHEME[arpabet]:
  310. new_events.append([word_arpabet[k], freq_a, start_time_a, end_time_a])
  311. i += 1
  312. k += 1
  313. elif arpabet_compound_2 in PHONEME2GRAPHEME and token_a.lower() in PHONEME2GRAPHEME[arpabet_compound_2]:
  314. new_events.append([word_arpabet[k], freq_a, start_time_a, end_time_a])
  315. new_events.append([word_arpabet[k+1], freq_a, start_time_a, end_time_a])
  316. i += 1
  317. k += 2
  318. else:
  319. k += 1
  320. # add extensions and pauses at end of words
  321. while i < len(event):
  322. token_a, freq_a, start_time_a, end_time_a = event[i]
  323. if token_a in (' ', '_'):
  324. new_events.append([token_a, freq_a, start_time_a, end_time_a])
  325. i += 1
  326. return new_events
  327. def event2alignment(events, hop_length=256, sampling_rate=22050):
  328. frame_length = float(hop_length) / float(sampling_rate)
  329. n_frames = int(events[-1][-1][-1] / frame_length)
  330. n_tokens = np.sum([len(e) for e in events])
  331. alignment = np.zeros((n_tokens, n_frames))
  332. cur_event = -1
  333. for event in events:
  334. for i in range(len(event)):
  335. if len(event) == 1 or cur_event == -1 or event[i][0] != event[i-1][0]:
  336. cur_event += 1
  337. token, freq, start_time, end_time = event[i]
  338. alignment[cur_event, int(start_time/frame_length):int(end_time/frame_length)] = 1
  339. return alignment[:cur_event+1]
  340. def event2f0(events, hop_length=256, sampling_rate=22050):
  341. frame_length = float(hop_length) / float(sampling_rate)
  342. n_frames = int(events[-1][-1][-1] / frame_length)
  343. f0s = np.zeros((1, n_frames))
  344. for event in events:
  345. for i in range(len(event)):
  346. token, freq, start_time, end_time = event[i]
  347. f0s[0, int(start_time/frame_length):int(end_time/frame_length)] = freq
  348. return f0s
  349. def event2text(events, convert_stress, cmudict=None):
  350. text_clean = ''
  351. for event in events:
  352. for i in range(len(event)):
  353. if i > 0 and event[i][0] == event[i-1][0]:
  354. continue
  355. if event[i][0] == ' ' and len(event) > 1:
  356. if text_clean[-1] != "}":
  357. text_clean = text_clean[:-1] + '} {'
  358. else:
  359. text_clean += ' {'
  360. else:
  361. if event[i][0][-1] in ('}', ' '):
  362. text_clean += event[i][0]
  363. else:
  364. text_clean += event[i][0] + ' '
  365. if convert_stress:
  366. text_clean = re.sub('[0-9]', '1', text_clean)
  367. text_encoded = text_to_sequence(text_clean, [], cmudict)
  368. return text_encoded, text_clean
  369. def remove_excess_frames(alignment, f0s):
  370. excess_frames = np.sum(alignment.sum(0) == 0)
  371. alignment = alignment[:, :-excess_frames] if excess_frames > 0 else alignment
  372. f0s = f0s[:, :-excess_frames] if excess_frames > 0 else f0s
  373. return alignment, f0s
  374. def get_data_from_musicxml(filepath, bpm, phoneme_durations=None,
  375. convert_stress=False):
  376. if phoneme_durations is None:
  377. phoneme_durations = PHONEMEDURATION
  378. score = musicxml2score(filepath, bpm)
  379. data = {}
  380. for k, v in score.items():
  381. # ignore empty tracks
  382. if len(v) == 1 and v[0][0] == ' ':
  383. continue
  384. events = track2events(v)
  385. events = adjust_words(events)
  386. events_arpabet = [events2eventsarpabet(e) for e in events]
  387. # make adjustments
  388. events_arpabet = [adjust_extensions(e, phoneme_durations)
  389. for e in events_arpabet]
  390. events_arpabet = [adjust_consonants(e, phoneme_durations)
  391. for e in events_arpabet]
  392. events_arpabet = add_space_between_events(events_arpabet)
  393. # convert data to alignment, f0 and text encoded
  394. alignment = event2alignment(events_arpabet)
  395. f0s = event2f0(events_arpabet)
  396. alignment, f0s = remove_excess_frames(alignment, f0s)
  397. text_encoded, text_clean = event2text(events_arpabet, convert_stress)
  398. # convert data to torch
  399. alignment = torch.from_numpy(alignment).permute(1, 0)[:, None].float()
  400. f0s = torch.from_numpy(f0s)[None].float()
  401. text_encoded = torch.LongTensor(text_encoded)[None]
  402. data[k] = {'rhythm': alignment,
  403. 'pitch_contour': f0s,
  404. 'text_encoded': text_encoded}
  405. return data
  406. if __name__ == "__main__":
  407. import argparse
  408. # Get defaults so it can work with no Sacred
  409. parser = argparse.ArgumentParser()
  410. parser.add_argument('-f', "--filepath", required=True)
  411. args = parser.parse_args()
  412. get_data_from_musicxml(args.filepath, 60)