inference.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. # *****************************************************************************
  2. # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
  3. #
  4. # Redistribution and use in source and binary forms, with or without
  5. # modification, are permitted provided that the following conditions are met:
  6. # * Redistributions of source code must retain the above copyright
  7. # notice, this list of conditions and the following disclaimer.
  8. # * Redistributions in binary form must reproduce the above copyright
  9. # notice, this list of conditions and the following disclaimer in the
  10. # documentation and/or other materials provided with the distribution.
  11. # * Neither the name of the NVIDIA CORPORATION nor the
  12. # names of its contributors may be used to endorse or promote products
  13. # derived from this software without specific prior written permission.
  14. #
  15. # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
  16. # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
  17. # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
  18. # ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
  19. # DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
  20. # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
  21. # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
  22. # ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
  23. # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
  24. # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  25. #
  26. # *****************************************************************************
  27. import os
  28. from scipy.io.wavfile import write
  29. import torch
  30. from mel2samp import files_to_list, MAX_WAV_VALUE
  31. from denoiser import Denoiser
  32. def main(mel_files, waveglow_path, sigma, output_dir, sampling_rate, is_fp16,
  33. denoiser_strength):
  34. mel_files = files_to_list(mel_files)
  35. waveglow = torch.load(waveglow_path)['model']
  36. waveglow = waveglow.remove_weightnorm(waveglow)
  37. waveglow.cuda().eval()
  38. if is_fp16:
  39. from apex import amp
  40. waveglow, _ = amp.initialize(waveglow, [], opt_level="O3")
  41. if denoiser_strength > 0:
  42. denoiser = Denoiser(waveglow).cuda()
  43. for i, file_path in enumerate(mel_files):
  44. file_name = os.path.splitext(os.path.basename(file_path))[0]
  45. mel = torch.load(file_path)
  46. mel = torch.autograd.Variable(mel.cuda())
  47. mel = torch.unsqueeze(mel, 0)
  48. mel = mel.half() if is_fp16 else mel
  49. with torch.no_grad():
  50. audio = waveglow.infer(mel, sigma=sigma)
  51. if denoiser_strength > 0:
  52. audio = denoiser(audio, denoiser_strength)
  53. audio = audio * MAX_WAV_VALUE
  54. audio = audio.squeeze()
  55. audio = audio.cpu().numpy()
  56. audio = audio.astype('int16')
  57. audio_path = os.path.join(
  58. output_dir, "{}_synthesis.wav".format(file_name))
  59. write(audio_path, sampling_rate, audio)
  60. print(audio_path)
  61. if __name__ == "__main__":
  62. import argparse
  63. parser = argparse.ArgumentParser()
  64. parser.add_argument('-f', "--filelist_path", required=True)
  65. parser.add_argument('-w', '--waveglow_path',
  66. help='Path to waveglow decoder checkpoint with model')
  67. parser.add_argument('-o', "--output_dir", required=True)
  68. parser.add_argument("-s", "--sigma", default=1.0, type=float)
  69. parser.add_argument("--sampling_rate", default=22050, type=int)
  70. parser.add_argument("--is_fp16", action="store_true")
  71. parser.add_argument("-d", "--denoiser_strength", default=0.0, type=float,
  72. help='Removes model bias. Start with 0.1 and adjust')
  73. args = parser.parse_args()
  74. main(args.filelist_path, args.waveglow_path, args.sigma, args.output_dir,
  75. args.sampling_rate, args.is_fp16, args.denoiser_strength)