|
|
@@ -109,8 +109,8 @@ if __name__ == "__main__":
|
|
|
from torch.utils.data import DataLoader
|
|
|
from transformers import GenerationConfig
|
|
|
|
|
|
- from speech_lm.models.flash_whisper import FlashWhisperForConditionalGeneration
|
|
|
from speech_lm.models.whisper_vq import WhisperVQ
|
|
|
+ from speech_lm.modules.flash_whisper import FlashWhisperForConditionalGeneration
|
|
|
|
|
|
dataset = WhisperVQDataset("filelists/whisper-vq.test.filelist")
|
|
|
dataloader = DataLoader(
|