|
|
@@ -124,7 +124,7 @@ def decode_one_token(
|
|
|
|
|
|
codebooks = torch.stack(codebooks, dim=0)
|
|
|
codebooks[1] = torch.where(
|
|
|
- codebooks[0] <= 32311, codebooks[0], codebooks[0] - 32311 + 2
|
|
|
+ codebooks[0] <= 32311, codebooks[1], codebooks[0] - 32311 + 2
|
|
|
)
|
|
|
|
|
|
return codebooks
|