Browse Source

Move flash whisper

Lengyue 2 years ago
parent
commit
f4e8633363

+ 1 - 1
speech_lm/datasets/whisper_vq.py

@@ -109,8 +109,8 @@ if __name__ == "__main__":
     from torch.utils.data import DataLoader
     from transformers import GenerationConfig
 
-    from speech_lm.models.flash_whisper import FlashWhisperForConditionalGeneration
     from speech_lm.models.whisper_vq import WhisperVQ
+    from speech_lm.modules.flash_whisper import FlashWhisperForConditionalGeneration
 
     dataset = WhisperVQDataset("filelists/whisper-vq.test.filelist")
     dataloader = DataLoader(

+ 1 - 1
speech_lm/models/whisper_vq.py

@@ -5,7 +5,7 @@ import torch
 from torch import nn
 from vector_quantize_pytorch import VectorQuantize
 
-from speech_lm.models.flash_whisper import (
+from speech_lm.modules.flash_whisper import (
     FlashWhisperEncoderLayer,
     FlashWhisperForConditionalGeneration,
 )

+ 0 - 0
speech_lm/models/flash_whisper.py → speech_lm/modules/flash_whisper.py


+ 1 - 1
tools/whisper_asr.py

@@ -16,7 +16,7 @@ from loguru import logger
 from transformers import WhisperProcessor
 from whisper.audio import load_audio, log_mel_spectrogram, pad_or_trim
 
-from speech_lm.models.flash_whisper import FlashWhisperForConditionalGeneration
+from speech_lm.modules.flash_whisper import FlashWhisperForConditionalGeneration
 
 RANK_STR = ""