hparams.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. import tensorflow as tf
  2. from text.symbols import symbols
  3. def create_hparams(hparams_string=None, verbose=False):
  4. """Create model hyperparameters. Parse nondefault from given string."""
  5. hparams = tf.contrib.training.HParams(
  6. ################################
  7. # Experiment Parameters #
  8. ################################
  9. epochs=50000,
  10. iters_per_checkpoint=500,
  11. seed=1234,
  12. dynamic_loss_scaling=True,
  13. fp16_run=False,
  14. distributed_run=False,
  15. dist_backend="nccl",
  16. dist_url="tcp://localhost:54321",
  17. cudnn_enabled=True,
  18. cudnn_benchmark=False,
  19. ignore_layers=['speaker_embedding.weight'],
  20. ################################
  21. # Data Parameters #
  22. ################################
  23. # training_files='filelists/ljs_audiopaths_text_sid_train_filelist.txt',
  24. # validation_files='filelists/ljs_audiopaths_text_sid_val_filelist.txt',
  25. training_files='/Users/tzld/mellotron/filelists/ljs_audiopaths_text_sid_train_filelist_new.txt',
  26. validation_files='/Users/tzld/mellotron/filelists/ljs_audiopaths_text_sid_val_filelist_new.txt',
  27. # training_files='/Users/tzld/mellotron/filelists/ljs_audiopaths_text_sid_train_filelist_new.txt',
  28. # validation_files='/Users/tzld/mellotron/filelists/ljs_audiopaths_text_sid_val_filelist_new.txt',
  29. text_cleaners=['english_cleaners'],
  30. p_arpabet=1.0,
  31. cmudict_path="data/cmu_dictionary",
  32. ################################
  33. # Audio Parameters #
  34. ################################
  35. max_wav_value=32768.0,
  36. sampling_rate=22050,
  37. filter_length=1024,
  38. hop_length=256,
  39. win_length=1024,
  40. n_mel_channels=80,
  41. mel_fmin=0.0,
  42. mel_fmax=8000.0,
  43. f0_min=80,
  44. f0_max=880,
  45. harm_thresh=0.25,
  46. ################################
  47. # Model Parameters #
  48. ################################
  49. n_symbols=len(symbols),
  50. symbols_embedding_dim=512,
  51. # Encoder parameters
  52. encoder_kernel_size=5,
  53. encoder_n_convolutions=3,
  54. encoder_embedding_dim=512,
  55. # Decoder parameters
  56. n_frames_per_step=1, # currently only 1 is supported
  57. decoder_rnn_dim=1024,
  58. prenet_dim=256,
  59. prenet_f0_n_layers=1,
  60. prenet_f0_dim=1,
  61. prenet_f0_kernel_size=1,
  62. prenet_rms_dim=0,
  63. prenet_rms_kernel_size=1,
  64. max_decoder_steps=1000,
  65. gate_threshold=0.5,
  66. p_attention_dropout=0.1,
  67. p_decoder_dropout=0.1,
  68. p_teacher_forcing=1.0,
  69. # Attention parameters
  70. attention_rnn_dim=1024,
  71. attention_dim=128,
  72. # Location Layer parameters
  73. attention_location_n_filters=32,
  74. attention_location_kernel_size=31,
  75. # Mel-post processing network parameters
  76. postnet_embedding_dim=512,
  77. postnet_kernel_size=5,
  78. postnet_n_convolutions=5,
  79. # Speaker embedding
  80. n_speakers=123,
  81. speaker_embedding_dim=128,
  82. # Reference encoder
  83. with_gst=True,
  84. ref_enc_filters=[32, 32, 64, 64, 128, 128],
  85. ref_enc_size=[3, 3],
  86. ref_enc_strides=[2, 2],
  87. ref_enc_pad=[1, 1],
  88. ref_enc_gru_size=128,
  89. # Style Token Layer
  90. token_embedding_size=256,
  91. token_num=10,
  92. num_heads=8,
  93. ################################
  94. # Optimization Hyperparameters #
  95. ################################
  96. use_saved_learning_rate=False,
  97. learning_rate=1e-3,
  98. learning_rate_min=1e-5,
  99. learning_rate_anneal=50000,
  100. weight_decay=1e-6,
  101. grad_clip_thresh=1.0,
  102. batch_size=32,
  103. mask_padding=True, # set model's padded outputs to padded values
  104. )
  105. if hparams_string:
  106. tf.compat.v1.logging.info('Parsing command line hparams: %s', hparams_string)
  107. hparams.parse(hparams_string)
  108. if verbose:
  109. tf.compat.v1.logging.info('Final parsed hparams: %s', hparams.values())
  110. return hparams