train.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188
  1. # *****************************************************************************
  2. # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
  3. #
  4. # Redistribution and use in source and binary forms, with or without
  5. # modification, are permitted provided that the following conditions are met:
  6. # * Redistributions of source code must retain the above copyright
  7. # notice, this list of conditions and the following disclaimer.
  8. # * Redistributions in binary form must reproduce the above copyright
  9. # notice, this list of conditions and the following disclaimer in the
  10. # documentation and/or other materials provided with the distribution.
  11. # * Neither the name of the NVIDIA CORPORATION nor the
  12. # names of its contributors may be used to endorse or promote products
  13. # derived from this software without specific prior written permission.
  14. #
  15. # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
  16. # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
  17. # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
  18. # DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
  19. # DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
  20. # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
  21. # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
  22. # ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
  23. # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
  24. # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  25. #
  26. # *****************************************************************************
  27. import argparse
  28. import json
  29. import os
  30. import torch
  31. #=====START: ADDED FOR DISTRIBUTED======
  32. from distributed import init_distributed, apply_gradient_allreduce, reduce_tensor
  33. from torch.utils.data.distributed import DistributedSampler
  34. #=====END: ADDED FOR DISTRIBUTED======
  35. from torch.utils.data import DataLoader
  36. from glow import WaveGlow, WaveGlowLoss
  37. from mel2samp import Mel2Samp
  38. def load_checkpoint(checkpoint_path, model, optimizer):
  39. assert os.path.isfile(checkpoint_path)
  40. checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
  41. iteration = checkpoint_dict['iteration']
  42. optimizer.load_state_dict(checkpoint_dict['optimizer'])
  43. model_for_loading = checkpoint_dict['model']
  44. model.load_state_dict(model_for_loading.state_dict())
  45. print("Loaded checkpoint '{}' (iteration {})" .format(
  46. checkpoint_path, iteration))
  47. return model, optimizer, iteration
  48. def save_checkpoint(model, optimizer, learning_rate, iteration, filepath):
  49. print("Saving model and optimizer state at iteration {} to {}".format(
  50. iteration, filepath))
  51. model_for_saving = WaveGlow(**waveglow_config).cuda()
  52. model_for_saving.load_state_dict(model.state_dict())
  53. torch.save({'model': model_for_saving,
  54. 'iteration': iteration,
  55. 'optimizer': optimizer.state_dict(),
  56. 'learning_rate': learning_rate}, filepath)
  57. def train(num_gpus, rank, group_name, output_directory, epochs, learning_rate,
  58. sigma, iters_per_checkpoint, batch_size, seed, fp16_run,
  59. checkpoint_path, with_tensorboard):
  60. torch.manual_seed(seed)
  61. torch.cuda.manual_seed(seed)
  62. #=====START: ADDED FOR DISTRIBUTED======
  63. if num_gpus > 1:
  64. init_distributed(rank, num_gpus, group_name, **dist_config)
  65. #=====END: ADDED FOR DISTRIBUTED======
  66. criterion = WaveGlowLoss(sigma)
  67. model = WaveGlow(**waveglow_config).cuda()
  68. #=====START: ADDED FOR DISTRIBUTED======
  69. if num_gpus > 1:
  70. model = apply_gradient_allreduce(model)
  71. #=====END: ADDED FOR DISTRIBUTED======
  72. optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
  73. if fp16_run:
  74. from apex import amp
  75. model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
  76. # Load checkpoint if one exists
  77. iteration = 0
  78. if checkpoint_path != "":
  79. model, optimizer, iteration = load_checkpoint(checkpoint_path, model,
  80. optimizer)
  81. iteration += 1 # next iteration is iteration + 1
  82. trainset = Mel2Samp(**data_config)
  83. # =====START: ADDED FOR DISTRIBUTED======
  84. train_sampler = DistributedSampler(trainset) if num_gpus > 1 else None
  85. # =====END: ADDED FOR DISTRIBUTED======
  86. train_loader = DataLoader(trainset, num_workers=1, shuffle=False,
  87. sampler=train_sampler,
  88. batch_size=batch_size,
  89. pin_memory=False,
  90. drop_last=True)
  91. # Get shared output_directory ready
  92. if rank == 0:
  93. if not os.path.isdir(output_directory):
  94. os.makedirs(output_directory)
  95. os.chmod(output_directory, 0o775)
  96. print("output directory", output_directory)
  97. if with_tensorboard and rank == 0:
  98. from tensorboardX import SummaryWriter
  99. logger = SummaryWriter(os.path.join(output_directory, 'logs'))
  100. model.train()
  101. epoch_offset = max(0, int(iteration / len(train_loader)))
  102. # ================ MAIN TRAINNIG LOOP! ===================
  103. for epoch in range(epoch_offset, epochs):
  104. print("Epoch: {}".format(epoch))
  105. for i, batch in enumerate(train_loader):
  106. model.zero_grad()
  107. mel, audio = batch
  108. mel = torch.autograd.Variable(mel.cuda())
  109. audio = torch.autograd.Variable(audio.cuda())
  110. outputs = model((mel, audio))
  111. loss = criterion(outputs)
  112. if num_gpus > 1:
  113. reduced_loss = reduce_tensor(loss.data, num_gpus).item()
  114. else:
  115. reduced_loss = loss.item()
  116. if fp16_run:
  117. with amp.scale_loss(loss, optimizer) as scaled_loss:
  118. scaled_loss.backward()
  119. else:
  120. loss.backward()
  121. optimizer.step()
  122. print("{}:\t{:.9f}".format(iteration, reduced_loss))
  123. if with_tensorboard and rank == 0:
  124. logger.add_scalar('training_loss', reduced_loss, i + len(train_loader) * epoch)
  125. if (iteration % iters_per_checkpoint == 0):
  126. if rank == 0:
  127. checkpoint_path = "{}/waveglow_{}".format(
  128. output_directory, iteration)
  129. save_checkpoint(model, optimizer, learning_rate, iteration,
  130. checkpoint_path)
  131. iteration += 1
  132. if __name__ == "__main__":
  133. parser = argparse.ArgumentParser()
  134. parser.add_argument('-c', '--config', type=str,
  135. help='JSON file for configuration')
  136. parser.add_argument('-r', '--rank', type=int, default=0,
  137. help='rank of process for distributed')
  138. parser.add_argument('-g', '--group_name', type=str, default='',
  139. help='name of group for distributed')
  140. args = parser.parse_args()
  141. # Parse configs. Globals nicer in this case
  142. with open(args.config) as f:
  143. data = f.read()
  144. config = json.loads(data)
  145. train_config = config["train_config"]
  146. global data_config
  147. data_config = config["data_config"]
  148. global dist_config
  149. dist_config = config["dist_config"]
  150. global waveglow_config
  151. waveglow_config = config["waveglow_config"]
  152. num_gpus = torch.cuda.device_count()
  153. if num_gpus > 1:
  154. if args.group_name == '':
  155. print("WARNING: Multiple GPUs detected but no distributed group set")
  156. print("Only running 1 GPU. Use distributed.py for multiple GPUs")
  157. num_gpus = 1
  158. if num_gpus == 1 and args.rank != 0:
  159. raise Exception("Doing single GPU training on rank > 0")
  160. torch.backends.cudnn.enabled = True
  161. torch.backends.cudnn.benchmark = False
  162. train(num_gpus, args.rank, args.group_name, **train_config)