@@ -35,7 +35,7 @@ logger.add(sys.stderr, format=logger_format)
@lru_cache(maxsize=1)
def get_hubert_model():
- model = HubertModel.from_pretrained("TencentGameMate/chinese-hubert-base")
+ model = HubertModel.from_pretrained("TencentGameMate/chinese-hubert-large")
model = model.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
model = model.half()
model.eval()