|
@@ -13,7 +13,7 @@ import torch
|
|
|
import torch.nn as nn
|
|
import torch.nn as nn
|
|
|
import torch.nn.functional as F
|
|
import torch.nn.functional as F
|
|
|
|
|
|
|
|
-from fish_speech.models.text2semantic.inference import load_model
|
|
|
|
|
|
|
+from fish_speech.models.text2semantic.inference import init_model
|
|
|
from fish_speech.models.text2semantic.llama import find_multiple
|
|
from fish_speech.models.text2semantic.llama import find_multiple
|
|
|
|
|
|
|
|
##### Quantization Primitives ######
|
|
##### Quantization Primitives ######
|
|
@@ -445,7 +445,7 @@ def quantize(checkpoint_path: Path, mode: str, groupsize: int, timestamp: str) -
|
|
|
print("Loading model ...")
|
|
print("Loading model ...")
|
|
|
t0 = time.time()
|
|
t0 = time.time()
|
|
|
|
|
|
|
|
- model, _ = load_model(
|
|
|
|
|
|
|
+ model, _ = init_model(
|
|
|
checkpoint_path=checkpoint_path,
|
|
checkpoint_path=checkpoint_path,
|
|
|
device=device,
|
|
device=device,
|
|
|
precision=precision,
|
|
precision=precision,
|